初始化
This commit is contained in:
104
src/plugin_system/__init__.py
Normal file
104
src/plugin_system/__init__.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""
|
||||
MaiBot 插件系统
|
||||
|
||||
提供统一的插件开发和管理框架
|
||||
"""
|
||||
|
||||
# 导出主要的公共接口
|
||||
from .base import (
|
||||
BasePlugin,
|
||||
BaseAction,
|
||||
BaseCommand,
|
||||
BaseTool,
|
||||
ConfigField,
|
||||
ComponentType,
|
||||
ActionActivationType,
|
||||
ChatMode,
|
||||
ComponentInfo,
|
||||
ActionInfo,
|
||||
CommandInfo,
|
||||
PluginInfo,
|
||||
ToolInfo,
|
||||
PythonDependency,
|
||||
BaseEventHandler,
|
||||
EventHandlerInfo,
|
||||
EventType,
|
||||
MaiMessages,
|
||||
ToolParamType,
|
||||
)
|
||||
|
||||
# 导入工具模块
|
||||
from .utils import (
|
||||
ManifestValidator,
|
||||
# ManifestGenerator,
|
||||
# validate_plugin_manifest,
|
||||
# generate_plugin_manifest,
|
||||
)
|
||||
|
||||
from .apis import (
|
||||
chat_api,
|
||||
tool_api,
|
||||
component_manage_api,
|
||||
config_api,
|
||||
database_api,
|
||||
emoji_api,
|
||||
generator_api,
|
||||
llm_api,
|
||||
message_api,
|
||||
person_api,
|
||||
plugin_manage_api,
|
||||
send_api,
|
||||
register_plugin,
|
||||
get_logger,
|
||||
)
|
||||
|
||||
|
||||
__version__ = "2.0.0"
|
||||
|
||||
__all__ = [
|
||||
# API 模块
|
||||
"chat_api",
|
||||
"tool_api",
|
||||
"component_manage_api",
|
||||
"config_api",
|
||||
"database_api",
|
||||
"emoji_api",
|
||||
"generator_api",
|
||||
"llm_api",
|
||||
"message_api",
|
||||
"person_api",
|
||||
"plugin_manage_api",
|
||||
"send_api",
|
||||
"register_plugin",
|
||||
"get_logger",
|
||||
# 基础类
|
||||
"BasePlugin",
|
||||
"BaseAction",
|
||||
"BaseCommand",
|
||||
"BaseTool",
|
||||
"BaseEventHandler",
|
||||
# 类型定义
|
||||
"ComponentType",
|
||||
"ActionActivationType",
|
||||
"ChatMode",
|
||||
"ComponentInfo",
|
||||
"ActionInfo",
|
||||
"CommandInfo",
|
||||
"PluginInfo",
|
||||
"ToolInfo",
|
||||
"PythonDependency",
|
||||
"EventHandlerInfo",
|
||||
"EventType",
|
||||
"ToolParamType",
|
||||
# 消息
|
||||
"MaiMessages",
|
||||
# 装饰器
|
||||
"register_plugin",
|
||||
"ConfigField",
|
||||
# 工具函数
|
||||
"ManifestValidator",
|
||||
"get_logger",
|
||||
# "ManifestGenerator",
|
||||
# "validate_plugin_manifest",
|
||||
# "generate_plugin_manifest",
|
||||
]
|
||||
41
src/plugin_system/apis/__init__.py
Normal file
41
src/plugin_system/apis/__init__.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""
|
||||
插件系统API模块
|
||||
|
||||
提供了插件开发所需的各种API
|
||||
"""
|
||||
|
||||
# 导入所有API模块
|
||||
from src.plugin_system.apis import (
|
||||
chat_api,
|
||||
component_manage_api,
|
||||
config_api,
|
||||
database_api,
|
||||
emoji_api,
|
||||
generator_api,
|
||||
llm_api,
|
||||
message_api,
|
||||
person_api,
|
||||
plugin_manage_api,
|
||||
send_api,
|
||||
tool_api,
|
||||
)
|
||||
from .logging_api import get_logger
|
||||
from .plugin_register_api import register_plugin
|
||||
|
||||
# 导出所有API模块,使它们可以通过 apis.xxx 方式访问
|
||||
__all__ = [
|
||||
"chat_api",
|
||||
"component_manage_api",
|
||||
"config_api",
|
||||
"database_api",
|
||||
"emoji_api",
|
||||
"generator_api",
|
||||
"llm_api",
|
||||
"message_api",
|
||||
"person_api",
|
||||
"plugin_manage_api",
|
||||
"send_api",
|
||||
"get_logger",
|
||||
"register_plugin",
|
||||
"tool_api",
|
||||
]
|
||||
325
src/plugin_system/apis/chat_api.py
Normal file
325
src/plugin_system/apis/chat_api.py
Normal file
@@ -0,0 +1,325 @@
|
||||
"""
|
||||
聊天API模块
|
||||
|
||||
专门负责聊天信息的查询和管理,采用标准Python包设计模式
|
||||
使用方式:
|
||||
from src.plugin_system.apis import chat_api
|
||||
streams = chat_api.get_all_group_streams()
|
||||
chat_type = chat_api.get_stream_type(stream)
|
||||
|
||||
或者:
|
||||
from src.plugin_system.apis.chat_api import ChatManager as chat
|
||||
streams = chat.get_all_group_streams()
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from enum import Enum
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
|
||||
logger = get_logger("chat_api")
|
||||
|
||||
|
||||
class SpecialTypes(Enum):
|
||||
"""特殊枚举类型"""
|
||||
|
||||
ALL_PLATFORMS = "all_platforms"
|
||||
|
||||
|
||||
class ChatManager:
|
||||
"""聊天管理器 - 专门负责聊天信息的查询和管理"""
|
||||
|
||||
@staticmethod
|
||||
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
|
||||
# sourcery skip: for-append-to-extend
|
||||
"""获取所有聊天流
|
||||
|
||||
Args:
|
||||
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
|
||||
|
||||
Returns:
|
||||
List[ChatStream]: 聊天流列表
|
||||
|
||||
Raises:
|
||||
TypeError: 如果 platform 不是字符串或 SpecialTypes 枚举类型
|
||||
"""
|
||||
if not isinstance(platform, (str, SpecialTypes)):
|
||||
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||
streams = []
|
||||
try:
|
||||
for _, stream in get_chat_manager().streams.items():
|
||||
if platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform:
|
||||
streams.append(stream)
|
||||
logger.debug(f"[ChatAPI] 获取到 {len(streams)} 个 {platform} 平台的聊天流")
|
||||
except Exception as e:
|
||||
logger.error(f"[ChatAPI] 获取聊天流失败: {e}")
|
||||
return streams
|
||||
|
||||
@staticmethod
|
||||
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
|
||||
# sourcery skip: for-append-to-extend
|
||||
"""获取所有群聊聊天流
|
||||
|
||||
Args:
|
||||
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
|
||||
|
||||
Returns:
|
||||
List[ChatStream]: 群聊聊天流列表
|
||||
"""
|
||||
if not isinstance(platform, (str, SpecialTypes)):
|
||||
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||
streams = []
|
||||
try:
|
||||
for _, stream in get_chat_manager().streams.items():
|
||||
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and stream.group_info:
|
||||
streams.append(stream)
|
||||
logger.debug(f"[ChatAPI] 获取到 {len(streams)} 个 {platform} 平台的群聊流")
|
||||
except Exception as e:
|
||||
logger.error(f"[ChatAPI] 获取群聊流失败: {e}")
|
||||
return streams
|
||||
|
||||
@staticmethod
|
||||
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
|
||||
# sourcery skip: for-append-to-extend
|
||||
"""获取所有私聊聊天流
|
||||
|
||||
Args:
|
||||
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
|
||||
|
||||
Returns:
|
||||
List[ChatStream]: 私聊聊天流列表
|
||||
|
||||
Raises:
|
||||
TypeError: 如果 platform 不是字符串或 SpecialTypes 枚举类型
|
||||
"""
|
||||
if not isinstance(platform, (str, SpecialTypes)):
|
||||
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||
streams = []
|
||||
try:
|
||||
for _, stream in get_chat_manager().streams.items():
|
||||
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and not stream.group_info:
|
||||
streams.append(stream)
|
||||
logger.debug(f"[ChatAPI] 获取到 {len(streams)} 个 {platform} 平台的私聊流")
|
||||
except Exception as e:
|
||||
logger.error(f"[ChatAPI] 获取私聊流失败: {e}")
|
||||
return streams
|
||||
|
||||
@staticmethod
|
||||
def get_group_stream_by_group_id(
|
||||
group_id: str, platform: Optional[str] | SpecialTypes = "qq"
|
||||
) -> Optional[ChatStream]: # sourcery skip: remove-unnecessary-cast
|
||||
"""根据群ID获取聊天流
|
||||
|
||||
Args:
|
||||
group_id: 群聊ID
|
||||
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
|
||||
|
||||
Returns:
|
||||
Optional[ChatStream]: 聊天流对象,如果未找到返回None
|
||||
|
||||
Raises:
|
||||
ValueError: 如果 group_id 为空字符串
|
||||
TypeError: 如果 group_id 不是字符串类型或 platform 不是字符串或 SpecialTypes
|
||||
"""
|
||||
if not isinstance(group_id, str):
|
||||
raise TypeError("group_id 必须是字符串类型")
|
||||
if not isinstance(platform, (str, SpecialTypes)):
|
||||
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||
if not group_id:
|
||||
raise ValueError("group_id 不能为空")
|
||||
try:
|
||||
for _, stream in get_chat_manager().streams.items():
|
||||
if (
|
||||
stream.group_info
|
||||
and str(stream.group_info.group_id) == str(group_id)
|
||||
and stream.platform == platform
|
||||
):
|
||||
logger.debug(f"[ChatAPI] 找到群ID {group_id} 的聊天流")
|
||||
return stream
|
||||
logger.warning(f"[ChatAPI] 未找到群ID {group_id} 的聊天流")
|
||||
except Exception as e:
|
||||
logger.error(f"[ChatAPI] 查找群聊流失败: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_private_stream_by_user_id(
|
||||
user_id: str, platform: Optional[str] | SpecialTypes = "qq"
|
||||
) -> Optional[ChatStream]: # sourcery skip: remove-unnecessary-cast
|
||||
"""根据用户ID获取私聊流
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
platform: 平台筛选,默认为"qq", 可以使用 SpecialTypes.ALL_PLATFORMS 获取所有平台的群聊流
|
||||
|
||||
Returns:
|
||||
Optional[ChatStream]: 聊天流对象,如果未找到返回None
|
||||
|
||||
Raises:
|
||||
ValueError: 如果 user_id 为空字符串
|
||||
TypeError: 如果 user_id 不是字符串类型或 platform 不是字符串或 SpecialTypes
|
||||
"""
|
||||
if not isinstance(user_id, str):
|
||||
raise TypeError("user_id 必须是字符串类型")
|
||||
if not isinstance(platform, (str, SpecialTypes)):
|
||||
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||
if not user_id:
|
||||
raise ValueError("user_id 不能为空")
|
||||
try:
|
||||
for _, stream in get_chat_manager().streams.items():
|
||||
if (
|
||||
not stream.group_info
|
||||
and str(stream.user_info.user_id) == str(user_id)
|
||||
and stream.platform == platform
|
||||
):
|
||||
logger.debug(f"[ChatAPI] 找到用户ID {user_id} 的私聊流")
|
||||
return stream
|
||||
logger.warning(f"[ChatAPI] 未找到用户ID {user_id} 的私聊流")
|
||||
except Exception as e:
|
||||
logger.error(f"[ChatAPI] 查找私聊流失败: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_stream_type(chat_stream: ChatStream) -> str:
|
||||
"""获取聊天流类型
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天流对象
|
||||
|
||||
Returns:
|
||||
str: 聊天类型 ("group", "private", "unknown")
|
||||
|
||||
Raises:
|
||||
TypeError: 如果 chat_stream 不是 ChatStream 类型
|
||||
ValueError: 如果 chat_stream 为空
|
||||
"""
|
||||
if not isinstance(chat_stream, ChatStream):
|
||||
raise TypeError("chat_stream 必须是 ChatStream 类型")
|
||||
if not chat_stream:
|
||||
raise ValueError("chat_stream 不能为 None")
|
||||
|
||||
if hasattr(chat_stream, "group_info"):
|
||||
return "group" if chat_stream.group_info else "private"
|
||||
return "unknown"
|
||||
|
||||
@staticmethod
|
||||
def get_stream_info(chat_stream: ChatStream) -> Dict[str, Any]:
|
||||
"""获取聊天流详细信息
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天流对象
|
||||
|
||||
Returns:
|
||||
Dict ({str: Any}): 聊天流信息字典
|
||||
|
||||
Raises:
|
||||
TypeError: 如果 chat_stream 不是 ChatStream 类型
|
||||
ValueError: 如果 chat_stream 为空
|
||||
"""
|
||||
if not chat_stream:
|
||||
raise ValueError("chat_stream 不能为 None")
|
||||
if not isinstance(chat_stream, ChatStream):
|
||||
raise TypeError("chat_stream 必须是 ChatStream 类型")
|
||||
|
||||
try:
|
||||
info: Dict[str, Any] = {
|
||||
"stream_id": chat_stream.stream_id,
|
||||
"platform": chat_stream.platform,
|
||||
"type": ChatManager.get_stream_type(chat_stream),
|
||||
}
|
||||
|
||||
if chat_stream.group_info:
|
||||
info.update(
|
||||
{
|
||||
"group_id": chat_stream.group_info.group_id,
|
||||
"group_name": getattr(chat_stream.group_info, "group_name", "未知群聊"),
|
||||
}
|
||||
)
|
||||
|
||||
if chat_stream.user_info:
|
||||
info.update(
|
||||
{
|
||||
"user_id": chat_stream.user_info.user_id,
|
||||
"user_name": chat_stream.user_info.user_nickname,
|
||||
}
|
||||
)
|
||||
|
||||
return info
|
||||
except Exception as e:
|
||||
logger.error(f"[ChatAPI] 获取聊天流信息失败: {e}")
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def get_streams_summary() -> Dict[str, int]:
|
||||
"""获取聊天流统计摘要
|
||||
|
||||
Returns:
|
||||
Dict[str, int]: 包含各种统计信息的字典
|
||||
"""
|
||||
try:
|
||||
all_streams = ChatManager.get_all_streams(SpecialTypes.ALL_PLATFORMS)
|
||||
group_streams = ChatManager.get_group_streams(SpecialTypes.ALL_PLATFORMS)
|
||||
private_streams = ChatManager.get_private_streams(SpecialTypes.ALL_PLATFORMS)
|
||||
|
||||
summary = {
|
||||
"total_streams": len(all_streams),
|
||||
"group_streams": len(group_streams),
|
||||
"private_streams": len(private_streams),
|
||||
"qq_streams": len([s for s in all_streams if s.platform == "qq"]),
|
||||
}
|
||||
|
||||
logger.debug(f"[ChatAPI] 聊天流统计: {summary}")
|
||||
return summary
|
||||
except Exception as e:
|
||||
logger.error(f"[ChatAPI] 获取聊天流统计失败: {e}")
|
||||
return {
|
||||
"total_streams": 0,
|
||||
"group_streams": 0,
|
||||
"private_streams": 0,
|
||||
"qq_streams": 0,
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 模块级别的便捷函数 - 类似 requests.get(), requests.post() 的设计
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_all_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
|
||||
"""获取所有聊天流的便捷函数"""
|
||||
return ChatManager.get_all_streams(platform)
|
||||
|
||||
|
||||
def get_group_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
|
||||
"""获取群聊聊天流的便捷函数"""
|
||||
return ChatManager.get_group_streams(platform)
|
||||
|
||||
|
||||
def get_private_streams(platform: Optional[str] | SpecialTypes = "qq") -> List[ChatStream]:
|
||||
"""获取私聊聊天流的便捷函数"""
|
||||
return ChatManager.get_private_streams(platform)
|
||||
|
||||
|
||||
def get_stream_by_group_id(group_id: str, platform: Optional[str] | SpecialTypes = "qq") -> Optional[ChatStream]:
|
||||
"""根据群ID获取聊天流的便捷函数"""
|
||||
return ChatManager.get_group_stream_by_group_id(group_id, platform)
|
||||
|
||||
|
||||
def get_stream_by_user_id(user_id: str, platform: Optional[str] | SpecialTypes = "qq") -> Optional[ChatStream]:
|
||||
"""根据用户ID获取私聊流的便捷函数"""
|
||||
return ChatManager.get_private_stream_by_user_id(user_id, platform)
|
||||
|
||||
|
||||
def get_stream_type(chat_stream: ChatStream) -> str:
|
||||
"""获取聊天流类型的便捷函数"""
|
||||
return ChatManager.get_stream_type(chat_stream)
|
||||
|
||||
|
||||
def get_stream_info(chat_stream: ChatStream) -> Dict[str, Any]:
|
||||
"""获取聊天流信息的便捷函数"""
|
||||
return ChatManager.get_stream_info(chat_stream)
|
||||
|
||||
|
||||
def get_streams_summary() -> Dict[str, int]:
|
||||
"""获取聊天流统计摘要的便捷函数"""
|
||||
return ChatManager.get_streams_summary()
|
||||
268
src/plugin_system/apis/component_manage_api.py
Normal file
268
src/plugin_system/apis/component_manage_api.py
Normal file
@@ -0,0 +1,268 @@
|
||||
from typing import Optional, Union, Dict
|
||||
from src.plugin_system.base.component_types import (
|
||||
CommandInfo,
|
||||
ActionInfo,
|
||||
EventHandlerInfo,
|
||||
PluginInfo,
|
||||
ComponentType,
|
||||
ToolInfo,
|
||||
)
|
||||
|
||||
|
||||
# === 插件信息查询 ===
|
||||
def get_all_plugin_info() -> Dict[str, PluginInfo]:
|
||||
"""
|
||||
获取所有插件的信息。
|
||||
|
||||
Returns:
|
||||
dict: 包含所有插件信息的字典,键为插件名称,值为 PluginInfo 对象。
|
||||
"""
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
return component_registry.get_all_plugins()
|
||||
|
||||
|
||||
def get_plugin_info(plugin_name: str) -> Optional[PluginInfo]:
|
||||
"""
|
||||
获取指定插件的信息。
|
||||
|
||||
Args:
|
||||
plugin_name (str): 插件名称。
|
||||
|
||||
Returns:
|
||||
PluginInfo: 插件信息对象,如果插件不存在则返回 None。
|
||||
"""
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
return component_registry.get_plugin_info(plugin_name)
|
||||
|
||||
|
||||
# === 组件查询方法 ===
|
||||
def get_component_info(
|
||||
component_name: str, component_type: ComponentType
|
||||
) -> Optional[Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
|
||||
"""
|
||||
获取指定组件的信息。
|
||||
|
||||
Args:
|
||||
component_name (str): 组件名称。
|
||||
component_type (ComponentType): 组件类型。
|
||||
Returns:
|
||||
Union[CommandInfo, ActionInfo, EventHandlerInfo]: 组件信息对象,如果组件不存在则返回 None。
|
||||
"""
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
return component_registry.get_component_info(component_name, component_type) # type: ignore
|
||||
|
||||
|
||||
def get_components_info_by_type(
|
||||
component_type: ComponentType,
|
||||
) -> Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
|
||||
"""
|
||||
获取指定类型的所有组件信息。
|
||||
|
||||
Args:
|
||||
component_type (ComponentType): 组件类型。
|
||||
|
||||
Returns:
|
||||
dict: 包含指定类型组件信息的字典,键为组件名称,值为对应的组件信息对象。
|
||||
"""
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
return component_registry.get_components_by_type(component_type) # type: ignore
|
||||
|
||||
|
||||
def get_enabled_components_info_by_type(
|
||||
component_type: ComponentType,
|
||||
) -> Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
|
||||
"""
|
||||
获取指定类型的所有启用的组件信息。
|
||||
|
||||
Args:
|
||||
component_type (ComponentType): 组件类型。
|
||||
|
||||
Returns:
|
||||
dict: 包含指定类型启用组件信息的字典,键为组件名称,值为对应的组件信息对象。
|
||||
"""
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
return component_registry.get_enabled_components_by_type(component_type) # type: ignore
|
||||
|
||||
|
||||
# === Action 查询方法 ===
|
||||
def get_registered_action_info(action_name: str) -> Optional[ActionInfo]:
|
||||
"""
|
||||
获取指定 Action 的注册信息。
|
||||
|
||||
Args:
|
||||
action_name (str): Action 名称。
|
||||
|
||||
Returns:
|
||||
ActionInfo: Action 信息对象,如果 Action 不存在则返回 None。
|
||||
"""
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
return component_registry.get_registered_action_info(action_name)
|
||||
|
||||
|
||||
def get_registered_command_info(command_name: str) -> Optional[CommandInfo]:
|
||||
"""
|
||||
获取指定 Command 的注册信息。
|
||||
|
||||
Args:
|
||||
command_name (str): Command 名称。
|
||||
|
||||
Returns:
|
||||
CommandInfo: Command 信息对象,如果 Command 不存在则返回 None。
|
||||
"""
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
return component_registry.get_registered_command_info(command_name)
|
||||
|
||||
|
||||
def get_registered_tool_info(tool_name: str) -> Optional[ToolInfo]:
|
||||
"""
|
||||
获取指定 Tool 的注册信息。
|
||||
|
||||
Args:
|
||||
tool_name (str): Tool 名称。
|
||||
|
||||
Returns:
|
||||
ToolInfo: Tool 信息对象,如果 Tool 不存在则返回 None。
|
||||
"""
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
return component_registry.get_registered_tool_info(tool_name)
|
||||
|
||||
|
||||
# === EventHandler 特定查询方法 ===
|
||||
def get_registered_event_handler_info(
|
||||
event_handler_name: str,
|
||||
) -> Optional[EventHandlerInfo]:
|
||||
"""
|
||||
获取指定 EventHandler 的注册信息。
|
||||
|
||||
Args:
|
||||
event_handler_name (str): EventHandler 名称。
|
||||
|
||||
Returns:
|
||||
EventHandlerInfo: EventHandler 信息对象,如果 EventHandler 不存在则返回 None。
|
||||
"""
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
return component_registry.get_registered_event_handler_info(event_handler_name)
|
||||
|
||||
|
||||
# === 组件管理方法 ===
|
||||
def globally_enable_component(component_name: str, component_type: ComponentType) -> bool:
|
||||
"""
|
||||
全局启用指定组件。
|
||||
|
||||
Args:
|
||||
component_name (str): 组件名称。
|
||||
component_type (ComponentType): 组件类型。
|
||||
|
||||
Returns:
|
||||
bool: 启用成功返回 True,否则返回 False。
|
||||
"""
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
return component_registry.enable_component(component_name, component_type)
|
||||
|
||||
|
||||
async def globally_disable_component(component_name: str, component_type: ComponentType) -> bool:
|
||||
"""
|
||||
全局禁用指定组件。
|
||||
|
||||
**此函数是异步的,确保在异步环境中调用。**
|
||||
|
||||
Args:
|
||||
component_name (str): 组件名称。
|
||||
component_type (ComponentType): 组件类型。
|
||||
|
||||
Returns:
|
||||
bool: 禁用成功返回 True,否则返回 False。
|
||||
"""
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
return await component_registry.disable_component(component_name, component_type)
|
||||
|
||||
|
||||
def locally_enable_component(component_name: str, component_type: ComponentType, stream_id: str) -> bool:
|
||||
"""
|
||||
局部启用指定组件。
|
||||
|
||||
Args:
|
||||
component_name (str): 组件名称。
|
||||
component_type (ComponentType): 组件类型。
|
||||
stream_id (str): 消息流 ID。
|
||||
|
||||
Returns:
|
||||
bool: 启用成功返回 True,否则返回 False。
|
||||
"""
|
||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||
|
||||
match component_type:
|
||||
case ComponentType.ACTION:
|
||||
return global_announcement_manager.enable_specific_chat_action(stream_id, component_name)
|
||||
case ComponentType.COMMAND:
|
||||
return global_announcement_manager.enable_specific_chat_command(stream_id, component_name)
|
||||
case ComponentType.TOOL:
|
||||
return global_announcement_manager.enable_specific_chat_tool(stream_id, component_name)
|
||||
case ComponentType.EVENT_HANDLER:
|
||||
return global_announcement_manager.enable_specific_chat_event_handler(stream_id, component_name)
|
||||
case _:
|
||||
raise ValueError(f"未知 component type: {component_type}")
|
||||
|
||||
|
||||
def locally_disable_component(component_name: str, component_type: ComponentType, stream_id: str) -> bool:
|
||||
"""
|
||||
局部禁用指定组件。
|
||||
|
||||
Args:
|
||||
component_name (str): 组件名称。
|
||||
component_type (ComponentType): 组件类型。
|
||||
stream_id (str): 消息流 ID。
|
||||
|
||||
Returns:
|
||||
bool: 禁用成功返回 True,否则返回 False。
|
||||
"""
|
||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||
|
||||
match component_type:
|
||||
case ComponentType.ACTION:
|
||||
return global_announcement_manager.disable_specific_chat_action(stream_id, component_name)
|
||||
case ComponentType.COMMAND:
|
||||
return global_announcement_manager.disable_specific_chat_command(stream_id, component_name)
|
||||
case ComponentType.TOOL:
|
||||
return global_announcement_manager.disable_specific_chat_tool(stream_id, component_name)
|
||||
case ComponentType.EVENT_HANDLER:
|
||||
return global_announcement_manager.disable_specific_chat_event_handler(stream_id, component_name)
|
||||
case _:
|
||||
raise ValueError(f"未知 component type: {component_type}")
|
||||
|
||||
|
||||
def get_locally_disabled_components(stream_id: str, component_type: ComponentType) -> list[str]:
|
||||
"""
|
||||
获取指定消息流中禁用的组件列表。
|
||||
|
||||
Args:
|
||||
stream_id (str): 消息流 ID。
|
||||
component_type (ComponentType): 组件类型。
|
||||
|
||||
Returns:
|
||||
list[str]: 禁用的组件名称列表。
|
||||
"""
|
||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||
|
||||
match component_type:
|
||||
case ComponentType.ACTION:
|
||||
return global_announcement_manager.get_disabled_chat_actions(stream_id)
|
||||
case ComponentType.COMMAND:
|
||||
return global_announcement_manager.get_disabled_chat_commands(stream_id)
|
||||
case ComponentType.TOOL:
|
||||
return global_announcement_manager.get_disabled_chat_tools(stream_id)
|
||||
case ComponentType.EVENT_HANDLER:
|
||||
return global_announcement_manager.get_disabled_chat_event_handlers(stream_id)
|
||||
case _:
|
||||
raise ValueError(f"未知 component type: {component_type}")
|
||||
77
src/plugin_system/apis/config_api.py
Normal file
77
src/plugin_system/apis/config_api.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""配置API模块
|
||||
|
||||
提供了配置读取和用户信息获取等功能
|
||||
使用方式:
|
||||
from src.plugin_system.apis import config_api
|
||||
value = config_api.get_global_config("section.key")
|
||||
platform, user_id = await config_api.get_user_id_by_person_name("用户名")
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger("config_api")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 配置访问API函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_global_config(key: str, default: Any = None) -> Any:
|
||||
"""
|
||||
安全地从全局配置中获取一个值。
|
||||
插件应使用此方法读取全局配置,以保证只读和隔离性。
|
||||
|
||||
Args:
|
||||
key: 命名空间式配置键名,使用嵌套访问,如 "section.subsection.key",大小写敏感
|
||||
default: 如果配置不存在时返回的默认值
|
||||
|
||||
Returns:
|
||||
Any: 配置值或默认值
|
||||
"""
|
||||
# 支持嵌套键访问
|
||||
keys = key.split(".")
|
||||
current = global_config
|
||||
|
||||
try:
|
||||
for k in keys:
|
||||
if hasattr(current, k):
|
||||
current = getattr(current, k)
|
||||
else:
|
||||
raise KeyError(f"配置中不存在子空间或键 '{k}'")
|
||||
return current
|
||||
except Exception as e:
|
||||
logger.warning(f"[ConfigAPI] 获取全局配置 {key} 失败: {e}")
|
||||
return default
|
||||
|
||||
|
||||
def get_plugin_config(plugin_config: dict, key: str, default: Any = None) -> Any:
|
||||
"""
|
||||
从插件配置中获取值,支持嵌套键访问
|
||||
|
||||
Args:
|
||||
plugin_config: 插件配置字典
|
||||
key: 配置键名,支持嵌套访问如 "section.subsection.key",大小写敏感
|
||||
default: 如果配置不存在时返回的默认值
|
||||
|
||||
Returns:
|
||||
Any: 配置值或默认值
|
||||
"""
|
||||
# 支持嵌套键访问
|
||||
keys = key.split(".")
|
||||
current = plugin_config
|
||||
|
||||
try:
|
||||
for k in keys:
|
||||
if isinstance(current, dict) and k in current:
|
||||
current = current[k]
|
||||
elif hasattr(current, k):
|
||||
current = getattr(current, k)
|
||||
else:
|
||||
raise KeyError(f"配置中不存在子空间或键 '{k}'")
|
||||
return current
|
||||
except Exception as e:
|
||||
logger.warning(f"[ConfigAPI] 获取插件配置 {key} 失败: {e}")
|
||||
return default
|
||||
29
src/plugin_system/apis/database_api.py
Normal file
29
src/plugin_system/apis/database_api.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""数据库API模块
|
||||
|
||||
提供数据库操作相关功能,采用标准Python包设计模式
|
||||
使用方式:
|
||||
from src.plugin_system.apis import database_api
|
||||
records = await database_api.db_query(ActionRecords, query_type="get")
|
||||
record = await database_api.db_save(ActionRecords, data={"action_id": "123"})
|
||||
|
||||
注意:此模块现在使用SQLAlchemy实现,提供更好的连接管理和错误处理
|
||||
"""
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import (
|
||||
db_query,
|
||||
db_save,
|
||||
db_get,
|
||||
store_action_info,
|
||||
get_model_class,
|
||||
MODEL_MAPPING
|
||||
)
|
||||
|
||||
# 保持向后兼容性
|
||||
__all__ = [
|
||||
'db_query',
|
||||
'db_save',
|
||||
'db_get',
|
||||
'store_action_info',
|
||||
'get_model_class',
|
||||
'MODEL_MAPPING'
|
||||
]
|
||||
268
src/plugin_system/apis/emoji_api.py
Normal file
268
src/plugin_system/apis/emoji_api.py
Normal file
@@ -0,0 +1,268 @@
|
||||
"""
|
||||
表情API模块
|
||||
|
||||
提供表情包相关功能,采用标准Python包设计模式
|
||||
使用方式:
|
||||
from src.plugin_system.apis import emoji_api
|
||||
result = await emoji_api.get_by_description("开心")
|
||||
count = emoji_api.get_count()
|
||||
"""
|
||||
|
||||
import random
|
||||
|
||||
from typing import Optional, Tuple, List
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
from src.chat.utils.utils_image import image_path_to_base64
|
||||
|
||||
logger = get_logger("emoji_api")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 表情包获取API函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]:
|
||||
"""根据描述选择表情包
|
||||
|
||||
Args:
|
||||
description: 表情包的描述文本,例如"开心"、"难过"、"愤怒"等
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[str, str, str]]: (base64编码, 表情包描述, 匹配的情感标签) 或 None
|
||||
|
||||
Raises:
|
||||
ValueError: 如果描述为空字符串
|
||||
TypeError: 如果描述不是字符串类型
|
||||
"""
|
||||
if not description:
|
||||
raise ValueError("描述不能为空")
|
||||
if not isinstance(description, str):
|
||||
raise TypeError("描述必须是字符串类型")
|
||||
try:
|
||||
logger.debug(f"[EmojiAPI] 根据描述获取表情包: {description}")
|
||||
|
||||
emoji_manager = get_emoji_manager()
|
||||
emoji_result = await emoji_manager.get_emoji_for_text(description)
|
||||
|
||||
if not emoji_result:
|
||||
logger.warning(f"[EmojiAPI] 未找到匹配描述 '{description}' 的表情包")
|
||||
return None
|
||||
|
||||
emoji_path, emoji_description, matched_emotion = emoji_result
|
||||
emoji_base64 = image_path_to_base64(emoji_path)
|
||||
|
||||
if not emoji_base64:
|
||||
logger.error(f"[EmojiAPI] 无法将表情包文件转换为base64: {emoji_path}")
|
||||
return None
|
||||
|
||||
logger.debug(f"[EmojiAPI] 成功获取表情包: {emoji_description}, 匹配情感: {matched_emotion}")
|
||||
return emoji_base64, emoji_description, matched_emotion
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[EmojiAPI] 获取表情包失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]:
|
||||
"""随机获取指定数量的表情包
|
||||
|
||||
Args:
|
||||
count: 要获取的表情包数量,默认为1
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str, str]]: 包含(base64编码, 表情包描述, 随机情感标签)的元组列表,失败则返回空列表
|
||||
|
||||
Raises:
|
||||
TypeError: 如果count不是整数类型
|
||||
ValueError: 如果count为负数
|
||||
"""
|
||||
if not isinstance(count, int):
|
||||
raise TypeError("count 必须是整数类型")
|
||||
if count < 0:
|
||||
raise ValueError("count 不能为负数")
|
||||
if count == 0:
|
||||
logger.warning("[EmojiAPI] count 为0,返回空列表")
|
||||
return []
|
||||
|
||||
try:
|
||||
logger.info(f"[EmojiAPI] 随机获取 {count} 个表情包")
|
||||
|
||||
emoji_manager = get_emoji_manager()
|
||||
all_emojis = emoji_manager.emoji_objects
|
||||
|
||||
if not all_emojis:
|
||||
logger.warning("[EmojiAPI] 没有可用的表情包")
|
||||
return []
|
||||
|
||||
# 过滤有效表情包
|
||||
valid_emojis = [emoji for emoji in all_emojis if not emoji.is_deleted]
|
||||
if not valid_emojis:
|
||||
logger.warning("[EmojiAPI] 没有有效的表情包")
|
||||
return []
|
||||
|
||||
if len(valid_emojis) < count:
|
||||
logger.warning(
|
||||
f"[EmojiAPI] 有效表情包数量 ({len(valid_emojis)}) 少于请求的数量 ({count}),将返回所有有效表情包"
|
||||
)
|
||||
count = len(valid_emojis)
|
||||
|
||||
# 随机选择
|
||||
selected_emojis = random.sample(valid_emojis, count)
|
||||
|
||||
results = []
|
||||
for selected_emoji in selected_emojis:
|
||||
emoji_base64 = image_path_to_base64(selected_emoji.full_path)
|
||||
|
||||
if not emoji_base64:
|
||||
logger.error(f"[EmojiAPI] 无法转换表情包为base64: {selected_emoji.full_path}")
|
||||
continue
|
||||
|
||||
matched_emotion = random.choice(selected_emoji.emotion) if selected_emoji.emotion else "随机表情"
|
||||
|
||||
# 记录使用次数
|
||||
emoji_manager.record_usage(selected_emoji.hash)
|
||||
results.append((emoji_base64, selected_emoji.description, matched_emotion))
|
||||
|
||||
if not results and count > 0:
|
||||
logger.warning("[EmojiAPI] 随机获取表情包失败,没有一个可以成功处理")
|
||||
return []
|
||||
|
||||
logger.info(f"[EmojiAPI] 成功获取 {len(results)} 个随机表情包")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[EmojiAPI] 获取随机表情包失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
|
||||
"""根据情感标签获取表情包
|
||||
|
||||
Args:
|
||||
emotion: 情感标签,如"happy"、"sad"、"angry"等
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[str, str, str]]: (base64编码, 表情包描述, 匹配的情感标签) 或 None
|
||||
|
||||
Raises:
|
||||
ValueError: 如果情感标签为空字符串
|
||||
TypeError: 如果情感标签不是字符串类型
|
||||
"""
|
||||
if not emotion:
|
||||
raise ValueError("情感标签不能为空")
|
||||
if not isinstance(emotion, str):
|
||||
raise TypeError("情感标签必须是字符串类型")
|
||||
try:
|
||||
logger.info(f"[EmojiAPI] 根据情感获取表情包: {emotion}")
|
||||
|
||||
emoji_manager = get_emoji_manager()
|
||||
all_emojis = emoji_manager.emoji_objects
|
||||
|
||||
# 筛选匹配情感的表情包
|
||||
matching_emojis = []
|
||||
matching_emojis.extend(
|
||||
emoji_obj
|
||||
for emoji_obj in all_emojis
|
||||
if not emoji_obj.is_deleted and emotion.lower() in [e.lower() for e in emoji_obj.emotion]
|
||||
)
|
||||
if not matching_emojis:
|
||||
logger.warning(f"[EmojiAPI] 未找到匹配情感 '{emotion}' 的表情包")
|
||||
return None
|
||||
|
||||
# 随机选择匹配的表情包
|
||||
selected_emoji = random.choice(matching_emojis)
|
||||
emoji_base64 = image_path_to_base64(selected_emoji.full_path)
|
||||
|
||||
if not emoji_base64:
|
||||
logger.error(f"[EmojiAPI] 无法转换表情包为base64: {selected_emoji.full_path}")
|
||||
return None
|
||||
|
||||
# 记录使用次数
|
||||
emoji_manager.record_usage(selected_emoji.hash)
|
||||
|
||||
logger.info(f"[EmojiAPI] 成功获取情感表情包: {selected_emoji.description}")
|
||||
return emoji_base64, selected_emoji.description, emotion
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[EmojiAPI] 根据情感获取表情包失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 表情包信息查询API函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_count() -> int:
|
||||
"""获取表情包数量
|
||||
|
||||
Returns:
|
||||
int: 当前可用的表情包数量
|
||||
"""
|
||||
try:
|
||||
emoji_manager = get_emoji_manager()
|
||||
return emoji_manager.emoji_num
|
||||
except Exception as e:
|
||||
logger.error(f"[EmojiAPI] 获取表情包数量失败: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
def get_info():
|
||||
"""获取表情包系统信息
|
||||
|
||||
Returns:
|
||||
dict: 包含表情包数量、最大数量、可用数量信息
|
||||
"""
|
||||
try:
|
||||
emoji_manager = get_emoji_manager()
|
||||
return {
|
||||
"current_count": emoji_manager.emoji_num,
|
||||
"max_count": emoji_manager.emoji_num_max,
|
||||
"available_emojis": len([e for e in emoji_manager.emoji_objects if not e.is_deleted]),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"[EmojiAPI] 获取表情包信息失败: {e}")
|
||||
return {"current_count": 0, "max_count": 0, "available_emojis": 0}
|
||||
|
||||
|
||||
def get_emotions() -> List[str]:
|
||||
"""获取所有可用的情感标签
|
||||
|
||||
Returns:
|
||||
list: 所有表情包的情感标签列表(去重)
|
||||
"""
|
||||
try:
|
||||
emoji_manager = get_emoji_manager()
|
||||
emotions = set()
|
||||
|
||||
for emoji_obj in emoji_manager.emoji_objects:
|
||||
if not emoji_obj.is_deleted and emoji_obj.emotion:
|
||||
emotions.update(emoji_obj.emotion)
|
||||
|
||||
return sorted(list(emotions))
|
||||
except Exception as e:
|
||||
logger.error(f"[EmojiAPI] 获取情感标签失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def get_descriptions() -> List[str]:
|
||||
"""获取所有表情包描述
|
||||
|
||||
Returns:
|
||||
list: 所有可用表情包的描述列表
|
||||
"""
|
||||
try:
|
||||
emoji_manager = get_emoji_manager()
|
||||
descriptions = []
|
||||
|
||||
descriptions.extend(
|
||||
emoji_obj.description
|
||||
for emoji_obj in emoji_manager.emoji_objects
|
||||
if not emoji_obj.is_deleted and emoji_obj.description
|
||||
)
|
||||
return descriptions
|
||||
except Exception as e:
|
||||
logger.error(f"[EmojiAPI] 获取表情包描述失败: {e}")
|
||||
return []
|
||||
280
src/plugin_system/apis/generator_api.py
Normal file
280
src/plugin_system/apis/generator_api.py
Normal file
@@ -0,0 +1,280 @@
|
||||
"""
|
||||
回复器API模块
|
||||
|
||||
提供回复器相关功能,采用标准Python包设计模式
|
||||
使用方式:
|
||||
from src.plugin_system.apis import generator_api
|
||||
replyer = generator_api.get_replyer(chat_stream)
|
||||
success, reply_set, _ = await generator_api.generate_reply(chat_stream, action_data, reasoning)
|
||||
"""
|
||||
|
||||
import traceback
|
||||
from typing import Tuple, Any, Dict, List, Optional
|
||||
from rich.traceback import install
|
||||
from src.common.logger import get_logger
|
||||
from src.config.api_ada_configs import TaskConfig
|
||||
from src.chat.replyer.default_generator import DefaultReplyer
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.utils.utils import process_llm_response
|
||||
from src.chat.replyer.replyer_manager import replyer_manager
|
||||
from src.plugin_system.base.component_types import ActionInfo
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("generator_api")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 回复器获取API函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_replyer(
|
||||
chat_stream: Optional[ChatStream] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
|
||||
request_type: str = "replyer",
|
||||
) -> Optional[DefaultReplyer]:
|
||||
"""获取回复器对象
|
||||
|
||||
优先使用chat_stream,如果没有则使用chat_id直接查找。
|
||||
使用 ReplyerManager 来管理实例,避免重复创建。
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天流对象(优先)
|
||||
chat_id: 聊天ID(实际上就是stream_id)
|
||||
model_set_with_weight: 模型配置列表,每个元素为 (TaskConfig, weight) 元组
|
||||
request_type: 请求类型
|
||||
|
||||
Returns:
|
||||
Optional[DefaultReplyer]: 回复器对象,如果获取失败则返回None
|
||||
|
||||
Raises:
|
||||
ValueError: chat_stream 和 chat_id 均为空
|
||||
"""
|
||||
if not chat_id and not chat_stream:
|
||||
raise ValueError("chat_stream 和 chat_id 不可均为空")
|
||||
try:
|
||||
logger.debug(f"[GeneratorAPI] 正在获取回复器,chat_id: {chat_id}, chat_stream: {'有' if chat_stream else '无'}")
|
||||
return replyer_manager.get_replyer(
|
||||
chat_stream=chat_stream,
|
||||
chat_id=chat_id,
|
||||
model_set_with_weight=model_set_with_weight,
|
||||
request_type=request_type,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[GeneratorAPI] 获取回复器时发生意外错误: {e}", exc_info=True)
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 回复生成API函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def generate_reply(
|
||||
chat_stream: Optional[ChatStream] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
action_data: Optional[Dict[str, Any]] = None,
|
||||
reply_to: str = "",
|
||||
extra_info: str = "",
|
||||
available_actions: Optional[Dict[str, ActionInfo]] = None,
|
||||
enable_tool: bool = False,
|
||||
enable_splitter: bool = True,
|
||||
enable_chinese_typo: bool = True,
|
||||
return_prompt: bool = False,
|
||||
model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
|
||||
request_type: str = "generator_api",
|
||||
from_plugin: bool = True,
|
||||
) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]:
|
||||
"""生成回复
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天流对象(优先)
|
||||
chat_id: 聊天ID(备用)
|
||||
action_data: 动作数据(向下兼容,包含reply_to和extra_info)
|
||||
reply_to: 回复对象,格式为 "发送者:消息内容"
|
||||
extra_info: 额外信息,用于补充上下文
|
||||
available_actions: 可用动作
|
||||
enable_tool: 是否启用工具调用
|
||||
enable_splitter: 是否启用消息分割器
|
||||
enable_chinese_typo: 是否启用错字生成器
|
||||
return_prompt: 是否返回提示词
|
||||
model_set_with_weight: 模型配置列表,每个元素为 (TaskConfig, weight) 元组
|
||||
request_type: 请求类型(可选,记录LLM使用)
|
||||
from_plugin: 是否来自插件
|
||||
Returns:
|
||||
Tuple[bool, List[Tuple[str, Any]], Optional[str]]: (是否成功, 回复集合, 提示词)
|
||||
"""
|
||||
try:
|
||||
# 获取回复器
|
||||
replyer = get_replyer(
|
||||
chat_stream, chat_id, model_set_with_weight=model_set_with_weight, request_type=request_type
|
||||
)
|
||||
if not replyer:
|
||||
logger.error("[GeneratorAPI] 无法获取回复器")
|
||||
return False, [], None
|
||||
|
||||
logger.debug("[GeneratorAPI] 开始生成回复")
|
||||
|
||||
if not reply_to and action_data:
|
||||
reply_to = action_data.get("reply_to", "")
|
||||
if not extra_info and action_data:
|
||||
extra_info = action_data.get("extra_info", "")
|
||||
|
||||
# 调用回复器生成回复
|
||||
success, llm_response_dict, prompt = await replyer.generate_reply_with_context(
|
||||
reply_to=reply_to,
|
||||
extra_info=extra_info,
|
||||
available_actions=available_actions,
|
||||
enable_tool=enable_tool,
|
||||
from_plugin=from_plugin,
|
||||
stream_id=chat_stream.stream_id if chat_stream else chat_id,
|
||||
)
|
||||
if not success:
|
||||
logger.warning("[GeneratorAPI] 回复生成失败")
|
||||
return False, [], None
|
||||
assert llm_response_dict is not None, "llm_response_dict不应为None" # 虽然说不会出现llm_response为空的情况
|
||||
if content := llm_response_dict.get("content", ""):
|
||||
reply_set = process_human_text(content, enable_splitter, enable_chinese_typo)
|
||||
else:
|
||||
reply_set = []
|
||||
logger.debug(f"[GeneratorAPI] 回复生成成功,生成了 {len(reply_set)} 个回复项")
|
||||
|
||||
if return_prompt:
|
||||
return success, reply_set, prompt
|
||||
else:
|
||||
return success, reply_set, None
|
||||
|
||||
except ValueError as ve:
|
||||
raise ve
|
||||
|
||||
except UserWarning as uw:
|
||||
logger.warning(f"[GeneratorAPI] 中断了生成: {uw}")
|
||||
return False, [], None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[GeneratorAPI] 生成回复时出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return False, [], None
|
||||
|
||||
|
||||
async def rewrite_reply(
|
||||
chat_stream: Optional[ChatStream] = None,
|
||||
reply_data: Optional[Dict[str, Any]] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
enable_splitter: bool = True,
|
||||
enable_chinese_typo: bool = True,
|
||||
model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
|
||||
raw_reply: str = "",
|
||||
reason: str = "",
|
||||
reply_to: str = "",
|
||||
return_prompt: bool = False,
|
||||
) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]:
|
||||
"""重写回复
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天流对象(优先)
|
||||
reply_data: 回复数据字典(向下兼容备用,当其他参数缺失时从此获取)
|
||||
chat_id: 聊天ID(备用)
|
||||
enable_splitter: 是否启用消息分割器
|
||||
enable_chinese_typo: 是否启用错字生成器
|
||||
model_set_with_weight: 模型配置列表,每个元素为 (TaskConfig, weight) 元组
|
||||
raw_reply: 原始回复内容
|
||||
reason: 回复原因
|
||||
reply_to: 回复对象
|
||||
return_prompt: 是否返回提示词
|
||||
|
||||
Returns:
|
||||
Tuple[bool, List[Tuple[str, Any]]]: (是否成功, 回复集合)
|
||||
"""
|
||||
try:
|
||||
# 获取回复器
|
||||
replyer = get_replyer(chat_stream, chat_id, model_set_with_weight=model_set_with_weight)
|
||||
if not replyer:
|
||||
logger.error("[GeneratorAPI] 无法获取回复器")
|
||||
return False, [], None
|
||||
|
||||
logger.info("[GeneratorAPI] 开始重写回复")
|
||||
|
||||
# 如果参数缺失,从reply_data中获取
|
||||
if reply_data:
|
||||
raw_reply = raw_reply or reply_data.get("raw_reply", "")
|
||||
reason = reason or reply_data.get("reason", "")
|
||||
reply_to = reply_to or reply_data.get("reply_to", "")
|
||||
|
||||
# 调用回复器重写回复
|
||||
success, content, prompt = await replyer.rewrite_reply_with_context(
|
||||
raw_reply=raw_reply,
|
||||
reason=reason,
|
||||
reply_to=reply_to,
|
||||
return_prompt=return_prompt,
|
||||
)
|
||||
reply_set = []
|
||||
if content:
|
||||
reply_set = process_human_text(content, enable_splitter, enable_chinese_typo)
|
||||
|
||||
if success:
|
||||
logger.info(f"[GeneratorAPI] 重写回复成功,生成了 {len(reply_set)} 个回复项")
|
||||
else:
|
||||
logger.warning("[GeneratorAPI] 重写回复失败")
|
||||
|
||||
return success, reply_set, prompt if return_prompt else None
|
||||
|
||||
except ValueError as ve:
|
||||
raise ve
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[GeneratorAPI] 重写回复时出错: {e}")
|
||||
return False, [], None
|
||||
|
||||
|
||||
def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> List[Tuple[str, Any]]:
|
||||
"""将文本处理为更拟人化的文本
|
||||
|
||||
Args:
|
||||
content: 文本内容
|
||||
enable_splitter: 是否启用消息分割器
|
||||
enable_chinese_typo: 是否启用错字生成器
|
||||
"""
|
||||
if not isinstance(content, str):
|
||||
raise ValueError("content 必须是字符串类型")
|
||||
try:
|
||||
processed_response = process_llm_response(content, enable_splitter, enable_chinese_typo)
|
||||
|
||||
reply_set = []
|
||||
for text in processed_response:
|
||||
reply_seg = ("text", text)
|
||||
reply_set.append(reply_seg)
|
||||
|
||||
return reply_set
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[GeneratorAPI] 处理人形文本时出错: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def generate_response_custom(
|
||||
chat_stream: Optional[ChatStream] = None,
|
||||
chat_id: Optional[str] = None,
|
||||
model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
|
||||
prompt: str = "",
|
||||
) -> Optional[str]:
|
||||
replyer = get_replyer(chat_stream, chat_id, model_set_with_weight=model_set_with_weight)
|
||||
if not replyer:
|
||||
logger.error("[GeneratorAPI] 无法获取回复器")
|
||||
return None
|
||||
|
||||
try:
|
||||
logger.debug("[GeneratorAPI] 开始生成自定义回复")
|
||||
response, _, _, _ = await replyer.llm_generate_content(prompt)
|
||||
if response:
|
||||
logger.debug("[GeneratorAPI] 自定义回复生成成功")
|
||||
return response
|
||||
else:
|
||||
logger.warning("[GeneratorAPI] 自定义回复生成失败")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"[GeneratorAPI] 生成自定义回复时出错: {e}")
|
||||
return None
|
||||
122
src/plugin_system/apis/llm_api.py
Normal file
122
src/plugin_system/apis/llm_api.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""LLM API模块
|
||||
|
||||
提供了与LLM模型交互的功能
|
||||
使用方式:
|
||||
from src.plugin_system.apis import llm_api
|
||||
models = llm_api.get_available_models()
|
||||
success, response, reasoning, model_name = await llm_api.generate_with_model(prompt, model_config)
|
||||
"""
|
||||
|
||||
from typing import Tuple, Dict, List, Any, Optional
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.payload_content.tool_option import ToolCall
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
from src.config.api_ada_configs import TaskConfig
|
||||
|
||||
logger = get_logger("llm_api")
|
||||
|
||||
# =============================================================================
|
||||
# LLM模型API函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_available_models() -> Dict[str, TaskConfig]:
|
||||
"""获取所有可用的模型配置
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 模型配置字典,key为模型名称,value为模型配置
|
||||
"""
|
||||
try:
|
||||
# 自动获取所有属性并转换为字典形式
|
||||
models = model_config.model_task_config
|
||||
attrs = dir(models)
|
||||
rets: Dict[str, TaskConfig] = {}
|
||||
for attr in attrs:
|
||||
if not attr.startswith("__"):
|
||||
try:
|
||||
value = getattr(models, attr)
|
||||
if not callable(value) and isinstance(value, TaskConfig):
|
||||
rets[attr] = value
|
||||
except Exception as e:
|
||||
logger.debug(f"[LLMAPI] 获取属性 {attr} 失败: {e}")
|
||||
continue
|
||||
return rets
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[LLMAPI] 获取可用模型失败: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
async def generate_with_model(
|
||||
prompt: str,
|
||||
model_config: TaskConfig,
|
||||
request_type: str = "plugin.generate",
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> Tuple[bool, str, str, str]:
|
||||
"""使用指定模型生成内容
|
||||
|
||||
Args:
|
||||
prompt: 提示词
|
||||
model_config: 模型配置(从 get_available_models 获取的模型配置)
|
||||
request_type: 请求类型标识
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称)
|
||||
"""
|
||||
try:
|
||||
model_name_list = model_config.model_list
|
||||
logger.info(f"[LLMAPI] 使用模型集合 {model_name_list} 生成内容")
|
||||
logger.debug(f"[LLMAPI] 完整提示词: {prompt}")
|
||||
|
||||
llm_request = LLMRequest(model_set=model_config, request_type=request_type)
|
||||
|
||||
response, (reasoning_content, model_name, _) = await llm_request.generate_response_async(prompt, temperature=temperature, max_tokens=max_tokens)
|
||||
return True, response, reasoning_content, model_name
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"生成内容时出错: {str(e)}"
|
||||
logger.error(f"[LLMAPI] {error_msg}")
|
||||
return False, error_msg, "", ""
|
||||
|
||||
async def generate_with_model_with_tools(
|
||||
prompt: str,
|
||||
model_config: TaskConfig,
|
||||
tool_options: List[Dict[str, Any]] | None = None,
|
||||
request_type: str = "plugin.generate",
|
||||
temperature: Optional[float] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> Tuple[bool, str, str, str, List[ToolCall] | None]:
|
||||
"""使用指定模型和工具生成内容
|
||||
|
||||
Args:
|
||||
prompt: 提示词
|
||||
model_config: 模型配置(从 get_available_models 获取的模型配置)
|
||||
tool_options: 工具选项列表
|
||||
request_type: 请求类型标识
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大token数
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称)
|
||||
"""
|
||||
try:
|
||||
model_name_list = model_config.model_list
|
||||
logger.info(f"[LLMAPI] 使用模型集合 {model_name_list} 生成内容")
|
||||
logger.debug(f"[LLMAPI] 完整提示词: {prompt}")
|
||||
|
||||
llm_request = LLMRequest(model_set=model_config, request_type=request_type)
|
||||
|
||||
response, (reasoning_content, model_name, tool_call) = await llm_request.generate_response_async(
|
||||
prompt,
|
||||
tools=tool_options,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
return True, response, reasoning_content, model_name, tool_call
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"生成内容时出错: {str(e)}"
|
||||
logger.error(f"[LLMAPI] {error_msg}")
|
||||
return False, error_msg, "", "", None
|
||||
3
src/plugin_system/apis/logging_api.py
Normal file
3
src/plugin_system/apis/logging_api.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from src.common.logger import get_logger
|
||||
|
||||
__all__ = ["get_logger"]
|
||||
483
src/plugin_system/apis/message_api.py
Normal file
483
src/plugin_system/apis/message_api.py
Normal file
@@ -0,0 +1,483 @@
|
||||
"""
|
||||
消息API模块
|
||||
|
||||
提供消息查询和构建成字符串的功能,采用标准Python包设计模式
|
||||
使用方式:
|
||||
from src.plugin_system.apis import message_api
|
||||
messages = message_api.get_messages_by_time_in_chat(chat_id, start_time, end_time)
|
||||
readable_text = message_api.build_readable_messages(messages)
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Tuple, Optional
|
||||
from src.config.config import global_config
|
||||
import time
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
get_raw_msg_by_timestamp,
|
||||
get_raw_msg_by_timestamp_with_chat,
|
||||
get_raw_msg_by_timestamp_with_chat_inclusive,
|
||||
get_raw_msg_by_timestamp_with_chat_users,
|
||||
get_raw_msg_by_timestamp_random,
|
||||
get_raw_msg_by_timestamp_with_users,
|
||||
get_raw_msg_before_timestamp,
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
get_raw_msg_before_timestamp_with_users,
|
||||
num_new_messages_since,
|
||||
num_new_messages_since_with_users,
|
||||
build_readable_messages,
|
||||
build_readable_messages_with_list,
|
||||
get_person_id_list,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 消息查询API函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_messages_by_time(
|
||||
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定时间范围内的消息
|
||||
|
||||
Args:
|
||||
start_time: 开始时间戳
|
||||
end_time: 结束时间戳
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
limit_mode: 当limit>0时生效,'earliest'表示获取最早的记录,'latest'表示获取最新的记录
|
||||
filter_mai: 是否过滤麦麦自身的消息,默认为False
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 消息列表
|
||||
|
||||
Raises:
|
||||
ValueError: 如果参数不合法
|
||||
"""
|
||||
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
|
||||
raise ValueError("start_time 和 end_time 必须是数字类型")
|
||||
if limit < 0:
|
||||
raise ValueError("limit 不能为负数")
|
||||
if filter_mai:
|
||||
return filter_mai_messages(get_raw_msg_by_timestamp(start_time, end_time, limit, limit_mode))
|
||||
return get_raw_msg_by_timestamp(start_time, end_time, limit, limit_mode)
|
||||
|
||||
|
||||
def get_messages_by_time_in_chat(
|
||||
chat_id: str,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
limit: int = 0,
|
||||
limit_mode: str = "latest",
|
||||
filter_mai: bool = False,
|
||||
filter_command: bool = False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定聊天中指定时间范围内的消息
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
start_time: 开始时间戳
|
||||
end_time: 结束时间戳
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
limit_mode: 当limit>0时生效,'earliest'表示获取最早的记录,'latest'表示获取最新的记录
|
||||
filter_mai: 是否过滤麦麦自身的消息,默认为False
|
||||
filter_command: 是否过滤命令消息,默认为False
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 消息列表
|
||||
|
||||
Raises:
|
||||
ValueError: 如果参数不合法
|
||||
"""
|
||||
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
|
||||
raise ValueError("start_time 和 end_time 必须是数字类型")
|
||||
if limit < 0:
|
||||
raise ValueError("limit 不能为负数")
|
||||
if not chat_id:
|
||||
raise ValueError("chat_id 不能为空")
|
||||
if not isinstance(chat_id, str):
|
||||
raise ValueError("chat_id 必须是字符串类型")
|
||||
if filter_mai:
|
||||
return filter_mai_messages(get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode, filter_command))
|
||||
return get_raw_msg_by_timestamp_with_chat(chat_id, start_time, end_time, limit, limit_mode, filter_command)
|
||||
|
||||
|
||||
def get_messages_by_time_in_chat_inclusive(
|
||||
chat_id: str,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
limit: int = 0,
|
||||
limit_mode: str = "latest",
|
||||
filter_mai: bool = False,
|
||||
filter_command: bool = False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定聊天中指定时间范围内的消息(包含边界)
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
start_time: 开始时间戳(包含)
|
||||
end_time: 结束时间戳(包含)
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
limit_mode: 当limit>0时生效,'earliest'表示获取最早的记录,'latest'表示获取最新的记录
|
||||
filter_mai: 是否过滤麦麦自身的消息,默认为False
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 消息列表
|
||||
|
||||
Raises:
|
||||
ValueError: 如果参数不合法
|
||||
"""
|
||||
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
|
||||
raise ValueError("start_time 和 end_time 必须是数字类型")
|
||||
if limit < 0:
|
||||
raise ValueError("limit 不能为负数")
|
||||
if not chat_id:
|
||||
raise ValueError("chat_id 不能为空")
|
||||
if not isinstance(chat_id, str):
|
||||
raise ValueError("chat_id 必须是字符串类型")
|
||||
if filter_mai:
|
||||
return filter_mai_messages(
|
||||
get_raw_msg_by_timestamp_with_chat_inclusive(chat_id, start_time, end_time, limit, limit_mode, filter_command)
|
||||
)
|
||||
return get_raw_msg_by_timestamp_with_chat_inclusive(chat_id, start_time, end_time, limit, limit_mode, filter_command)
|
||||
|
||||
|
||||
def get_messages_by_time_in_chat_for_users(
|
||||
chat_id: str,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
person_ids: List[str],
|
||||
limit: int = 0,
|
||||
limit_mode: str = "latest",
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定聊天中指定用户在指定时间范围内的消息
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
start_time: 开始时间戳
|
||||
end_time: 结束时间戳
|
||||
person_ids: 用户ID列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
limit_mode: 当limit>0时生效,'earliest'表示获取最早的记录,'latest'表示获取最新的记录
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 消息列表
|
||||
|
||||
Raises:
|
||||
ValueError: 如果参数不合法
|
||||
"""
|
||||
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
|
||||
raise ValueError("start_time 和 end_time 必须是数字类型")
|
||||
if limit < 0:
|
||||
raise ValueError("limit 不能为负数")
|
||||
if not chat_id:
|
||||
raise ValueError("chat_id 不能为空")
|
||||
if not isinstance(chat_id, str):
|
||||
raise ValueError("chat_id 必须是字符串类型")
|
||||
return get_raw_msg_by_timestamp_with_chat_users(chat_id, start_time, end_time, person_ids, limit, limit_mode)
|
||||
|
||||
|
||||
def get_random_chat_messages(
|
||||
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
随机选择一个聊天,返回该聊天在指定时间范围内的消息
|
||||
|
||||
Args:
|
||||
start_time: 开始时间戳
|
||||
end_time: 结束时间戳
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
limit_mode: 当limit>0时生效,'earliest'表示获取最早的记录,'latest'表示获取最新的记录
|
||||
filter_mai: 是否过滤麦麦自身的消息,默认为False
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 消息列表
|
||||
|
||||
Raises:
|
||||
ValueError: 如果参数不合法
|
||||
"""
|
||||
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
|
||||
raise ValueError("start_time 和 end_time 必须是数字类型")
|
||||
if limit < 0:
|
||||
raise ValueError("limit 不能为负数")
|
||||
if filter_mai:
|
||||
return filter_mai_messages(get_raw_msg_by_timestamp_random(start_time, end_time, limit, limit_mode))
|
||||
return get_raw_msg_by_timestamp_random(start_time, end_time, limit, limit_mode)
|
||||
|
||||
|
||||
def get_messages_by_time_for_users(
|
||||
start_time: float, end_time: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定用户在所有聊天中指定时间范围内的消息
|
||||
|
||||
Args:
|
||||
start_time: 开始时间戳
|
||||
end_time: 结束时间戳
|
||||
person_ids: 用户ID列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
limit_mode: 当limit>0时生效,'earliest'表示获取最早的记录,'latest'表示获取最新的记录
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 消息列表
|
||||
|
||||
Raises:
|
||||
ValueError: 如果参数不合法
|
||||
"""
|
||||
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
|
||||
raise ValueError("start_time 和 end_time 必须是数字类型")
|
||||
if limit < 0:
|
||||
raise ValueError("limit 不能为负数")
|
||||
return get_raw_msg_by_timestamp_with_users(start_time, end_time, person_ids, limit, limit_mode)
|
||||
|
||||
|
||||
def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定时间戳之前的消息
|
||||
|
||||
Args:
|
||||
timestamp: 时间戳
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
filter_mai: 是否过滤麦麦自身的消息,默认为False
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 消息列表
|
||||
|
||||
Raises:
|
||||
ValueError: 如果参数不合法
|
||||
"""
|
||||
if not isinstance(timestamp, (int, float)):
|
||||
raise ValueError("timestamp 必须是数字类型")
|
||||
if limit < 0:
|
||||
raise ValueError("limit 不能为负数")
|
||||
if filter_mai:
|
||||
return filter_mai_messages(get_raw_msg_before_timestamp(timestamp, limit))
|
||||
return get_raw_msg_before_timestamp(timestamp, limit)
|
||||
|
||||
|
||||
def get_messages_before_time_in_chat(
|
||||
chat_id: str, timestamp: float, limit: int = 0, filter_mai: bool = False
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定聊天中指定时间戳之前的消息
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
timestamp: 时间戳
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
filter_mai: 是否过滤麦麦自身的消息,默认为False
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 消息列表
|
||||
|
||||
Raises:
|
||||
ValueError: 如果参数不合法
|
||||
"""
|
||||
if not isinstance(timestamp, (int, float)):
|
||||
raise ValueError("timestamp 必须是数字类型")
|
||||
if limit < 0:
|
||||
raise ValueError("limit 不能为负数")
|
||||
if not chat_id:
|
||||
raise ValueError("chat_id 不能为空")
|
||||
if not isinstance(chat_id, str):
|
||||
raise ValueError("chat_id 必须是字符串类型")
|
||||
if filter_mai:
|
||||
return filter_mai_messages(get_raw_msg_before_timestamp_with_chat(chat_id, timestamp, limit))
|
||||
return get_raw_msg_before_timestamp_with_chat(chat_id, timestamp, limit)
|
||||
|
||||
|
||||
def get_messages_before_time_for_users(timestamp: float, person_ids: List[str], limit: int = 0) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定用户在指定时间戳之前的消息
|
||||
|
||||
Args:
|
||||
timestamp: 时间戳
|
||||
person_ids: 用户ID列表
|
||||
limit: 限制返回的消息数量,0为不限制
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 消息列表
|
||||
|
||||
Raises:
|
||||
ValueError: 如果参数不合法
|
||||
"""
|
||||
if not isinstance(timestamp, (int, float)):
|
||||
raise ValueError("timestamp 必须是数字类型")
|
||||
if limit < 0:
|
||||
raise ValueError("limit 不能为负数")
|
||||
return get_raw_msg_before_timestamp_with_users(timestamp, person_ids, limit)
|
||||
|
||||
|
||||
def get_recent_messages(
|
||||
chat_id: str, hours: float = 24.0, limit: int = 100, limit_mode: str = "latest", filter_mai: bool = False
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定聊天中最近一段时间的消息
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
hours: 最近多少小时,默认24小时
|
||||
limit: 限制返回的消息数量,默认100条
|
||||
limit_mode: 当limit>0时生效,'earliest'表示获取最早的记录,'latest'表示获取最新的记录
|
||||
filter_mai: 是否过滤麦麦自身的消息,默认为False
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 消息列表
|
||||
|
||||
Raises:
|
||||
ValueError: 如果参数不合法s
|
||||
"""
|
||||
if not isinstance(hours, (int, float)) or hours < 0:
|
||||
raise ValueError("hours 不能是负数")
|
||||
if not isinstance(limit, int) or limit < 0:
|
||||
raise ValueError("limit 必须是非负整数")
|
||||
if not chat_id:
|
||||
raise ValueError("chat_id 不能为空")
|
||||
if not isinstance(chat_id, str):
|
||||
raise ValueError("chat_id 必须是字符串类型")
|
||||
now = time.time()
|
||||
start_time = now - hours * 3600
|
||||
if filter_mai:
|
||||
return filter_mai_messages(get_raw_msg_by_timestamp_with_chat(chat_id, start_time, now, limit, limit_mode))
|
||||
return get_raw_msg_by_timestamp_with_chat(chat_id, start_time, now, limit, limit_mode)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 消息计数API函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional[float] = None) -> int:
|
||||
"""
|
||||
计算指定聊天中从开始时间到结束时间的新消息数量
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
start_time: 开始时间戳
|
||||
end_time: 结束时间戳,如果为None则使用当前时间
|
||||
|
||||
Returns:
|
||||
int: 新消息数量
|
||||
|
||||
Raises:
|
||||
ValueError: 如果参数不合法
|
||||
"""
|
||||
if not isinstance(start_time, (int, float)):
|
||||
raise ValueError("start_time 必须是数字类型")
|
||||
if not chat_id:
|
||||
raise ValueError("chat_id 不能为空")
|
||||
if not isinstance(chat_id, str):
|
||||
raise ValueError("chat_id 必须是字符串类型")
|
||||
return num_new_messages_since(chat_id, start_time, end_time)
|
||||
|
||||
|
||||
def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: List[str]) -> int:
|
||||
"""
|
||||
计算指定聊天中指定用户从开始时间到结束时间的新消息数量
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
start_time: 开始时间戳
|
||||
end_time: 结束时间戳
|
||||
person_ids: 用户ID列表
|
||||
|
||||
Returns:
|
||||
int: 新消息数量
|
||||
|
||||
Raises:
|
||||
ValueError: 如果参数不合法
|
||||
"""
|
||||
if not isinstance(start_time, (int, float)) or not isinstance(end_time, (int, float)):
|
||||
raise ValueError("start_time 和 end_time 必须是数字类型")
|
||||
if not chat_id:
|
||||
raise ValueError("chat_id 不能为空")
|
||||
if not isinstance(chat_id, str):
|
||||
raise ValueError("chat_id 必须是字符串类型")
|
||||
return num_new_messages_since_with_users(chat_id, start_time, end_time, person_ids)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 消息格式化API函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def build_readable_messages_to_str(
|
||||
messages: List[Dict[str, Any]],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
timestamp_mode: str = "relative",
|
||||
read_mark: float = 0.0,
|
||||
truncate: bool = False,
|
||||
show_actions: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
将消息列表构建成可读的字符串
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
replace_bot_name: 是否将机器人的名称替换为"你"
|
||||
merge_messages: 是否合并连续消息
|
||||
timestamp_mode: 时间戳显示模式,'relative'或'absolute'
|
||||
read_mark: 已读标记时间戳,用于分割已读和未读消息
|
||||
truncate: 是否截断长消息
|
||||
show_actions: 是否显示动作记录
|
||||
|
||||
Returns:
|
||||
格式化后的可读字符串
|
||||
"""
|
||||
return build_readable_messages(
|
||||
messages, replace_bot_name, merge_messages, timestamp_mode, read_mark, truncate, show_actions
|
||||
)
|
||||
|
||||
|
||||
async def build_readable_messages_with_details(
|
||||
messages: List[Dict[str, Any]],
|
||||
replace_bot_name: bool = True,
|
||||
merge_messages: bool = False,
|
||||
timestamp_mode: str = "relative",
|
||||
truncate: bool = False,
|
||||
) -> Tuple[str, List[Tuple[float, str, str]]]:
|
||||
"""
|
||||
将消息列表构建成可读的字符串,并返回详细信息
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
replace_bot_name: 是否将机器人的名称替换为"你"
|
||||
merge_messages: 是否合并连续消息
|
||||
timestamp_mode: 时间戳显示模式,'relative'或'absolute'
|
||||
truncate: 是否截断长消息
|
||||
|
||||
Returns:
|
||||
格式化后的可读字符串和详细信息元组列表(时间戳, 昵称, 内容)
|
||||
"""
|
||||
return await build_readable_messages_with_list(messages, replace_bot_name, merge_messages, timestamp_mode, truncate)
|
||||
|
||||
|
||||
async def get_person_ids_from_messages(messages: List[Dict[str, Any]]) -> List[str]:
|
||||
"""
|
||||
从消息列表中提取不重复的用户ID列表
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
|
||||
Returns:
|
||||
用户ID列表
|
||||
"""
|
||||
return await get_person_id_list(messages)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 消息过滤函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def filter_mai_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
从消息列表中移除麦麦的消息
|
||||
Args:
|
||||
messages: 消息列表,每个元素是消息字典
|
||||
Returns:
|
||||
过滤后的消息列表
|
||||
"""
|
||||
return [msg for msg in messages if msg.get("user_id") != str(global_config.bot.qq_account)]
|
||||
154
src/plugin_system/apis/person_api.py
Normal file
154
src/plugin_system/apis/person_api.py
Normal file
@@ -0,0 +1,154 @@
|
||||
"""个人信息API模块
|
||||
|
||||
提供个人信息查询功能,用于插件获取用户相关信息
|
||||
使用方式:
|
||||
from src.plugin_system.apis import person_api
|
||||
person_id = person_api.get_person_id("qq", 123456)
|
||||
value = await person_api.get_person_value(person_id, "nickname")
|
||||
"""
|
||||
|
||||
from typing import Any, Optional
|
||||
from src.common.logger import get_logger
|
||||
from src.person_info.person_info import get_person_info_manager, PersonInfoManager
|
||||
|
||||
logger = get_logger("person_api")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 个人信息API函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_person_id(platform: str, user_id: int) -> str:
|
||||
"""根据平台和用户ID获取person_id
|
||||
|
||||
Args:
|
||||
platform: 平台名称,如 "qq", "telegram" 等
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
str: 唯一的person_id(MD5哈希值)
|
||||
|
||||
示例:
|
||||
person_id = person_api.get_person_id("qq", 123456)
|
||||
"""
|
||||
try:
|
||||
return PersonInfoManager.get_person_id(platform, user_id)
|
||||
except Exception as e:
|
||||
logger.error(f"[PersonAPI] 获取person_id失败: platform={platform}, user_id={user_id}, error={e}")
|
||||
return ""
|
||||
|
||||
|
||||
async def get_person_value(person_id: str, field_name: str, default: Any = None) -> Any:
|
||||
"""根据person_id和字段名获取某个值
|
||||
|
||||
Args:
|
||||
person_id: 用户的唯一标识ID
|
||||
field_name: 要获取的字段名,如 "nickname", "impression" 等
|
||||
default: 当字段不存在或获取失败时返回的默认值
|
||||
|
||||
Returns:
|
||||
Any: 字段值或默认值
|
||||
|
||||
示例:
|
||||
nickname = await person_api.get_person_value(person_id, "nickname", "未知用户")
|
||||
impression = await person_api.get_person_value(person_id, "impression")
|
||||
"""
|
||||
try:
|
||||
person_info_manager = get_person_info_manager()
|
||||
value = await person_info_manager.get_value(person_id, field_name)
|
||||
return value if value is not None else default
|
||||
except Exception as e:
|
||||
logger.error(f"[PersonAPI] 获取用户信息失败: person_id={person_id}, field={field_name}, error={e}")
|
||||
return default
|
||||
|
||||
|
||||
async def get_person_values(person_id: str, field_names: list, default_dict: Optional[dict] = None) -> dict:
|
||||
"""批量获取用户信息字段值
|
||||
|
||||
Args:
|
||||
person_id: 用户的唯一标识ID
|
||||
field_names: 要获取的字段名列表
|
||||
default_dict: 默认值字典,键为字段名,值为默认值
|
||||
|
||||
Returns:
|
||||
dict: 字段名到值的映射字典
|
||||
|
||||
示例:
|
||||
values = await person_api.get_person_values(
|
||||
person_id,
|
||||
["nickname", "impression", "know_times"],
|
||||
{"nickname": "未知用户", "know_times": 0}
|
||||
)
|
||||
"""
|
||||
try:
|
||||
person_info_manager = get_person_info_manager()
|
||||
values = await person_info_manager.get_values(person_id, field_names)
|
||||
|
||||
# 如果获取成功,返回结果
|
||||
if values:
|
||||
return values
|
||||
|
||||
# 如果获取失败,构建默认值字典
|
||||
result = {}
|
||||
if default_dict:
|
||||
for field in field_names:
|
||||
result[field] = default_dict.get(field, None)
|
||||
else:
|
||||
for field in field_names:
|
||||
result[field] = None
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[PersonAPI] 批量获取用户信息失败: person_id={person_id}, fields={field_names}, error={e}")
|
||||
# 返回默认值字典
|
||||
result = {}
|
||||
if default_dict:
|
||||
for field in field_names:
|
||||
result[field] = default_dict.get(field, None)
|
||||
else:
|
||||
for field in field_names:
|
||||
result[field] = None
|
||||
return result
|
||||
|
||||
|
||||
async def is_person_known(platform: str, user_id: int) -> bool:
|
||||
"""判断是否认识某个用户
|
||||
|
||||
Args:
|
||||
platform: 平台名称
|
||||
user_id: 用户ID
|
||||
|
||||
Returns:
|
||||
bool: 是否认识该用户
|
||||
|
||||
示例:
|
||||
known = await person_api.is_person_known("qq", 123456)
|
||||
"""
|
||||
try:
|
||||
person_info_manager = get_person_info_manager()
|
||||
return await person_info_manager.is_person_known(platform, user_id)
|
||||
except Exception as e:
|
||||
logger.error(f"[PersonAPI] 检查用户是否已知失败: platform={platform}, user_id={user_id}, error={e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_person_id_by_name(person_name: str) -> str:
|
||||
"""根据用户名获取person_id
|
||||
|
||||
Args:
|
||||
person_name: 用户名
|
||||
|
||||
Returns:
|
||||
str: person_id,如果未找到返回空字符串
|
||||
|
||||
示例:
|
||||
person_id = person_api.get_person_id_by_name("张三")
|
||||
"""
|
||||
try:
|
||||
person_info_manager = get_person_info_manager()
|
||||
return person_info_manager.get_person_id_by_person_name(person_name)
|
||||
except Exception as e:
|
||||
logger.error(f"[PersonAPI] 根据用户名获取person_id失败: person_name={person_name}, error={e}")
|
||||
return ""
|
||||
120
src/plugin_system/apis/plugin_manage_api.py
Normal file
120
src/plugin_system/apis/plugin_manage_api.py
Normal file
@@ -0,0 +1,120 @@
|
||||
from typing import Tuple, List
|
||||
|
||||
|
||||
def list_loaded_plugins() -> List[str]:
|
||||
"""
|
||||
列出所有当前加载的插件。
|
||||
|
||||
Returns:
|
||||
List[str]: 当前加载的插件名称列表。
|
||||
"""
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
|
||||
return plugin_manager.list_loaded_plugins()
|
||||
|
||||
|
||||
def list_registered_plugins() -> List[str]:
|
||||
"""
|
||||
列出所有已注册的插件。
|
||||
|
||||
Returns:
|
||||
List[str]: 已注册的插件名称列表。
|
||||
"""
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
|
||||
return plugin_manager.list_registered_plugins()
|
||||
|
||||
|
||||
def get_plugin_path(plugin_name: str) -> str:
|
||||
"""
|
||||
获取指定插件的路径。
|
||||
|
||||
Args:
|
||||
plugin_name (str): 插件名称。
|
||||
|
||||
Returns:
|
||||
str: 插件目录的绝对路径。
|
||||
|
||||
Raises:
|
||||
ValueError: 如果插件不存在。
|
||||
"""
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
|
||||
if plugin_path := plugin_manager.get_plugin_path(plugin_name):
|
||||
return plugin_path
|
||||
else:
|
||||
raise ValueError(f"插件 '{plugin_name}' 不存在。")
|
||||
|
||||
|
||||
async def remove_plugin(plugin_name: str) -> bool:
|
||||
"""
|
||||
卸载指定的插件。
|
||||
|
||||
**此函数是异步的,确保在异步环境中调用。**
|
||||
|
||||
Args:
|
||||
plugin_name (str): 要卸载的插件名称。
|
||||
|
||||
Returns:
|
||||
bool: 卸载是否成功。
|
||||
"""
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
|
||||
return await plugin_manager.remove_registered_plugin(plugin_name)
|
||||
|
||||
|
||||
async def reload_plugin(plugin_name: str) -> bool:
|
||||
"""
|
||||
重新加载指定的插件。
|
||||
|
||||
**此函数是异步的,确保在异步环境中调用。**
|
||||
|
||||
Args:
|
||||
plugin_name (str): 要重新加载的插件名称。
|
||||
|
||||
Returns:
|
||||
bool: 重新加载是否成功。
|
||||
"""
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
|
||||
return await plugin_manager.reload_registered_plugin(plugin_name)
|
||||
|
||||
|
||||
def load_plugin(plugin_name: str) -> Tuple[bool, int]:
|
||||
"""
|
||||
加载指定的插件。
|
||||
|
||||
Args:
|
||||
plugin_name (str): 要加载的插件名称。
|
||||
|
||||
Returns:
|
||||
Tuple[bool, int]: 加载是否成功,成功或失败个数。
|
||||
"""
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
|
||||
return plugin_manager.load_registered_plugin_classes(plugin_name)
|
||||
|
||||
|
||||
def add_plugin_directory(plugin_directory: str) -> bool:
|
||||
"""
|
||||
添加插件目录。
|
||||
|
||||
Args:
|
||||
plugin_directory (str): 要添加的插件目录路径。
|
||||
Returns:
|
||||
bool: 添加是否成功。
|
||||
"""
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
|
||||
return plugin_manager.add_plugin_directory(plugin_directory)
|
||||
|
||||
|
||||
def rescan_plugin_directory() -> Tuple[int, int]:
|
||||
"""
|
||||
重新扫描插件目录,加载新插件。
|
||||
Returns:
|
||||
Tuple[int, int]: 成功加载的插件数量和失败的插件数量。
|
||||
"""
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
|
||||
return plugin_manager.rescan_plugin_directory()
|
||||
46
src/plugin_system/apis/plugin_register_api.py
Normal file
46
src/plugin_system/apis/plugin_register_api.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from pathlib import Path
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("plugin_manager") # 复用plugin_manager名称
|
||||
|
||||
|
||||
def register_plugin(cls):
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
from src.plugin_system.base.base_plugin import BasePlugin
|
||||
|
||||
"""插件注册装饰器
|
||||
|
||||
用法:
|
||||
@register_plugin
|
||||
class MyPlugin(BasePlugin):
|
||||
plugin_name = "my_plugin"
|
||||
plugin_description = "我的插件"
|
||||
...
|
||||
"""
|
||||
if not issubclass(cls, BasePlugin):
|
||||
logger.error(f"类 {cls.__name__} 不是 BasePlugin 的子类")
|
||||
return cls
|
||||
|
||||
# 只是注册插件类,不立即实例化
|
||||
# 插件管理器会负责实例化和注册
|
||||
plugin_name: str = cls.plugin_name # type: ignore
|
||||
if "." in plugin_name:
|
||||
logger.error(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代")
|
||||
raise ValueError(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代")
|
||||
splitted_name = cls.__module__.split(".")
|
||||
root_path = Path(__file__)
|
||||
|
||||
# 查找项目根目录
|
||||
while not (root_path / "pyproject.toml").exists() and root_path.parent != root_path:
|
||||
root_path = root_path.parent
|
||||
|
||||
if not (root_path / "pyproject.toml").exists():
|
||||
logger.error(f"注册 {plugin_name} 无法找到项目根目录")
|
||||
return cls
|
||||
|
||||
plugin_manager.plugin_classes[plugin_name] = cls
|
||||
plugin_manager.plugin_paths[plugin_name] = str(Path(root_path, *splitted_name).resolve())
|
||||
logger.debug(f"插件类已注册: {plugin_name}, 路径: {plugin_manager.plugin_paths[plugin_name]}")
|
||||
|
||||
return cls
|
||||
369
src/plugin_system/apis/send_api.py
Normal file
369
src/plugin_system/apis/send_api.py
Normal file
@@ -0,0 +1,369 @@
|
||||
"""
|
||||
发送API模块
|
||||
|
||||
专门负责发送各种类型的消息,采用标准Python包设计模式
|
||||
|
||||
使用方式:
|
||||
from src.plugin_system.apis import send_api
|
||||
|
||||
# 方式1:直接使用stream_id(推荐)
|
||||
await send_api.text_to_stream("hello", stream_id)
|
||||
await send_api.emoji_to_stream(emoji_base64, stream_id)
|
||||
await send_api.custom_to_stream("video", video_data, stream_id)
|
||||
|
||||
# 方式2:使用群聊/私聊指定函数
|
||||
await send_api.text_to_group("hello", "123456")
|
||||
await send_api.text_to_user("hello", "987654")
|
||||
|
||||
# 方式3:使用通用custom_message函数
|
||||
await send_api.custom_message("video", video_data, "123456", True)
|
||||
"""
|
||||
|
||||
import traceback
|
||||
import time
|
||||
import difflib
|
||||
from typing import Optional, Union
|
||||
from src.common.logger import get_logger
|
||||
|
||||
# 导入依赖
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.message_receive.uni_message_sender import HeartFCSender
|
||||
from src.chat.message_receive.message import MessageSending, MessageRecv
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, replace_user_references_async
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from maim_message import Seg, UserInfo
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger("send_api")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 内部实现函数(不暴露给外部)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def _send_to_target(
|
||||
message_type: str,
|
||||
content: Union[str, dict],
|
||||
stream_id: str,
|
||||
display_message: str = "",
|
||||
typing: bool = False,
|
||||
reply_to: str = "",
|
||||
reply_to_platform_id: Optional[str] = None,
|
||||
storage_message: bool = True,
|
||||
show_log: bool = True,
|
||||
) -> bool:
|
||||
"""向指定目标发送消息的内部实现
|
||||
|
||||
Args:
|
||||
message_type: 消息类型,如"text"、"image"、"emoji"等
|
||||
content: 消息内容
|
||||
stream_id: 目标流ID
|
||||
display_message: 显示消息
|
||||
typing: 是否模拟打字等待。
|
||||
reply_to: 回复消息,格式为"发送者:消息内容"
|
||||
reply_to_platform_id: 回复消息,格式为"平台:用户ID",如果不提供则自动查找(插件开发者禁用!)
|
||||
storage_message: 是否存储消息到数据库
|
||||
show_log: 发送是否显示日志
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
try:
|
||||
if show_log:
|
||||
logger.debug(f"[SendAPI] 发送{message_type}消息到 {stream_id}")
|
||||
|
||||
# 查找目标聊天流
|
||||
target_stream = get_chat_manager().get_stream(stream_id)
|
||||
if not target_stream:
|
||||
logger.error(f"[SendAPI] 未找到聊天流: {stream_id}")
|
||||
return False
|
||||
|
||||
# 创建发送器
|
||||
heart_fc_sender = HeartFCSender()
|
||||
|
||||
# 生成消息ID
|
||||
current_time = time.time()
|
||||
message_id = f"send_api_{int(current_time * 1000)}"
|
||||
|
||||
# 构建机器人用户信息
|
||||
bot_user_info = UserInfo(
|
||||
user_id=global_config.bot.qq_account,
|
||||
user_nickname=global_config.bot.nickname,
|
||||
platform=target_stream.platform,
|
||||
)
|
||||
|
||||
# 创建消息段
|
||||
message_segment = Seg(type=message_type, data=content) # type: ignore
|
||||
|
||||
# 处理回复消息
|
||||
anchor_message = None
|
||||
if reply_to:
|
||||
anchor_message = await _find_reply_message(target_stream, reply_to)
|
||||
if anchor_message and anchor_message.message_info.user_info and not reply_to_platform_id:
|
||||
reply_to_platform_id = (
|
||||
f"{anchor_message.message_info.platform}:{anchor_message.message_info.user_info.user_id}"
|
||||
)
|
||||
|
||||
# 构建发送消息对象
|
||||
bot_message = MessageSending(
|
||||
message_id=message_id,
|
||||
chat_stream=target_stream,
|
||||
bot_user_info=bot_user_info,
|
||||
sender_info=target_stream.user_info,
|
||||
message_segment=message_segment,
|
||||
display_message=display_message,
|
||||
reply=anchor_message,
|
||||
is_head=True,
|
||||
is_emoji=(message_type == "emoji"),
|
||||
thinking_start_time=current_time,
|
||||
reply_to=reply_to_platform_id,
|
||||
)
|
||||
|
||||
# 发送消息
|
||||
sent_msg = await heart_fc_sender.send_message(
|
||||
bot_message,
|
||||
typing=typing,
|
||||
set_reply=(anchor_message is not None),
|
||||
storage_message=storage_message,
|
||||
show_log=show_log,
|
||||
)
|
||||
|
||||
if sent_msg:
|
||||
logger.debug(f"[SendAPI] 成功发送消息到 {stream_id}")
|
||||
return True
|
||||
else:
|
||||
logger.error("[SendAPI] 发送消息失败")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SendAPI] 发送消息时出错: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
async def _find_reply_message(target_stream, reply_to: str) -> Optional[MessageRecv]:
|
||||
# sourcery skip: inline-variable, use-named-expression
|
||||
"""查找要回复的消息
|
||||
|
||||
Args:
|
||||
target_stream: 目标聊天流
|
||||
reply_to: 回复格式,如"发送者:消息内容"或"发送者:消息内容"
|
||||
|
||||
Returns:
|
||||
Optional[MessageRecv]: 找到的消息,如果没找到则返回None
|
||||
"""
|
||||
try:
|
||||
# 解析reply_to参数
|
||||
if ":" in reply_to:
|
||||
parts = reply_to.split(":", 1)
|
||||
elif ":" in reply_to:
|
||||
parts = reply_to.split(":", 1)
|
||||
else:
|
||||
logger.warning(f"[SendAPI] reply_to格式不正确: {reply_to}")
|
||||
return None
|
||||
|
||||
if len(parts) != 2:
|
||||
logger.warning(f"[SendAPI] reply_to格式不正确: {reply_to}")
|
||||
return None
|
||||
|
||||
sender = parts[0].strip()
|
||||
text = parts[1].strip()
|
||||
|
||||
# 获取聊天流的最新20条消息
|
||||
reverse_talking_message = get_raw_msg_before_timestamp_with_chat(
|
||||
target_stream.stream_id,
|
||||
time.time(), # 当前时间之前的消息
|
||||
20, # 最新的20条消息
|
||||
)
|
||||
|
||||
# 反转列表,使最新的消息在前面
|
||||
reverse_talking_message = list(reversed(reverse_talking_message))
|
||||
|
||||
find_msg = None
|
||||
for message in reverse_talking_message:
|
||||
user_id = message["user_id"]
|
||||
platform = message["chat_info_platform"]
|
||||
person_id = get_person_info_manager().get_person_id(platform, user_id)
|
||||
person_name = await get_person_info_manager().get_value(person_id, "person_name")
|
||||
if person_name == sender:
|
||||
translate_text = message["processed_plain_text"]
|
||||
|
||||
# 使用独立函数处理用户引用格式
|
||||
translate_text = await replace_user_references_async(translate_text, platform)
|
||||
|
||||
similarity = difflib.SequenceMatcher(None, text, translate_text).ratio()
|
||||
if similarity >= 0.9:
|
||||
find_msg = message
|
||||
break
|
||||
|
||||
if not find_msg:
|
||||
logger.info("[SendAPI] 未找到匹配的回复消息")
|
||||
return None
|
||||
|
||||
# 构建MessageRecv对象
|
||||
user_info = {
|
||||
"platform": find_msg.get("user_platform", ""),
|
||||
"user_id": find_msg.get("user_id", ""),
|
||||
"user_nickname": find_msg.get("user_nickname", ""),
|
||||
"user_cardname": find_msg.get("user_cardname", ""),
|
||||
}
|
||||
|
||||
group_info = {}
|
||||
if find_msg.get("chat_info_group_id"):
|
||||
group_info = {
|
||||
"platform": find_msg.get("chat_info_group_platform", ""),
|
||||
"group_id": find_msg.get("chat_info_group_id", ""),
|
||||
"group_name": find_msg.get("chat_info_group_name", ""),
|
||||
}
|
||||
|
||||
format_info = {"content_format": "", "accept_format": ""}
|
||||
template_info = {"template_items": {}}
|
||||
|
||||
message_info = {
|
||||
"platform": target_stream.platform,
|
||||
"message_id": find_msg.get("message_id"),
|
||||
"time": find_msg.get("time"),
|
||||
"group_info": group_info,
|
||||
"user_info": user_info,
|
||||
"additional_config": find_msg.get("additional_config"),
|
||||
"format_info": format_info,
|
||||
"template_info": template_info,
|
||||
}
|
||||
|
||||
message_dict = {
|
||||
"message_info": message_info,
|
||||
"raw_message": find_msg.get("processed_plain_text"),
|
||||
"processed_plain_text": find_msg.get("processed_plain_text"),
|
||||
}
|
||||
|
||||
find_rec_msg = MessageRecv(message_dict)
|
||||
find_rec_msg.update_chat_stream(target_stream)
|
||||
|
||||
logger.info(f"[SendAPI] 找到匹配的回复消息,发送者: {sender}")
|
||||
return find_rec_msg
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SendAPI] 查找回复消息时出错: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 公共API函数 - 预定义类型的发送函数
|
||||
# =============================================================================
|
||||
|
||||
|
||||
async def text_to_stream(
|
||||
text: str,
|
||||
stream_id: str,
|
||||
typing: bool = False,
|
||||
reply_to: str = "",
|
||||
reply_to_platform_id: str = "",
|
||||
storage_message: bool = True,
|
||||
) -> bool:
|
||||
"""向指定流发送文本消息
|
||||
|
||||
Args:
|
||||
text: 要发送的文本内容
|
||||
stream_id: 聊天流ID
|
||||
typing: 是否显示正在输入
|
||||
reply_to: 回复消息,格式为"发送者:消息内容"
|
||||
reply_to_platform_id: 回复消息,格式为"平台:用户ID",如果不提供则自动查找(插件开发者禁用!)
|
||||
storage_message: 是否存储消息到数据库
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
return await _send_to_target(
|
||||
"text",
|
||||
text,
|
||||
stream_id,
|
||||
"",
|
||||
typing,
|
||||
reply_to,
|
||||
reply_to_platform_id=reply_to_platform_id,
|
||||
storage_message=storage_message,
|
||||
)
|
||||
|
||||
|
||||
async def emoji_to_stream(emoji_base64: str, stream_id: str, storage_message: bool = True) -> bool:
|
||||
"""向指定流发送表情包
|
||||
|
||||
Args:
|
||||
emoji_base64: 表情包的base64编码
|
||||
stream_id: 聊天流ID
|
||||
storage_message: 是否存储消息到数据库
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
return await _send_to_target("emoji", emoji_base64, stream_id, "", typing=False, storage_message=storage_message)
|
||||
|
||||
|
||||
async def image_to_stream(image_base64: str, stream_id: str, storage_message: bool = True) -> bool:
|
||||
"""向指定流发送图片
|
||||
|
||||
Args:
|
||||
image_base64: 图片的base64编码
|
||||
stream_id: 聊天流ID
|
||||
storage_message: 是否存储消息到数据库
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
return await _send_to_target("image", image_base64, stream_id, "", typing=False, storage_message=storage_message)
|
||||
|
||||
|
||||
async def command_to_stream(
|
||||
command: Union[str, dict], stream_id: str, storage_message: bool = True, display_message: str = ""
|
||||
) -> bool:
|
||||
"""向指定流发送命令
|
||||
|
||||
Args:
|
||||
command: 命令
|
||||
stream_id: 聊天流ID
|
||||
storage_message: 是否存储消息到数据库
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
return await _send_to_target(
|
||||
"command", command, stream_id, display_message, typing=False, storage_message=storage_message
|
||||
)
|
||||
|
||||
|
||||
async def custom_to_stream(
|
||||
message_type: str,
|
||||
content: str | dict,
|
||||
stream_id: str,
|
||||
display_message: str = "",
|
||||
typing: bool = False,
|
||||
reply_to: str = "",
|
||||
storage_message: bool = True,
|
||||
show_log: bool = True,
|
||||
) -> bool:
|
||||
"""向指定流发送自定义类型消息
|
||||
|
||||
Args:
|
||||
message_type: 消息类型,如"text"、"image"、"emoji"、"video"、"file"等
|
||||
content: 消息内容(通常是base64编码或文本)
|
||||
stream_id: 聊天流ID
|
||||
display_message: 显示消息
|
||||
typing: 是否显示正在输入
|
||||
reply_to: 回复消息,格式为"发送者:消息内容"
|
||||
storage_message: 是否存储消息到数据库
|
||||
show_log: 是否显示日志
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
return await _send_to_target(
|
||||
message_type=message_type,
|
||||
content=content,
|
||||
stream_id=stream_id,
|
||||
display_message=display_message,
|
||||
typing=typing,
|
||||
reply_to=reply_to,
|
||||
storage_message=storage_message,
|
||||
show_log=show_log,
|
||||
)
|
||||
34
src/plugin_system/apis/tool_api.py
Normal file
34
src/plugin_system/apis/tool_api.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from typing import Optional, Type
|
||||
from src.plugin_system.base.base_tool import BaseTool
|
||||
from src.plugin_system.base.component_types import ComponentType
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("tool_api")
|
||||
|
||||
|
||||
def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
|
||||
"""获取公开工具实例"""
|
||||
from src.plugin_system.core import component_registry
|
||||
|
||||
# 获取插件配置
|
||||
tool_info = component_registry.get_component_info(tool_name, ComponentType.TOOL)
|
||||
if tool_info:
|
||||
plugin_config = component_registry.get_plugin_config(tool_info.plugin_name)
|
||||
else:
|
||||
plugin_config = None
|
||||
|
||||
tool_class: Type[BaseTool] = component_registry.get_component_class(tool_name, ComponentType.TOOL) # type: ignore
|
||||
return tool_class(plugin_config) if tool_class else None
|
||||
|
||||
|
||||
def get_llm_available_tool_definitions():
|
||||
"""获取LLM可用的工具定义列表
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, Dict[str, Any]]]: 工具定义列表,为[("tool_name", 定义)]
|
||||
"""
|
||||
from src.plugin_system.core import component_registry
|
||||
|
||||
llm_available_tools = component_registry.get_llm_available_tools()
|
||||
return [(name, tool_class.get_tool_definition()) for name, tool_class in llm_available_tools.items()]
|
||||
49
src/plugin_system/base/__init__.py
Normal file
49
src/plugin_system/base/__init__.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""
|
||||
插件基础类模块
|
||||
|
||||
提供插件开发的基础类和类型定义
|
||||
"""
|
||||
|
||||
from .base_plugin import BasePlugin
|
||||
from .base_action import BaseAction
|
||||
from .base_tool import BaseTool
|
||||
from .base_command import BaseCommand
|
||||
from .base_events_handler import BaseEventHandler
|
||||
from .component_types import (
|
||||
ComponentType,
|
||||
ActionActivationType,
|
||||
ChatMode,
|
||||
ComponentInfo,
|
||||
ActionInfo,
|
||||
CommandInfo,
|
||||
ToolInfo,
|
||||
PluginInfo,
|
||||
PythonDependency,
|
||||
EventHandlerInfo,
|
||||
EventType,
|
||||
MaiMessages,
|
||||
ToolParamType,
|
||||
)
|
||||
from .config_types import ConfigField
|
||||
|
||||
__all__ = [
|
||||
"BasePlugin",
|
||||
"BaseAction",
|
||||
"BaseCommand",
|
||||
"BaseTool",
|
||||
"ComponentType",
|
||||
"ActionActivationType",
|
||||
"ChatMode",
|
||||
"ComponentInfo",
|
||||
"ActionInfo",
|
||||
"CommandInfo",
|
||||
"ToolInfo",
|
||||
"PluginInfo",
|
||||
"PythonDependency",
|
||||
"ConfigField",
|
||||
"EventHandlerInfo",
|
||||
"EventType",
|
||||
"BaseEventHandler",
|
||||
"MaiMessages",
|
||||
"ToolParamType",
|
||||
]
|
||||
437
src/plugin_system/base/base_action.py
Normal file
437
src/plugin_system/base/base_action.py
Normal file
@@ -0,0 +1,437 @@
|
||||
import time
|
||||
import asyncio
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Tuple, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.plugin_system.base.component_types import ActionActivationType, ChatMode, ActionInfo, ComponentType
|
||||
from src.plugin_system.apis import send_api, database_api, message_api
|
||||
|
||||
|
||||
logger = get_logger("base_action")
|
||||
|
||||
|
||||
class BaseAction(ABC):
|
||||
"""Action组件基类
|
||||
|
||||
Action是插件的一种组件类型,用于处理聊天中的动作逻辑
|
||||
|
||||
子类可以通过类属性定义激活条件,这些会在实例化时转换为实例属性:
|
||||
- focus_activation_type: 专注模式激活类型
|
||||
- normal_activation_type: 普通模式激活类型
|
||||
- activation_keywords: 激活关键词列表
|
||||
- keyword_case_sensitive: 关键词是否区分大小写
|
||||
- mode_enable: 启用的聊天模式
|
||||
- parallel_action: 是否允许并行执行
|
||||
- random_activation_probability: 随机激活概率
|
||||
- llm_judge_prompt: LLM判断提示词
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
action_data: dict,
|
||||
reasoning: str,
|
||||
cycle_timers: dict,
|
||||
thinking_id: str,
|
||||
chat_stream: ChatStream,
|
||||
log_prefix: str = "",
|
||||
plugin_config: Optional[dict] = None,
|
||||
action_message: Optional[dict] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# sourcery skip: hoist-similar-statement-from-if, merge-else-if-into-elif, move-assign-in-block, swap-if-else-branches, swap-nested-ifs
|
||||
"""初始化Action组件
|
||||
|
||||
Args:
|
||||
action_data: 动作数据
|
||||
reasoning: 执行该动作的理由
|
||||
cycle_timers: 计时器字典
|
||||
thinking_id: 思考ID
|
||||
chat_stream: 聊天流对象
|
||||
log_prefix: 日志前缀
|
||||
plugin_config: 插件配置字典
|
||||
action_message: 消息数据
|
||||
**kwargs: 其他参数
|
||||
"""
|
||||
if plugin_config is None:
|
||||
plugin_config = {}
|
||||
self.action_data = action_data
|
||||
self.reasoning = reasoning
|
||||
self.cycle_timers = cycle_timers
|
||||
self.thinking_id = thinking_id
|
||||
self.log_prefix = log_prefix
|
||||
|
||||
self.plugin_config = plugin_config or {}
|
||||
"""对应的插件配置"""
|
||||
|
||||
# 设置动作基本信息实例属性
|
||||
self.action_name: str = getattr(self, "action_name", self.__class__.__name__.lower().replace("action", ""))
|
||||
"""Action的名字"""
|
||||
self.action_description: str = getattr(self, "action_description", self.__doc__ or "Action组件")
|
||||
"""Action的描述"""
|
||||
self.action_parameters: dict = getattr(self.__class__, "action_parameters", {}).copy()
|
||||
self.action_require: list[str] = getattr(self.__class__, "action_require", []).copy()
|
||||
|
||||
# 设置激活类型实例属性(从类属性复制,提供默认值)
|
||||
self.focus_activation_type = getattr(self.__class__, "focus_activation_type", ActionActivationType.ALWAYS)
|
||||
"""FOCUS模式下的激活类型"""
|
||||
self.normal_activation_type = getattr(self.__class__, "normal_activation_type", ActionActivationType.ALWAYS)
|
||||
"""NORMAL模式下的激活类型"""
|
||||
self.activation_type = getattr(self.__class__, "activation_type", self.focus_activation_type)
|
||||
"""激活类型"""
|
||||
self.random_activation_probability: float = getattr(self.__class__, "random_activation_probability", 0.0)
|
||||
"""当激活类型为RANDOM时的概率"""
|
||||
self.llm_judge_prompt: str = getattr(self.__class__, "llm_judge_prompt", "")
|
||||
"""协助LLM进行判断的Prompt"""
|
||||
self.activation_keywords: list[str] = getattr(self.__class__, "activation_keywords", []).copy()
|
||||
"""激活类型为KEYWORD时的KEYWORDS列表"""
|
||||
self.keyword_case_sensitive: bool = getattr(self.__class__, "keyword_case_sensitive", False)
|
||||
self.mode_enable: ChatMode = getattr(self.__class__, "mode_enable", ChatMode.ALL)
|
||||
self.parallel_action: bool = getattr(self.__class__, "parallel_action", True)
|
||||
self.associated_types: list[str] = getattr(self.__class__, "associated_types", []).copy()
|
||||
|
||||
# =============================================================================
|
||||
# 便捷属性 - 直接在初始化时获取常用聊天信息(带类型注解)
|
||||
# =============================================================================
|
||||
|
||||
# 获取聊天流对象
|
||||
self.chat_stream = chat_stream or kwargs.get("chat_stream")
|
||||
self.chat_id = self.chat_stream.stream_id
|
||||
self.platform = getattr(self.chat_stream, "platform", None)
|
||||
|
||||
# 初始化基础信息(带类型注解)
|
||||
self.action_message = action_message
|
||||
|
||||
self.group_id = None
|
||||
self.group_name = None
|
||||
self.user_id = None
|
||||
self.user_nickname = None
|
||||
self.is_group = False
|
||||
self.target_id = None
|
||||
self.has_action_message = False
|
||||
|
||||
if self.action_message:
|
||||
self.has_action_message = True
|
||||
else:
|
||||
self.action_message = {}
|
||||
|
||||
if self.has_action_message:
|
||||
if self.action_name != "no_reply":
|
||||
self.group_id = str(self.action_message.get("chat_info_group_id", None))
|
||||
self.group_name = self.action_message.get("chat_info_group_name", None)
|
||||
|
||||
self.user_id = str(self.action_message.get("user_id", None))
|
||||
self.user_nickname = self.action_message.get("user_nickname", None)
|
||||
if self.group_id:
|
||||
self.is_group = True
|
||||
self.target_id = self.group_id
|
||||
else:
|
||||
self.is_group = False
|
||||
self.target_id = self.user_id
|
||||
else:
|
||||
if self.chat_stream.group_info:
|
||||
self.group_id = self.chat_stream.group_info.group_id
|
||||
self.group_name = self.chat_stream.group_info.group_name
|
||||
self.is_group = True
|
||||
self.target_id = self.group_id
|
||||
else:
|
||||
self.user_id = self.chat_stream.user_info.user_id
|
||||
self.user_nickname = self.chat_stream.user_info.user_nickname
|
||||
self.is_group = False
|
||||
self.target_id = self.user_id
|
||||
|
||||
logger.debug(f"{self.log_prefix} Action组件初始化完成")
|
||||
logger.debug(
|
||||
f"{self.log_prefix} 聊天信息: 类型={'群聊' if self.is_group else '私聊'}, 平台={self.platform}, 目标={self.target_id}"
|
||||
)
|
||||
|
||||
async def wait_for_new_message(self, timeout: int = 1200) -> Tuple[bool, str]:
|
||||
"""等待新消息或超时
|
||||
|
||||
在loop_start_time之后等待新消息,如果没有新消息且没有超时,就一直等待。
|
||||
使用message_api检查self.chat_id对应的聊天中是否有新消息。
|
||||
|
||||
Args:
|
||||
timeout: 超时时间(秒),默认1200秒
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (是否收到新消息, 空字符串)
|
||||
"""
|
||||
try:
|
||||
# 获取循环开始时间,如果没有则使用当前时间
|
||||
loop_start_time = self.action_data.get("loop_start_time", time.time())
|
||||
logger.info(f"{self.log_prefix} 开始等待新消息... (最长等待: {timeout}秒, 从时间点: {loop_start_time})")
|
||||
|
||||
# 确保有有效的chat_id
|
||||
if not self.chat_id:
|
||||
logger.error(f"{self.log_prefix} 等待新消息失败: 没有有效的chat_id")
|
||||
return False, "没有有效的chat_id"
|
||||
|
||||
wait_start_time = asyncio.get_event_loop().time()
|
||||
while True:
|
||||
# 检查关闭标志
|
||||
# shutting_down = self.get_action_context("shutting_down", False)
|
||||
# if shutting_down:
|
||||
# logger.info(f"{self.log_prefix} 等待新消息时检测到关闭信号,中断等待")
|
||||
# return False, ""
|
||||
|
||||
# 检查新消息
|
||||
current_time = time.time()
|
||||
new_message_count = message_api.count_new_messages(
|
||||
chat_id=self.chat_id, start_time=loop_start_time, end_time=current_time
|
||||
)
|
||||
|
||||
if new_message_count > 0:
|
||||
logger.info(f"{self.log_prefix} 检测到{new_message_count}条新消息,聊天ID: {self.chat_id}")
|
||||
return True, ""
|
||||
|
||||
# 检查超时
|
||||
elapsed_time = asyncio.get_event_loop().time() - wait_start_time
|
||||
if elapsed_time > timeout:
|
||||
logger.warning(f"{self.log_prefix} 等待新消息超时({timeout}秒),聊天ID: {self.chat_id}")
|
||||
return False, ""
|
||||
|
||||
# 每30秒记录一次等待状态
|
||||
if int(elapsed_time) % 15 == 0 and int(elapsed_time) > 0:
|
||||
logger.debug(f"{self.log_prefix} 已等待{int(elapsed_time)}秒,继续等待新消息...")
|
||||
|
||||
# 短暂休眠
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"{self.log_prefix} 等待新消息被中断 (CancelledError)")
|
||||
return False, ""
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 等待新消息时发生错误: {e}")
|
||||
return False, f"等待新消息失败: {str(e)}"
|
||||
|
||||
async def send_text(
|
||||
self, content: str, reply_to: str = "", typing: bool = False
|
||||
) -> bool:
|
||||
"""发送文本消息
|
||||
|
||||
Args:
|
||||
content: 文本内容
|
||||
reply_to: 回复消息,格式为"发送者:消息内容"
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
if not self.chat_id:
|
||||
logger.error(f"{self.log_prefix} 缺少聊天ID")
|
||||
return False
|
||||
|
||||
return await send_api.text_to_stream(
|
||||
text=content,
|
||||
stream_id=self.chat_id,
|
||||
reply_to=reply_to,
|
||||
typing=typing,
|
||||
)
|
||||
|
||||
async def send_emoji(self, emoji_base64: str) -> bool:
|
||||
"""发送表情包
|
||||
|
||||
Args:
|
||||
emoji_base64: 表情包的base64编码
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
if not self.chat_id:
|
||||
logger.error(f"{self.log_prefix} 缺少聊天ID")
|
||||
return False
|
||||
|
||||
return await send_api.emoji_to_stream(emoji_base64, self.chat_id)
|
||||
|
||||
async def send_image(self, image_base64: str) -> bool:
|
||||
"""发送图片
|
||||
|
||||
Args:
|
||||
image_base64: 图片的base64编码
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
if not self.chat_id:
|
||||
logger.error(f"{self.log_prefix} 缺少聊天ID")
|
||||
return False
|
||||
|
||||
return await send_api.image_to_stream(image_base64, self.chat_id)
|
||||
|
||||
async def send_custom(self, message_type: str, content: str, typing: bool = False, reply_to: str = "") -> bool:
|
||||
"""发送自定义类型消息
|
||||
|
||||
Args:
|
||||
message_type: 消息类型,如"video"、"file"、"audio"等
|
||||
content: 消息内容
|
||||
typing: 是否显示正在输入
|
||||
reply_to: 回复消息,格式为"发送者:消息内容"
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
if not self.chat_id:
|
||||
logger.error(f"{self.log_prefix} 缺少聊天ID")
|
||||
return False
|
||||
|
||||
return await send_api.custom_to_stream(
|
||||
message_type=message_type,
|
||||
content=content,
|
||||
stream_id=self.chat_id,
|
||||
typing=typing,
|
||||
reply_to=reply_to,
|
||||
)
|
||||
|
||||
async def store_action_info(
|
||||
self,
|
||||
action_build_into_prompt: bool = False,
|
||||
action_prompt_display: str = "",
|
||||
action_done: bool = True,
|
||||
) -> None:
|
||||
"""存储动作信息到数据库
|
||||
|
||||
Args:
|
||||
action_build_into_prompt: 是否构建到提示中
|
||||
action_prompt_display: 显示的action提示信息
|
||||
action_done: action是否完成
|
||||
"""
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
action_build_into_prompt=action_build_into_prompt,
|
||||
action_prompt_display=action_prompt_display,
|
||||
action_done=action_done,
|
||||
thinking_id=self.thinking_id,
|
||||
action_data=self.action_data,
|
||||
action_name=self.action_name,
|
||||
)
|
||||
|
||||
async def send_command(
|
||||
self, command_name: str, args: Optional[dict] = None, display_message: str = "", storage_message: bool = True
|
||||
) -> bool:
|
||||
"""发送命令消息
|
||||
|
||||
使用stream API发送命令
|
||||
|
||||
Args:
|
||||
command_name: 命令名称
|
||||
args: 命令参数
|
||||
display_message: 显示消息
|
||||
storage_message: 是否存储消息到数据库
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
try:
|
||||
if not self.chat_id:
|
||||
logger.error(f"{self.log_prefix} 缺少聊天ID")
|
||||
return False
|
||||
|
||||
# 构造命令数据
|
||||
command_data = {"name": command_name, "args": args or {}}
|
||||
|
||||
success = await send_api.command_to_stream(
|
||||
command=command_data,
|
||||
stream_id=self.chat_id,
|
||||
storage_message=storage_message,
|
||||
display_message=display_message,
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"{self.log_prefix} 成功发送命令: {command_name}")
|
||||
else:
|
||||
logger.error(f"{self.log_prefix} 发送命令失败: {command_name}")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 发送命令时出错: {e}")
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_action_info(cls) -> "ActionInfo":
|
||||
"""从类属性生成ActionInfo
|
||||
|
||||
所有信息都从类属性中读取,确保一致性和完整性。
|
||||
Action类必须定义所有必要的类属性。
|
||||
|
||||
Returns:
|
||||
ActionInfo: 生成的Action信息对象
|
||||
"""
|
||||
|
||||
# 从类属性读取名称,如果没有定义则使用类名自动生成
|
||||
name = getattr(cls, "action_name", cls.__name__.lower().replace("action", ""))
|
||||
if "." in name:
|
||||
logger.error(f"Action名称 '{name}' 包含非法字符 '.',请使用下划线替代")
|
||||
raise ValueError(f"Action名称 '{name}' 包含非法字符 '.',请使用下划线替代")
|
||||
# 获取focus_activation_type和normal_activation_type
|
||||
focus_activation_type = getattr(cls, "focus_activation_type", ActionActivationType.ALWAYS)
|
||||
normal_activation_type = getattr(cls, "normal_activation_type", ActionActivationType.ALWAYS)
|
||||
|
||||
# 处理activation_type:如果插件中声明了就用插件的值,否则默认使用focus_activation_type
|
||||
activation_type = getattr(cls, "activation_type", focus_activation_type)
|
||||
|
||||
return ActionInfo(
|
||||
name=name,
|
||||
component_type=ComponentType.ACTION,
|
||||
description=getattr(cls, "action_description", "Action动作"),
|
||||
focus_activation_type=focus_activation_type,
|
||||
normal_activation_type=normal_activation_type,
|
||||
activation_type=activation_type,
|
||||
activation_keywords=getattr(cls, "activation_keywords", []).copy(),
|
||||
keyword_case_sensitive=getattr(cls, "keyword_case_sensitive", False),
|
||||
mode_enable=getattr(cls, "mode_enable", ChatMode.ALL),
|
||||
parallel_action=getattr(cls, "parallel_action", True),
|
||||
random_activation_probability=getattr(cls, "random_activation_probability", 0.0),
|
||||
llm_judge_prompt=getattr(cls, "llm_judge_prompt", ""),
|
||||
# 使用正确的字段名
|
||||
action_parameters=getattr(cls, "action_parameters", {}).copy(),
|
||||
action_require=getattr(cls, "action_require", []).copy(),
|
||||
associated_types=getattr(cls, "associated_types", []).copy(),
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
"""执行Action的抽象方法,子类必须实现
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (是否执行成功, 回复文本)
|
||||
"""
|
||||
pass
|
||||
|
||||
async def handle_action(self) -> Tuple[bool, str]:
|
||||
"""兼容旧系统的handle_action接口,委托给execute方法
|
||||
|
||||
为了保持向后兼容性,旧系统的代码可能会调用handle_action方法。
|
||||
此方法将调用委托给新的execute方法。
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (是否执行成功, 回复文本)
|
||||
"""
|
||||
return await self.execute()
|
||||
|
||||
def get_config(self, key: str, default=None):
|
||||
"""获取插件配置值,使用嵌套键访问
|
||||
|
||||
Args:
|
||||
key: 配置键名,使用嵌套访问如 "section.subsection.key"
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
Any: 配置值或默认值
|
||||
"""
|
||||
if not self.plugin_config:
|
||||
return default
|
||||
|
||||
# 支持嵌套键访问
|
||||
keys = key.split(".")
|
||||
current = self.plugin_config
|
||||
|
||||
for k in keys:
|
||||
if isinstance(current, dict) and k in current:
|
||||
current = current[k]
|
||||
else:
|
||||
return default
|
||||
|
||||
return current
|
||||
228
src/plugin_system/base/base_command.py
Normal file
228
src/plugin_system/base/base_command.py
Normal file
@@ -0,0 +1,228 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Tuple, Optional
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import CommandInfo, ComponentType
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.plugin_system.apis import send_api
|
||||
|
||||
logger = get_logger("base_command")
|
||||
|
||||
|
||||
class BaseCommand(ABC):
|
||||
"""Command组件基类
|
||||
|
||||
Command是插件的一种组件类型,用于处理命令请求
|
||||
|
||||
子类可以通过类属性定义命令模式:
|
||||
- command_pattern: 命令匹配的正则表达式
|
||||
- command_help: 命令帮助信息
|
||||
- command_examples: 命令使用示例列表
|
||||
"""
|
||||
|
||||
command_name: str = ""
|
||||
"""Command组件的名称"""
|
||||
command_description: str = ""
|
||||
"""Command组件的描述"""
|
||||
# 默认命令设置
|
||||
command_pattern: str = r""
|
||||
"""命令匹配的正则表达式"""
|
||||
|
||||
def __init__(self, message: MessageRecv, plugin_config: Optional[dict] = None):
|
||||
"""初始化Command组件
|
||||
|
||||
Args:
|
||||
message: 接收到的消息对象
|
||||
plugin_config: 插件配置字典
|
||||
"""
|
||||
self.message = message
|
||||
self.matched_groups: Dict[str, str] = {} # 存储正则表达式匹配的命名组
|
||||
self.plugin_config = plugin_config or {} # 直接存储插件配置字典
|
||||
|
||||
self.log_prefix = "[Command]"
|
||||
|
||||
logger.debug(f"{self.log_prefix} Command组件初始化完成")
|
||||
|
||||
def set_matched_groups(self, groups: Dict[str, str]) -> None:
|
||||
"""设置正则表达式匹配的命名组
|
||||
|
||||
Args:
|
||||
groups: 正则表达式匹配的命名组
|
||||
"""
|
||||
self.matched_groups = groups
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self) -> Tuple[bool, Optional[str], bool]:
|
||||
"""执行Command的抽象方法,子类必须实现
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Optional[str], bool]: (是否执行成功, 可选的回复消息, 是否拦截消息 不进行 后续处理)
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_config(self, key: str, default=None):
|
||||
"""获取插件配置值,使用嵌套键访问
|
||||
|
||||
Args:
|
||||
key: 配置键名,使用嵌套访问如 "section.subsection.key"
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
Any: 配置值或默认值
|
||||
"""
|
||||
if not self.plugin_config:
|
||||
return default
|
||||
|
||||
# 支持嵌套键访问
|
||||
keys = key.split(".")
|
||||
current = self.plugin_config
|
||||
|
||||
for k in keys:
|
||||
if isinstance(current, dict) and k in current:
|
||||
current = current[k]
|
||||
else:
|
||||
return default
|
||||
|
||||
return current
|
||||
|
||||
async def send_text(self, content: str, reply_to: str = "") -> bool:
|
||||
"""发送回复消息
|
||||
|
||||
Args:
|
||||
content: 回复内容
|
||||
reply_to: 回复消息,格式为"发送者:消息内容"
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
# 获取聊天流信息
|
||||
chat_stream = self.message.chat_stream
|
||||
if not chat_stream or not hasattr(chat_stream, "stream_id"):
|
||||
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
|
||||
return False
|
||||
|
||||
return await send_api.text_to_stream(text=content, stream_id=chat_stream.stream_id, reply_to=reply_to)
|
||||
|
||||
async def send_type(
|
||||
self, message_type: str, content: str, display_message: str = "", typing: bool = False, reply_to: str = ""
|
||||
) -> bool:
|
||||
"""发送指定类型的回复消息到当前聊天环境
|
||||
|
||||
Args:
|
||||
message_type: 消息类型,如"text"、"image"、"emoji"等
|
||||
content: 消息内容
|
||||
display_message: 显示消息(可选)
|
||||
typing: 是否显示正在输入
|
||||
reply_to: 回复消息,格式为"发送者:消息内容"
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
# 获取聊天流信息
|
||||
chat_stream = self.message.chat_stream
|
||||
if not chat_stream or not hasattr(chat_stream, "stream_id"):
|
||||
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
|
||||
return False
|
||||
|
||||
return await send_api.custom_to_stream(
|
||||
message_type=message_type,
|
||||
content=content,
|
||||
stream_id=chat_stream.stream_id,
|
||||
display_message=display_message,
|
||||
typing=typing,
|
||||
reply_to=reply_to,
|
||||
)
|
||||
|
||||
async def send_command(
|
||||
self, command_name: str, args: Optional[dict] = None, display_message: str = "", storage_message: bool = True
|
||||
) -> bool:
|
||||
"""发送命令消息
|
||||
|
||||
Args:
|
||||
command_name: 命令名称
|
||||
args: 命令参数
|
||||
display_message: 显示消息
|
||||
storage_message: 是否存储消息到数据库
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
try:
|
||||
# 获取聊天流信息
|
||||
chat_stream = self.message.chat_stream
|
||||
if not chat_stream or not hasattr(chat_stream, "stream_id"):
|
||||
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
|
||||
return False
|
||||
|
||||
# 构造命令数据
|
||||
command_data = {"name": command_name, "args": args or {}}
|
||||
|
||||
success = await send_api.command_to_stream(
|
||||
command=command_data,
|
||||
stream_id=chat_stream.stream_id,
|
||||
storage_message=storage_message,
|
||||
display_message=display_message,
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"{self.log_prefix} 成功发送命令: {command_name}")
|
||||
else:
|
||||
logger.error(f"{self.log_prefix} 发送命令失败: {command_name}")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 发送命令时出错: {e}")
|
||||
return False
|
||||
|
||||
async def send_emoji(self, emoji_base64: str) -> bool:
|
||||
"""发送表情包
|
||||
|
||||
Args:
|
||||
emoji_base64: 表情包的base64编码
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
chat_stream = self.message.chat_stream
|
||||
if not chat_stream or not hasattr(chat_stream, "stream_id"):
|
||||
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
|
||||
return False
|
||||
|
||||
return await send_api.emoji_to_stream(emoji_base64, chat_stream.stream_id)
|
||||
|
||||
async def send_image(self, image_base64: str) -> bool:
|
||||
"""发送图片
|
||||
|
||||
Args:
|
||||
image_base64: 图片的base64编码
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
chat_stream = self.message.chat_stream
|
||||
if not chat_stream or not hasattr(chat_stream, "stream_id"):
|
||||
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
|
||||
return False
|
||||
|
||||
return await send_api.image_to_stream(image_base64, chat_stream.stream_id)
|
||||
|
||||
@classmethod
|
||||
def get_command_info(cls) -> "CommandInfo":
|
||||
"""从类属性生成CommandInfo
|
||||
|
||||
Args:
|
||||
name: Command名称,如果不提供则使用类名
|
||||
description: Command描述,如果不提供则使用类文档字符串
|
||||
|
||||
Returns:
|
||||
CommandInfo: 生成的Command信息对象
|
||||
"""
|
||||
if "." in cls.command_name:
|
||||
logger.error(f"Command名称 '{cls.command_name}' 包含非法字符 '.',请使用下划线替代")
|
||||
raise ValueError(f"Command名称 '{cls.command_name}' 包含非法字符 '.',请使用下划线替代")
|
||||
return CommandInfo(
|
||||
name=cls.command_name,
|
||||
component_type=ComponentType.COMMAND,
|
||||
description=cls.command_description,
|
||||
command_pattern=cls.command_pattern,
|
||||
)
|
||||
101
src/plugin_system/base/base_events_handler.py
Normal file
101
src/plugin_system/base/base_events_handler.py
Normal file
@@ -0,0 +1,101 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Tuple, Optional, Dict
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from .component_types import MaiMessages, EventType, EventHandlerInfo, ComponentType
|
||||
|
||||
logger = get_logger("base_event_handler")
|
||||
|
||||
|
||||
class BaseEventHandler(ABC):
|
||||
"""事件处理器基类
|
||||
|
||||
所有事件处理器都应该继承这个基类,提供事件处理的基本接口
|
||||
"""
|
||||
|
||||
event_type: EventType = EventType.UNKNOWN
|
||||
"""事件类型,默认为未知"""
|
||||
handler_name: str = ""
|
||||
"""处理器名称"""
|
||||
handler_description: str = ""
|
||||
"""处理器描述"""
|
||||
weight: int = 0
|
||||
"""处理器权重,越大权重越高"""
|
||||
intercept_message: bool = False
|
||||
"""是否拦截消息,默认为否"""
|
||||
|
||||
def __init__(self):
|
||||
self.log_prefix = "[EventHandler]"
|
||||
self.plugin_name = ""
|
||||
"""对应插件名"""
|
||||
self.plugin_config: Optional[Dict] = None
|
||||
"""插件配置字典"""
|
||||
if self.event_type == EventType.UNKNOWN:
|
||||
raise NotImplementedError("事件处理器必须指定 event_type")
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, message: MaiMessages) -> Tuple[bool, bool, Optional[str]]:
|
||||
"""执行事件处理的抽象方法,子类必须实现
|
||||
|
||||
Returns:
|
||||
Tuple[bool, bool, Optional[str]]: (是否执行成功, 是否需要继续处理, 可选的返回消息)
|
||||
"""
|
||||
raise NotImplementedError("子类必须实现 execute 方法")
|
||||
|
||||
@classmethod
|
||||
def get_handler_info(cls) -> "EventHandlerInfo":
|
||||
"""获取事件处理器的信息"""
|
||||
# 从类属性读取名称,如果没有定义则使用类名自动生成
|
||||
name: str = getattr(cls, "handler_name", cls.__name__.lower().replace("handler", ""))
|
||||
if "." in name:
|
||||
logger.error(f"事件处理器名称 '{name}' 包含非法字符 '.',请使用下划线替代")
|
||||
raise ValueError(f"事件处理器名称 '{name}' 包含非法字符 '.',请使用下划线替代")
|
||||
return EventHandlerInfo(
|
||||
name=name,
|
||||
component_type=ComponentType.EVENT_HANDLER,
|
||||
description=getattr(cls, "handler_description", "events处理器"),
|
||||
event_type=cls.event_type,
|
||||
weight=cls.weight,
|
||||
intercept_message=cls.intercept_message,
|
||||
)
|
||||
|
||||
def set_plugin_config(self, plugin_config: Dict) -> None:
|
||||
"""设置插件配置
|
||||
|
||||
Args:
|
||||
plugin_config (dict): 插件配置字典
|
||||
"""
|
||||
self.plugin_config = plugin_config
|
||||
|
||||
def set_plugin_name(self, plugin_name: str) -> None:
|
||||
"""设置插件名称
|
||||
|
||||
Args:
|
||||
plugin_name (str): 插件名称
|
||||
"""
|
||||
self.plugin_name = plugin_name
|
||||
|
||||
def get_config(self, key: str, default=None):
|
||||
"""获取插件配置值,支持嵌套键访问
|
||||
|
||||
Args:
|
||||
key: 配置键名,支持嵌套访问如 "section.subsection.key"
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
Any: 配置值或默认值
|
||||
"""
|
||||
if not self.plugin_config:
|
||||
return default
|
||||
|
||||
# 支持嵌套键访问
|
||||
keys = key.split(".")
|
||||
current = self.plugin_config
|
||||
|
||||
for k in keys:
|
||||
if isinstance(current, dict) and k in current:
|
||||
current = current[k]
|
||||
else:
|
||||
return default
|
||||
|
||||
return current
|
||||
76
src/plugin_system/base/base_plugin.py
Normal file
76
src/plugin_system/base/base_plugin.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from abc import abstractmethod
|
||||
from typing import List, Type, Tuple, Union
|
||||
from .plugin_base import PluginBase
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import ActionInfo, CommandInfo, EventHandlerInfo, ToolInfo
|
||||
from .base_action import BaseAction
|
||||
from .base_command import BaseCommand
|
||||
from .base_events_handler import BaseEventHandler
|
||||
from .base_tool import BaseTool
|
||||
|
||||
logger = get_logger("base_plugin")
|
||||
|
||||
|
||||
class BasePlugin(PluginBase):
|
||||
"""基于Action和Command的插件基类
|
||||
|
||||
所有上述类型的插件都应该继承这个基类,一个插件可以包含多种组件:
|
||||
- Action组件:处理聊天中的动作
|
||||
- Command组件:处理命令请求
|
||||
- 未来可扩展:Scheduler、Listener等
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def get_plugin_components(
|
||||
self,
|
||||
) -> List[
|
||||
Union[
|
||||
Tuple[ActionInfo, Type[BaseAction]],
|
||||
Tuple[CommandInfo, Type[BaseCommand]],
|
||||
Tuple[EventHandlerInfo, Type[BaseEventHandler]],
|
||||
Tuple[ToolInfo, Type[BaseTool]],
|
||||
]
|
||||
]:
|
||||
"""获取插件包含的组件列表
|
||||
|
||||
子类必须实现此方法,返回组件信息和组件类的列表
|
||||
|
||||
Returns:
|
||||
List[tuple[ComponentInfo, Type]]: [(组件信息, 组件类), ...]
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement this method")
|
||||
|
||||
def register_plugin(self) -> bool:
|
||||
"""注册插件及其所有组件"""
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
components = self.get_plugin_components()
|
||||
|
||||
# 检查依赖
|
||||
if not self._check_dependencies():
|
||||
logger.error(f"{self.log_prefix} 依赖检查失败,跳过注册")
|
||||
return False
|
||||
|
||||
# 注册所有组件
|
||||
registered_components = []
|
||||
for component_info, component_class in components:
|
||||
component_info.plugin_name = self.plugin_name
|
||||
if component_registry.register_component(component_info, component_class):
|
||||
registered_components.append(component_info)
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 组件 {component_info.name} 注册失败")
|
||||
|
||||
# 更新插件信息中的组件列表
|
||||
self.plugin_info.components = registered_components
|
||||
|
||||
# 注册插件
|
||||
if component_registry.register_plugin(self.plugin_info):
|
||||
logger.debug(f"{self.log_prefix} 插件注册成功,包含 {len(registered_components)} 个组件")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"{self.log_prefix} 插件注册失败")
|
||||
return False
|
||||
119
src/plugin_system/base/base_tool.py
Normal file
119
src/plugin_system/base/base_tool.py
Normal file
@@ -0,0 +1,119 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Optional, Tuple
|
||||
from rich.traceback import install
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import ComponentType, ToolInfo, ToolParamType
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("base_tool")
|
||||
|
||||
|
||||
class BaseTool(ABC):
|
||||
"""所有工具的基类"""
|
||||
|
||||
name: str = ""
|
||||
"""工具的名称"""
|
||||
description: str = ""
|
||||
"""工具的描述"""
|
||||
parameters: List[Tuple[str, ToolParamType, str, bool, List[str] | None]] = []
|
||||
"""工具的参数定义,为[("param_name", param_type, "description", required, enum_values)]格式
|
||||
param_name: 参数名称
|
||||
param_type: 参数类型
|
||||
description: 参数描述
|
||||
required: 是否必填
|
||||
enum_values: 枚举值列表
|
||||
例如: [("arg1", ToolParamType.STRING, "参数1描述", True, None), ("arg2", ToolParamType.INTEGER, "参数2描述", False, ["1", "2", "3"])]
|
||||
"""
|
||||
available_for_llm: bool = False
|
||||
"""是否可供LLM使用"""
|
||||
|
||||
def __init__(self, plugin_config: Optional[dict] = None):
|
||||
self.plugin_config = plugin_config or {} # 直接存储插件配置字典
|
||||
|
||||
@classmethod
|
||||
def get_tool_definition(cls) -> dict[str, Any]:
|
||||
"""获取工具定义,用于LLM工具调用
|
||||
|
||||
Returns:
|
||||
dict: 工具定义字典
|
||||
"""
|
||||
if not cls.name or not cls.description or not cls.parameters:
|
||||
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
|
||||
|
||||
return {"name": cls.name, "description": cls.description, "parameters": cls.parameters}
|
||||
|
||||
@classmethod
|
||||
def get_tool_info(cls) -> ToolInfo:
|
||||
"""获取工具信息"""
|
||||
if not cls.name or not cls.description or not cls.parameters:
|
||||
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
|
||||
|
||||
return ToolInfo(
|
||||
name=cls.name,
|
||||
tool_description=cls.description,
|
||||
enabled=cls.available_for_llm,
|
||||
tool_parameters=cls.parameters,
|
||||
component_type=ComponentType.TOOL,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""执行工具函数(供llm调用)
|
||||
通过该方法,maicore会通过llm的tool call来调用工具
|
||||
传入的是json格式的参数,符合parameters定义的格式
|
||||
|
||||
Args:
|
||||
function_args: 工具调用参数
|
||||
|
||||
Returns:
|
||||
dict: 工具执行结果
|
||||
"""
|
||||
raise NotImplementedError("子类必须实现execute方法")
|
||||
|
||||
async def direct_execute(self, **function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""直接执行工具函数(供插件调用)
|
||||
通过该方法,插件可以直接调用工具,而不需要传入字典格式的参数
|
||||
插件可以直接调用此方法,用更加明了的方式传入参数
|
||||
示例: result = await tool.direct_execute(arg1="参数",arg2="参数2")
|
||||
|
||||
工具开发者可以重写此方法以实现与llm调用差异化的执行逻辑
|
||||
|
||||
Args:
|
||||
**function_args: 工具调用参数
|
||||
|
||||
Returns:
|
||||
dict: 工具执行结果
|
||||
"""
|
||||
parameter_required = [param[0] for param in self.parameters if param[3]] # 获取所有必填参数名
|
||||
for param_name in parameter_required:
|
||||
if param_name not in function_args:
|
||||
raise ValueError(f"工具类 {self.__class__.__name__} 缺少必要参数: {param_name}")
|
||||
|
||||
return await self.execute(function_args)
|
||||
|
||||
def get_config(self, key: str, default=None):
|
||||
"""获取插件配置值,使用嵌套键访问
|
||||
|
||||
Args:
|
||||
key: 配置键名,使用嵌套访问如 "section.subsection.key"
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
Any: 配置值或默认值
|
||||
"""
|
||||
if not self.plugin_config:
|
||||
return default
|
||||
|
||||
# 支持嵌套键访问
|
||||
keys = key.split(".")
|
||||
current = self.plugin_config
|
||||
|
||||
for k in keys:
|
||||
if isinstance(current, dict) and k in current:
|
||||
current = current[k]
|
||||
else:
|
||||
return default
|
||||
|
||||
return current
|
||||
283
src/plugin_system/base/component_types.py
Normal file
283
src/plugin_system/base/component_types.py
Normal file
@@ -0,0 +1,283 @@
|
||||
from enum import Enum
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from maim_message import Seg
|
||||
|
||||
from src.llm_models.payload_content.tool_option import ToolParamType as ToolParamType
|
||||
from src.llm_models.payload_content.tool_option import ToolCall as ToolCall
|
||||
|
||||
# 组件类型枚举
|
||||
class ComponentType(Enum):
|
||||
"""组件类型枚举"""
|
||||
|
||||
ACTION = "action" # 动作组件
|
||||
COMMAND = "command" # 命令组件
|
||||
TOOL = "tool" # 服务组件(预留)
|
||||
SCHEDULER = "scheduler" # 定时任务组件(预留)
|
||||
EVENT_HANDLER = "event_handler" # 事件处理组件(预留)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
# 动作激活类型枚举
|
||||
class ActionActivationType(Enum):
|
||||
"""动作激活类型枚举"""
|
||||
|
||||
NEVER = "never" # 从不激活(默认关闭)
|
||||
ALWAYS = "always" # 默认参与到planner
|
||||
LLM_JUDGE = "llm_judge" # LLM判定是否启动该action到planner
|
||||
RANDOM = "random" # 随机启用action到planner
|
||||
KEYWORD = "keyword" # 关键词触发启用action到planner
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
|
||||
# 聊天模式枚举
|
||||
class ChatMode(Enum):
|
||||
"""聊天模式枚举"""
|
||||
|
||||
FOCUS = "focus" # Focus聊天模式
|
||||
NORMAL = "normal" # Normal聊天模式
|
||||
PRIORITY = "priority" # 优先级聊天模式
|
||||
ALL = "all" # 所有聊天模式
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
|
||||
# 事件类型枚举
|
||||
class EventType(Enum):
|
||||
"""
|
||||
事件类型枚举类
|
||||
"""
|
||||
|
||||
ON_START = "on_start" # 启动事件,用于调用按时任务
|
||||
ON_MESSAGE = "on_message"
|
||||
ON_PLAN = "on_plan"
|
||||
POST_LLM = "post_llm"
|
||||
AFTER_LLM = "after_llm"
|
||||
POST_SEND = "post_send"
|
||||
AFTER_SEND = "after_send"
|
||||
UNKNOWN = "unknown" # 未知事件类型
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
@dataclass
|
||||
class PythonDependency:
|
||||
"""Python包依赖信息"""
|
||||
|
||||
package_name: str # 包名称
|
||||
version: str = "" # 版本要求,例如: ">=1.0.0", "==2.1.3", ""表示任意版本
|
||||
optional: bool = False # 是否为可选依赖
|
||||
description: str = "" # 依赖描述
|
||||
install_name: str = "" # 安装时的包名(如果与import名不同)
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.install_name:
|
||||
self.install_name = self.package_name
|
||||
|
||||
def get_pip_requirement(self) -> str:
|
||||
"""获取pip安装格式的依赖字符串"""
|
||||
if self.version:
|
||||
return f"{self.install_name}{self.version}"
|
||||
return self.install_name
|
||||
|
||||
|
||||
@dataclass
|
||||
class ComponentInfo:
|
||||
"""组件信息"""
|
||||
|
||||
name: str # 组件名称
|
||||
component_type: ComponentType # 组件类型
|
||||
description: str = "" # 组件描述
|
||||
enabled: bool = True # 是否启用
|
||||
plugin_name: str = "" # 所属插件名称
|
||||
is_built_in: bool = False # 是否为内置组件
|
||||
metadata: Dict[str, Any] = field(default_factory=dict) # 额外元数据
|
||||
|
||||
def __post_init__(self):
|
||||
if self.metadata is None:
|
||||
self.metadata = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActionInfo(ComponentInfo):
|
||||
"""动作组件信息"""
|
||||
|
||||
action_parameters: Dict[str, str] = field(
|
||||
default_factory=dict
|
||||
) # 动作参数与描述,例如 {"param1": "描述1", "param2": "描述2"}
|
||||
action_require: List[str] = field(default_factory=list) # 动作需求说明
|
||||
associated_types: List[str] = field(default_factory=list) # 关联的消息类型
|
||||
# 激活类型相关
|
||||
focus_activation_type: ActionActivationType = ActionActivationType.ALWAYS
|
||||
normal_activation_type: ActionActivationType = ActionActivationType.ALWAYS
|
||||
activation_type: ActionActivationType = ActionActivationType.ALWAYS
|
||||
random_activation_probability: float = 0.0
|
||||
llm_judge_prompt: str = ""
|
||||
activation_keywords: List[str] = field(default_factory=list) # 激活关键词列表
|
||||
keyword_case_sensitive: bool = False
|
||||
# 模式和并行设置
|
||||
mode_enable: ChatMode = ChatMode.ALL
|
||||
parallel_action: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self.activation_keywords is None:
|
||||
self.activation_keywords = []
|
||||
if self.action_parameters is None:
|
||||
self.action_parameters = {}
|
||||
if self.action_require is None:
|
||||
self.action_require = []
|
||||
if self.associated_types is None:
|
||||
self.associated_types = []
|
||||
self.component_type = ComponentType.ACTION
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommandInfo(ComponentInfo):
|
||||
"""命令组件信息"""
|
||||
|
||||
command_pattern: str = "" # 命令匹配模式(正则表达式)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.component_type = ComponentType.COMMAND
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolInfo(ComponentInfo):
|
||||
"""工具组件信息"""
|
||||
|
||||
tool_parameters: List[Tuple[str, ToolParamType, str, bool, List[str] | None]] = field(default_factory=list) # 工具参数定义
|
||||
tool_description: str = "" # 工具描述
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.component_type = ComponentType.TOOL
|
||||
|
||||
|
||||
@dataclass
|
||||
class EventHandlerInfo(ComponentInfo):
|
||||
"""事件处理器组件信息"""
|
||||
|
||||
event_type: EventType = EventType.ON_MESSAGE # 监听事件类型
|
||||
intercept_message: bool = False # 是否拦截消息处理(默认不拦截)
|
||||
weight: int = 0 # 事件处理器权重,决定执行顺序
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.component_type = ComponentType.EVENT_HANDLER
|
||||
|
||||
|
||||
@dataclass
|
||||
class PluginInfo:
|
||||
"""插件信息"""
|
||||
|
||||
display_name: str # 插件显示名称
|
||||
name: str # 插件名称
|
||||
description: str # 插件描述
|
||||
version: str = "1.0.0" # 插件版本
|
||||
author: str = "" # 插件作者
|
||||
enabled: bool = True # 是否启用
|
||||
is_built_in: bool = False # 是否为内置插件
|
||||
components: List[ComponentInfo] = field(default_factory=list) # 包含的组件列表
|
||||
dependencies: List[str] = field(default_factory=list) # 依赖的其他插件
|
||||
python_dependencies: List[PythonDependency] = field(default_factory=list) # Python包依赖
|
||||
config_file: str = "" # 配置文件路径
|
||||
metadata: Dict[str, Any] = field(default_factory=dict) # 额外元数据
|
||||
# 新增:manifest相关信息
|
||||
manifest_data: Dict[str, Any] = field(default_factory=dict) # manifest文件数据
|
||||
license: str = "" # 插件许可证
|
||||
homepage_url: str = "" # 插件主页
|
||||
repository_url: str = "" # 插件仓库地址
|
||||
keywords: List[str] = field(default_factory=list) # 插件关键词
|
||||
categories: List[str] = field(default_factory=list) # 插件分类
|
||||
min_host_version: str = "" # 最低主机版本要求
|
||||
max_host_version: str = "" # 最高主机版本要求
|
||||
|
||||
def __post_init__(self):
|
||||
if self.components is None:
|
||||
self.components = []
|
||||
if self.dependencies is None:
|
||||
self.dependencies = []
|
||||
if self.python_dependencies is None:
|
||||
self.python_dependencies = []
|
||||
if self.metadata is None:
|
||||
self.metadata = {}
|
||||
if self.manifest_data is None:
|
||||
self.manifest_data = {}
|
||||
if self.keywords is None:
|
||||
self.keywords = []
|
||||
if self.categories is None:
|
||||
self.categories = []
|
||||
|
||||
def get_missing_packages(self) -> List[PythonDependency]:
|
||||
"""检查缺失的Python包"""
|
||||
missing = []
|
||||
for dep in self.python_dependencies:
|
||||
try:
|
||||
__import__(dep.package_name)
|
||||
except ImportError:
|
||||
if not dep.optional:
|
||||
missing.append(dep)
|
||||
return missing
|
||||
|
||||
def get_pip_requirements(self) -> List[str]:
|
||||
"""获取所有pip安装格式的依赖"""
|
||||
return [dep.get_pip_requirement() for dep in self.python_dependencies]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaiMessages:
|
||||
"""MaiM插件消息"""
|
||||
|
||||
message_segments: List[Seg] = field(default_factory=list)
|
||||
"""消息段列表,支持多段消息"""
|
||||
|
||||
message_base_info: Dict[str, Any] = field(default_factory=dict)
|
||||
"""消息基本信息,包含平台,用户信息等数据"""
|
||||
|
||||
plain_text: str = ""
|
||||
"""纯文本消息内容"""
|
||||
|
||||
raw_message: Optional[str] = None
|
||||
"""原始消息内容"""
|
||||
|
||||
is_group_message: bool = False
|
||||
"""是否为群组消息"""
|
||||
|
||||
is_private_message: bool = False
|
||||
"""是否为私聊消息"""
|
||||
|
||||
stream_id: Optional[str] = None
|
||||
"""流ID,用于标识消息流"""
|
||||
|
||||
llm_prompt: Optional[str] = None
|
||||
"""LLM提示词"""
|
||||
|
||||
llm_response_content: Optional[str] = None
|
||||
"""LLM响应内容"""
|
||||
|
||||
llm_response_reasoning: Optional[str] = None
|
||||
"""LLM响应推理内容"""
|
||||
|
||||
llm_response_model: Optional[str] = None
|
||||
"""LLM响应模型名称"""
|
||||
|
||||
llm_response_tool_call: Optional[List[ToolCall]] = None
|
||||
"""LLM使用的工具调用"""
|
||||
|
||||
action_usage: Optional[List[str]] = None
|
||||
"""使用的Action"""
|
||||
|
||||
additional_data: Dict[Any, Any] = field(default_factory=dict)
|
||||
"""附加数据,可以存储额外信息"""
|
||||
|
||||
def __post_init__(self):
|
||||
if self.message_segments is None:
|
||||
self.message_segments = []
|
||||
18
src/plugin_system/base/config_types.py
Normal file
18
src/plugin_system/base/config_types.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""
|
||||
插件系统配置类型定义
|
||||
"""
|
||||
|
||||
from typing import Any, Optional, List
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfigField:
|
||||
"""配置字段定义"""
|
||||
|
||||
type: type # 字段类型
|
||||
default: Any # 默认值
|
||||
description: str # 字段描述
|
||||
example: Optional[str] = None # 示例值
|
||||
required: bool = False # 是否必需
|
||||
choices: Optional[List[Any]] = field(default_factory=list) # 可选值列表
|
||||
577
src/plugin_system/base/plugin_base.py
Normal file
577
src/plugin_system/base/plugin_base.py
Normal file
@@ -0,0 +1,577 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Any, Union
|
||||
import os
|
||||
import inspect
|
||||
import toml
|
||||
import json
|
||||
import shutil
|
||||
import datetime
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import (
|
||||
PluginInfo,
|
||||
PythonDependency,
|
||||
)
|
||||
from src.plugin_system.base.config_types import ConfigField
|
||||
from src.plugin_system.utils.manifest_utils import ManifestValidator
|
||||
|
||||
logger = get_logger("plugin_base")
|
||||
|
||||
|
||||
class PluginBase(ABC):
|
||||
"""插件总基类
|
||||
|
||||
所有衍生插件基类都应该继承自此类,这个类定义了插件的基本结构和行为。
|
||||
"""
|
||||
|
||||
# 插件基本信息(子类必须定义)
|
||||
@property
|
||||
@abstractmethod
|
||||
def plugin_name(self) -> str:
|
||||
return "" # 插件内部标识符(如 "hello_world_plugin")
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def enable_plugin(self) -> bool:
|
||||
return True # 是否启用插件
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def dependencies(self) -> List[str]:
|
||||
return [] # 依赖的其他插件
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def python_dependencies(self) -> List[PythonDependency]:
|
||||
return [] # Python包依赖
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def config_file_name(self) -> str:
|
||||
return "" # 配置文件名
|
||||
|
||||
# manifest文件相关
|
||||
manifest_file_name: str = "_manifest.json" # manifest文件名
|
||||
manifest_data: Dict[str, Any] = {} # manifest数据
|
||||
|
||||
# 配置定义
|
||||
@property
|
||||
@abstractmethod
|
||||
def config_schema(self) -> Dict[str, Union[Dict[str, ConfigField], str]]:
|
||||
return {}
|
||||
|
||||
config_section_descriptions: Dict[str, str] = {}
|
||||
|
||||
def __init__(self, plugin_dir: str):
|
||||
"""初始化插件
|
||||
|
||||
Args:
|
||||
plugin_dir: 插件目录路径,由插件管理器传递
|
||||
"""
|
||||
self.config: Dict[str, Any] = {} # 插件配置
|
||||
self.plugin_dir = plugin_dir # 插件目录路径
|
||||
self.log_prefix = f"[Plugin:{self.plugin_name}]"
|
||||
|
||||
# 加载manifest文件
|
||||
self._load_manifest()
|
||||
|
||||
# 验证插件信息
|
||||
self._validate_plugin_info()
|
||||
|
||||
# 加载插件配置
|
||||
self._load_plugin_config()
|
||||
|
||||
# 从manifest获取显示信息
|
||||
self.display_name = self.get_manifest_info("name", self.plugin_name)
|
||||
self.plugin_version = self.get_manifest_info("version", "1.0.0")
|
||||
self.plugin_description = self.get_manifest_info("description", "")
|
||||
self.plugin_author = self._get_author_name()
|
||||
|
||||
# 创建插件信息对象
|
||||
self.plugin_info = PluginInfo(
|
||||
name=self.plugin_name,
|
||||
display_name=self.display_name,
|
||||
description=self.plugin_description,
|
||||
version=self.plugin_version,
|
||||
author=self.plugin_author,
|
||||
enabled=self.enable_plugin,
|
||||
is_built_in=False,
|
||||
config_file=self.config_file_name or "",
|
||||
dependencies=self.dependencies.copy(),
|
||||
python_dependencies=self.python_dependencies.copy(),
|
||||
# manifest相关信息
|
||||
manifest_data=self.manifest_data.copy(),
|
||||
license=self.get_manifest_info("license", ""),
|
||||
homepage_url=self.get_manifest_info("homepage_url", ""),
|
||||
repository_url=self.get_manifest_info("repository_url", ""),
|
||||
keywords=self.get_manifest_info("keywords", []).copy() if self.get_manifest_info("keywords") else [],
|
||||
categories=self.get_manifest_info("categories", []).copy() if self.get_manifest_info("categories") else [],
|
||||
min_host_version=self.get_manifest_info("host_application.min_version", ""),
|
||||
max_host_version=self.get_manifest_info("host_application.max_version", ""),
|
||||
)
|
||||
|
||||
logger.debug(f"{self.log_prefix} 插件基类初始化完成")
|
||||
|
||||
def _validate_plugin_info(self):
|
||||
"""验证插件基本信息"""
|
||||
if not self.plugin_name:
|
||||
raise ValueError(f"插件类 {self.__class__.__name__} 必须定义 plugin_name")
|
||||
|
||||
# 验证manifest中的必需信息
|
||||
if not self.get_manifest_info("name"):
|
||||
raise ValueError(f"插件 {self.plugin_name} 的manifest中缺少name字段")
|
||||
if not self.get_manifest_info("description"):
|
||||
raise ValueError(f"插件 {self.plugin_name} 的manifest中缺少description字段")
|
||||
|
||||
def _load_manifest(self): # sourcery skip: raise-from-previous-error
|
||||
"""加载manifest文件(强制要求)"""
|
||||
if not self.plugin_dir:
|
||||
raise ValueError(f"{self.log_prefix} 没有插件目录路径,无法加载manifest")
|
||||
|
||||
manifest_path = os.path.join(self.plugin_dir, self.manifest_file_name)
|
||||
|
||||
if not os.path.exists(manifest_path):
|
||||
error_msg = f"{self.log_prefix} 缺少必需的manifest文件: {manifest_path}"
|
||||
logger.error(error_msg)
|
||||
raise FileNotFoundError(error_msg)
|
||||
|
||||
try:
|
||||
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||
self.manifest_data = json.load(f)
|
||||
|
||||
logger.debug(f"{self.log_prefix} 成功加载manifest文件: {manifest_path}")
|
||||
|
||||
# 验证manifest格式
|
||||
self._validate_manifest()
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
error_msg = f"{self.log_prefix} manifest文件格式错误: {e}"
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg) # noqa
|
||||
except IOError as e:
|
||||
error_msg = f"{self.log_prefix} 读取manifest文件失败: {e}"
|
||||
logger.error(error_msg)
|
||||
raise IOError(error_msg) # noqa
|
||||
|
||||
def _get_author_name(self) -> str:
|
||||
"""从manifest获取作者名称"""
|
||||
author_info = self.get_manifest_info("author", {})
|
||||
if isinstance(author_info, dict):
|
||||
return author_info.get("name", "")
|
||||
else:
|
||||
return str(author_info) if author_info else ""
|
||||
|
||||
def _validate_manifest(self):
|
||||
"""验证manifest文件格式(使用强化的验证器)"""
|
||||
if not self.manifest_data:
|
||||
raise ValueError(f"{self.log_prefix} manifest数据为空,验证失败")
|
||||
|
||||
validator = ManifestValidator()
|
||||
is_valid = validator.validate_manifest(self.manifest_data)
|
||||
|
||||
# 记录验证结果
|
||||
if validator.validation_errors or validator.validation_warnings:
|
||||
report = validator.get_validation_report()
|
||||
logger.info(f"{self.log_prefix} Manifest验证结果:\n{report}")
|
||||
|
||||
# 如果有验证错误,抛出异常
|
||||
if not is_valid:
|
||||
error_msg = f"{self.log_prefix} Manifest文件验证失败"
|
||||
if validator.validation_errors:
|
||||
error_msg += f": {'; '.join(validator.validation_errors)}"
|
||||
raise ValueError(error_msg)
|
||||
|
||||
def get_manifest_info(self, key: str, default: Any = None) -> Any:
|
||||
"""获取manifest信息
|
||||
|
||||
Args:
|
||||
key: 信息键,支持点分割的嵌套键(如 "author.name")
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
Any: 对应的值
|
||||
"""
|
||||
if not self.manifest_data:
|
||||
return default
|
||||
|
||||
keys = key.split(".")
|
||||
value = self.manifest_data
|
||||
|
||||
for k in keys:
|
||||
if isinstance(value, dict) and k in value:
|
||||
value = value[k]
|
||||
else:
|
||||
return default
|
||||
|
||||
return value
|
||||
|
||||
def _generate_and_save_default_config(self, config_file_path: str):
|
||||
"""根据插件的Schema生成并保存默认配置文件"""
|
||||
if not self.config_schema:
|
||||
logger.debug(f"{self.log_prefix} 插件未定义config_schema,不生成配置文件")
|
||||
return
|
||||
|
||||
toml_str = f"# {self.plugin_name} - 自动生成的配置文件\n"
|
||||
plugin_description = self.get_manifest_info("description", "插件配置文件")
|
||||
toml_str += f"# {plugin_description}\n\n"
|
||||
|
||||
# 遍历每个配置节
|
||||
for section, fields in self.config_schema.items():
|
||||
# 添加节描述
|
||||
if section in self.config_section_descriptions:
|
||||
toml_str += f"# {self.config_section_descriptions[section]}\n"
|
||||
|
||||
toml_str += f"[{section}]\n\n"
|
||||
|
||||
# 遍历节内的字段
|
||||
if isinstance(fields, dict):
|
||||
for field_name, field in fields.items():
|
||||
if isinstance(field, ConfigField):
|
||||
# 添加字段描述
|
||||
toml_str += f"# {field.description}"
|
||||
if field.required:
|
||||
toml_str += " (必需)"
|
||||
toml_str += "\n"
|
||||
|
||||
# 如果有示例值,添加示例
|
||||
if field.example:
|
||||
toml_str += f"# 示例: {field.example}\n"
|
||||
|
||||
# 如果有可选值,添加说明
|
||||
if field.choices:
|
||||
choices_str = ", ".join(map(str, field.choices))
|
||||
toml_str += f"# 可选值: {choices_str}\n"
|
||||
|
||||
# 添加字段值
|
||||
value = field.default
|
||||
if isinstance(value, str):
|
||||
toml_str += f'{field_name} = "{value}"\n'
|
||||
elif isinstance(value, bool):
|
||||
toml_str += f"{field_name} = {str(value).lower()}\n"
|
||||
else:
|
||||
toml_str += f"{field_name} = {value}\n"
|
||||
|
||||
toml_str += "\n"
|
||||
toml_str += "\n"
|
||||
|
||||
try:
|
||||
with open(config_file_path, "w", encoding="utf-8") as f:
|
||||
f.write(toml_str)
|
||||
logger.info(f"{self.log_prefix} 已生成默认配置文件: {config_file_path}")
|
||||
except IOError as e:
|
||||
logger.error(f"{self.log_prefix} 保存默认配置文件失败: {e}", exc_info=True)
|
||||
|
||||
def _get_expected_config_version(self) -> str:
|
||||
"""获取插件期望的配置版本号"""
|
||||
# 从config_schema的plugin.config_version字段获取
|
||||
if "plugin" in self.config_schema and isinstance(self.config_schema["plugin"], dict):
|
||||
config_version_field = self.config_schema["plugin"].get("config_version")
|
||||
if isinstance(config_version_field, ConfigField):
|
||||
return config_version_field.default
|
||||
return "1.0.0"
|
||||
|
||||
def _get_current_config_version(self, config: Dict[str, Any]) -> str:
|
||||
"""从配置文件中获取当前版本号"""
|
||||
if "plugin" in config and "config_version" in config["plugin"]:
|
||||
return str(config["plugin"]["config_version"])
|
||||
# 如果没有config_version字段,视为最早的版本
|
||||
return "0.0.0"
|
||||
|
||||
def _backup_config_file(self, config_file_path: str) -> str:
|
||||
"""备份配置文件"""
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
backup_path = f"{config_file_path}.backup_{timestamp}"
|
||||
|
||||
try:
|
||||
shutil.copy2(config_file_path, backup_path)
|
||||
logger.info(f"{self.log_prefix} 配置文件已备份到: {backup_path}")
|
||||
return backup_path
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 备份配置文件失败: {e}")
|
||||
return ""
|
||||
|
||||
def _migrate_config_values(self, old_config: Dict[str, Any], new_config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""将旧配置值迁移到新配置结构中
|
||||
|
||||
Args:
|
||||
old_config: 旧配置数据
|
||||
new_config: 基于新schema生成的默认配置
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 迁移后的配置
|
||||
"""
|
||||
|
||||
def migrate_section(
|
||||
old_section: Dict[str, Any], new_section: Dict[str, Any], section_name: str
|
||||
) -> Dict[str, Any]:
|
||||
"""迁移单个配置节"""
|
||||
result = new_section.copy()
|
||||
|
||||
for key, value in old_section.items():
|
||||
if key in new_section:
|
||||
# 特殊处理:config_version字段总是使用新版本
|
||||
if section_name == "plugin" and key == "config_version":
|
||||
# 保持新的版本号,不迁移旧值
|
||||
logger.debug(
|
||||
f"{self.log_prefix} 更新配置版本: {section_name}.{key} = {result[key]} (旧值: {value})"
|
||||
)
|
||||
continue
|
||||
|
||||
# 键存在于新配置中,复制值
|
||||
if isinstance(value, dict) and isinstance(new_section[key], dict):
|
||||
# 递归处理嵌套字典
|
||||
result[key] = migrate_section(value, new_section[key], f"{section_name}.{key}")
|
||||
else:
|
||||
result[key] = value
|
||||
logger.debug(f"{self.log_prefix} 迁移配置: {section_name}.{key} = {value}")
|
||||
else:
|
||||
# 键在新配置中不存在,记录警告
|
||||
logger.warning(f"{self.log_prefix} 配置项 {section_name}.{key} 在新版本中已被移除")
|
||||
|
||||
return result
|
||||
|
||||
migrated_config = {}
|
||||
|
||||
# 迁移每个配置节
|
||||
for section_name, new_section_data in new_config.items():
|
||||
if (
|
||||
section_name in old_config
|
||||
and isinstance(old_config[section_name], dict)
|
||||
and isinstance(new_section_data, dict)
|
||||
):
|
||||
migrated_config[section_name] = migrate_section(
|
||||
old_config[section_name], new_section_data, section_name
|
||||
)
|
||||
else:
|
||||
# 新增的节或类型不匹配,使用默认值
|
||||
migrated_config[section_name] = new_section_data
|
||||
if section_name in old_config:
|
||||
logger.warning(f"{self.log_prefix} 配置节 {section_name} 结构已改变,使用默认值")
|
||||
|
||||
# 检查旧配置中是否有新配置没有的节
|
||||
for section_name in old_config:
|
||||
if section_name not in migrated_config:
|
||||
logger.warning(f"{self.log_prefix} 配置节 {section_name} 在新版本中已被移除")
|
||||
|
||||
return migrated_config
|
||||
|
||||
def _generate_config_from_schema(self) -> Dict[str, Any]:
|
||||
# sourcery skip: dict-comprehension
|
||||
"""根据schema生成配置数据结构(不写入文件)"""
|
||||
if not self.config_schema:
|
||||
return {}
|
||||
|
||||
config_data = {}
|
||||
|
||||
# 遍历每个配置节
|
||||
for section, fields in self.config_schema.items():
|
||||
if isinstance(fields, dict):
|
||||
section_data = {}
|
||||
|
||||
# 遍历节内的字段
|
||||
for field_name, field in fields.items():
|
||||
if isinstance(field, ConfigField):
|
||||
section_data[field_name] = field.default
|
||||
|
||||
config_data[section] = section_data
|
||||
|
||||
return config_data
|
||||
|
||||
def _save_config_to_file(self, config_data: Dict[str, Any], config_file_path: str):
|
||||
"""将配置数据保存为TOML文件(包含注释)"""
|
||||
if not self.config_schema:
|
||||
logger.debug(f"{self.log_prefix} 插件未定义config_schema,不生成配置文件")
|
||||
return
|
||||
|
||||
toml_str = f"# {self.plugin_name} - 配置文件\n"
|
||||
plugin_description = self.get_manifest_info("description", "插件配置文件")
|
||||
toml_str += f"# {plugin_description}\n"
|
||||
|
||||
# 获取当前期望的配置版本
|
||||
expected_version = self._get_expected_config_version()
|
||||
toml_str += f"# 配置版本: {expected_version}\n\n"
|
||||
|
||||
# 遍历每个配置节
|
||||
for section, fields in self.config_schema.items():
|
||||
# 添加节描述
|
||||
if section in self.config_section_descriptions:
|
||||
toml_str += f"# {self.config_section_descriptions[section]}\n"
|
||||
|
||||
toml_str += f"[{section}]\n\n"
|
||||
|
||||
# 遍历节内的字段
|
||||
if isinstance(fields, dict) and section in config_data:
|
||||
section_data = config_data[section]
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if isinstance(field, ConfigField):
|
||||
# 添加字段描述
|
||||
toml_str += f"# {field.description}"
|
||||
if field.required:
|
||||
toml_str += " (必需)"
|
||||
toml_str += "\n"
|
||||
|
||||
# 如果有示例值,添加示例
|
||||
if field.example:
|
||||
toml_str += f"# 示例: {field.example}\n"
|
||||
|
||||
# 如果有可选值,添加说明
|
||||
if field.choices:
|
||||
choices_str = ", ".join(map(str, field.choices))
|
||||
toml_str += f"# 可选值: {choices_str}\n"
|
||||
|
||||
# 添加字段值(使用迁移后的值)
|
||||
value = section_data.get(field_name, field.default)
|
||||
if isinstance(value, str):
|
||||
toml_str += f'{field_name} = "{value}"\n'
|
||||
elif isinstance(value, bool):
|
||||
toml_str += f"{field_name} = {str(value).lower()}\n"
|
||||
elif isinstance(value, list):
|
||||
# 格式化列表
|
||||
if all(isinstance(item, str) for item in value):
|
||||
formatted_list = "[" + ", ".join(f'"{item}"' for item in value) + "]"
|
||||
else:
|
||||
formatted_list = str(value)
|
||||
toml_str += f"{field_name} = {formatted_list}\n"
|
||||
else:
|
||||
toml_str += f"{field_name} = {value}\n"
|
||||
|
||||
toml_str += "\n"
|
||||
toml_str += "\n"
|
||||
|
||||
try:
|
||||
with open(config_file_path, "w", encoding="utf-8") as f:
|
||||
f.write(toml_str)
|
||||
logger.info(f"{self.log_prefix} 配置文件已保存: {config_file_path}")
|
||||
except IOError as e:
|
||||
logger.error(f"{self.log_prefix} 保存配置文件失败: {e}", exc_info=True)
|
||||
|
||||
def _load_plugin_config(self): # sourcery skip: extract-method
|
||||
"""加载插件配置文件,支持版本检查和自动迁移"""
|
||||
if not self.config_file_name:
|
||||
logger.debug(f"{self.log_prefix} 未指定配置文件,跳过加载")
|
||||
return
|
||||
|
||||
# 优先使用传入的插件目录路径
|
||||
if self.plugin_dir:
|
||||
plugin_dir = self.plugin_dir
|
||||
else:
|
||||
# fallback:尝试从类的模块信息获取路径
|
||||
try:
|
||||
plugin_module_path = inspect.getfile(self.__class__)
|
||||
plugin_dir = os.path.dirname(plugin_module_path)
|
||||
except (TypeError, OSError):
|
||||
# 最后的fallback:从模块的__file__属性获取
|
||||
module = inspect.getmodule(self.__class__)
|
||||
if module and hasattr(module, "__file__") and module.__file__:
|
||||
plugin_dir = os.path.dirname(module.__file__)
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 无法获取插件目录路径,跳过配置加载")
|
||||
return
|
||||
|
||||
config_file_path = os.path.join(plugin_dir, self.config_file_name)
|
||||
|
||||
# 如果配置文件不存在,生成默认配置
|
||||
if not os.path.exists(config_file_path):
|
||||
logger.info(f"{self.log_prefix} 配置文件 {config_file_path} 不存在,将生成默认配置。")
|
||||
self._generate_and_save_default_config(config_file_path)
|
||||
|
||||
if not os.path.exists(config_file_path):
|
||||
logger.warning(f"{self.log_prefix} 配置文件 {config_file_path} 不存在且无法生成。")
|
||||
return
|
||||
|
||||
file_ext = os.path.splitext(self.config_file_name)[1].lower()
|
||||
|
||||
if file_ext == ".toml":
|
||||
# 加载现有配置
|
||||
with open(config_file_path, "r", encoding="utf-8") as f:
|
||||
existing_config = toml.load(f) or {}
|
||||
|
||||
# 检查配置版本
|
||||
current_version = self._get_current_config_version(existing_config)
|
||||
|
||||
# 如果配置文件没有版本信息,跳过版本检查
|
||||
if current_version == "0.0.0":
|
||||
logger.debug(f"{self.log_prefix} 配置文件无版本信息,跳过版本检查")
|
||||
self.config = existing_config
|
||||
else:
|
||||
expected_version = self._get_expected_config_version()
|
||||
|
||||
if current_version != expected_version:
|
||||
logger.info(
|
||||
f"{self.log_prefix} 检测到配置版本需要更新: 当前=v{current_version}, 期望=v{expected_version}"
|
||||
)
|
||||
|
||||
# 生成新的默认配置结构
|
||||
new_config_structure = self._generate_config_from_schema()
|
||||
|
||||
# 迁移旧配置值到新结构
|
||||
migrated_config = self._migrate_config_values(existing_config, new_config_structure)
|
||||
|
||||
# 保存迁移后的配置
|
||||
self._save_config_to_file(migrated_config, config_file_path)
|
||||
|
||||
logger.info(f"{self.log_prefix} 配置文件已从 v{current_version} 更新到 v{expected_version}")
|
||||
|
||||
self.config = migrated_config
|
||||
else:
|
||||
logger.debug(f"{self.log_prefix} 配置版本匹配 (v{current_version}),直接加载")
|
||||
self.config = existing_config
|
||||
|
||||
logger.debug(f"{self.log_prefix} 配置已从 {config_file_path} 加载")
|
||||
|
||||
# 从配置中更新 enable_plugin
|
||||
if "plugin" in self.config and "enabled" in self.config["plugin"]:
|
||||
self.enable_plugin = self.config["plugin"]["enabled"] # type: ignore
|
||||
logger.debug(f"{self.log_prefix} 从配置更新插件启用状态: {self.enable_plugin}")
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 不支持的配置文件格式: {file_ext},仅支持 .toml")
|
||||
self.config = {}
|
||||
|
||||
def _check_dependencies(self) -> bool:
|
||||
"""检查插件依赖"""
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
if not self.dependencies:
|
||||
return True
|
||||
|
||||
for dep in self.dependencies:
|
||||
if not component_registry.get_plugin_info(dep):
|
||||
logger.error(f"{self.log_prefix} 缺少依赖插件: {dep}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_config(self, key: str, default: Any = None) -> Any:
|
||||
"""获取插件配置值,支持嵌套键访问
|
||||
|
||||
Args:
|
||||
key: 配置键名,支持嵌套访问如 "section.subsection.key"
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
Any: 配置值或默认值
|
||||
"""
|
||||
# 支持嵌套键访问
|
||||
keys = key.split(".")
|
||||
current = self.config
|
||||
|
||||
for k in keys:
|
||||
if isinstance(current, dict) and k in current:
|
||||
current = current[k]
|
||||
else:
|
||||
return default
|
||||
|
||||
return current
|
||||
|
||||
@abstractmethod
|
||||
def register_plugin(self) -> bool:
|
||||
"""
|
||||
注册插件到插件管理器
|
||||
|
||||
子类必须实现此方法,返回注册是否成功
|
||||
|
||||
Returns:
|
||||
bool: 是否成功注册插件
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement this method")
|
||||
19
src/plugin_system/core/__init__.py
Normal file
19
src/plugin_system/core/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
插件核心管理模块
|
||||
|
||||
提供插件的加载、注册和管理功能
|
||||
"""
|
||||
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
from src.plugin_system.core.events_manager import events_manager
|
||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||
from src.plugin_system.core.plugin_hot_reload import hot_reload_manager
|
||||
|
||||
__all__ = [
|
||||
"plugin_manager",
|
||||
"component_registry",
|
||||
"events_manager",
|
||||
"global_announcement_manager",
|
||||
"hot_reload_manager",
|
||||
]
|
||||
688
src/plugin_system/core/component_registry.py
Normal file
688
src/plugin_system/core/component_registry.py
Normal file
@@ -0,0 +1,688 @@
|
||||
import re
|
||||
|
||||
from typing import Dict, List, Optional, Any, Pattern, Tuple, Union, Type
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import (
|
||||
ComponentInfo,
|
||||
ActionInfo,
|
||||
ToolInfo,
|
||||
CommandInfo,
|
||||
EventHandlerInfo,
|
||||
PluginInfo,
|
||||
ComponentType,
|
||||
)
|
||||
from src.plugin_system.base.base_command import BaseCommand
|
||||
from src.plugin_system.base.base_action import BaseAction
|
||||
from src.plugin_system.base.base_tool import BaseTool
|
||||
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
||||
|
||||
logger = get_logger("component_registry")
|
||||
|
||||
|
||||
class ComponentRegistry:
|
||||
"""统一的组件注册中心
|
||||
|
||||
负责管理所有插件组件的注册、查询和生命周期管理
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# 命名空间式组件名构成法 f"{component_type}.{component_name}"
|
||||
self._components: Dict[str, ComponentInfo] = {}
|
||||
"""组件注册表 命名空间式组件名 -> 组件信息"""
|
||||
self._components_by_type: Dict[ComponentType, Dict[str, ComponentInfo]] = {types: {} for types in ComponentType}
|
||||
"""类型 -> 组件原名称 -> 组件信息"""
|
||||
self._components_classes: Dict[str, Type[Union[BaseCommand, BaseAction, BaseTool, BaseEventHandler]]] = {}
|
||||
"""命名空间式组件名 -> 组件类"""
|
||||
|
||||
# 插件注册表
|
||||
self._plugins: Dict[str, PluginInfo] = {}
|
||||
"""插件名 -> 插件信息"""
|
||||
|
||||
# Action特定注册表
|
||||
self._action_registry: Dict[str, Type[BaseAction]] = {}
|
||||
"""Action注册表 action名 -> action类"""
|
||||
self._default_actions: Dict[str, ActionInfo] = {}
|
||||
"""默认动作集,即启用的Action集,用于重置ActionManager状态"""
|
||||
|
||||
# Command特定注册表
|
||||
self._command_registry: Dict[str, Type[BaseCommand]] = {}
|
||||
"""Command类注册表 command名 -> command类"""
|
||||
self._command_patterns: Dict[Pattern, str] = {}
|
||||
"""编译后的正则 -> command名"""
|
||||
|
||||
# 工具特定注册表
|
||||
self._tool_registry: Dict[str, Type[BaseTool]] = {} # 工具名 -> 工具类
|
||||
self._llm_available_tools: Dict[str, Type[BaseTool]] = {} # llm可用的工具名 -> 工具类
|
||||
|
||||
# EventHandler特定注册表
|
||||
self._event_handler_registry: Dict[str, Type[BaseEventHandler]] = {}
|
||||
"""event_handler名 -> event_handler类"""
|
||||
self._enabled_event_handlers: Dict[str, Type[BaseEventHandler]] = {}
|
||||
"""启用的事件处理器 event_handler名 -> event_handler类"""
|
||||
|
||||
logger.info("组件注册中心初始化完成")
|
||||
|
||||
# == 注册方法 ==
|
||||
|
||||
def register_plugin(self, plugin_info: PluginInfo) -> bool:
|
||||
"""注册插件
|
||||
|
||||
Args:
|
||||
plugin_info: 插件信息
|
||||
|
||||
Returns:
|
||||
bool: 是否注册成功
|
||||
"""
|
||||
plugin_name = plugin_info.name
|
||||
|
||||
if plugin_name in self._plugins:
|
||||
logger.warning(f"插件 {plugin_name} 已存在,跳过注册")
|
||||
return False
|
||||
|
||||
self._plugins[plugin_name] = plugin_info
|
||||
logger.debug(f"已注册插件: {plugin_name} (组件数量: {len(plugin_info.components)})")
|
||||
return True
|
||||
|
||||
def register_component(
|
||||
self,
|
||||
component_info: ComponentInfo,
|
||||
component_class: Type[Union[BaseCommand, BaseAction, BaseEventHandler, BaseTool]],
|
||||
) -> bool:
|
||||
"""注册组件
|
||||
|
||||
Args:
|
||||
component_info (ComponentInfo): 组件信息
|
||||
component_class (Type[Union[BaseCommand, BaseAction, BaseEventHandler]]): 组件类
|
||||
|
||||
Returns:
|
||||
bool: 是否注册成功
|
||||
"""
|
||||
component_name = component_info.name
|
||||
component_type = component_info.component_type
|
||||
plugin_name = getattr(component_info, "plugin_name", "unknown")
|
||||
if "." in component_name:
|
||||
logger.error(f"组件名称 '{component_name}' 包含非法字符 '.',请使用下划线替代")
|
||||
return False
|
||||
if "." in plugin_name:
|
||||
logger.error(f"插件名称 '{plugin_name}' 包含非法字符 '.',请使用下划线替代")
|
||||
return False
|
||||
|
||||
namespaced_name = f"{component_type}.{component_name}"
|
||||
|
||||
if namespaced_name in self._components:
|
||||
existing_info = self._components[namespaced_name]
|
||||
existing_plugin = getattr(existing_info, "plugin_name", "unknown")
|
||||
|
||||
logger.warning(
|
||||
f"组件名冲突: '{plugin_name}' 插件的 {component_type} 类型组件 '{component_name}' 已被插件 '{existing_plugin}' 注册,跳过此组件注册"
|
||||
)
|
||||
return False
|
||||
|
||||
self._components[namespaced_name] = component_info # 注册到通用注册表(使用命名空间化的名称)
|
||||
self._components_by_type[component_type][component_name] = component_info # 类型内部仍使用原名
|
||||
self._components_classes[namespaced_name] = component_class
|
||||
|
||||
# 根据组件类型进行特定注册(使用原始名称)
|
||||
match component_type:
|
||||
case ComponentType.ACTION:
|
||||
assert isinstance(component_info, ActionInfo)
|
||||
assert issubclass(component_class, BaseAction)
|
||||
ret = self._register_action_component(component_info, component_class)
|
||||
case ComponentType.COMMAND:
|
||||
assert isinstance(component_info, CommandInfo)
|
||||
assert issubclass(component_class, BaseCommand)
|
||||
ret = self._register_command_component(component_info, component_class)
|
||||
case ComponentType.TOOL:
|
||||
assert isinstance(component_info, ToolInfo)
|
||||
assert issubclass(component_class, BaseTool)
|
||||
ret = self._register_tool_component(component_info, component_class)
|
||||
case ComponentType.EVENT_HANDLER:
|
||||
assert isinstance(component_info, EventHandlerInfo)
|
||||
assert issubclass(component_class, BaseEventHandler)
|
||||
ret = self._register_event_handler_component(component_info, component_class)
|
||||
case _:
|
||||
logger.warning(f"未知组件类型: {component_type}")
|
||||
|
||||
if not ret:
|
||||
return False
|
||||
logger.debug(
|
||||
f"已注册{component_type}组件: '{component_name}' -> '{namespaced_name}' "
|
||||
f"({component_class.__name__}) [插件: {plugin_name}]"
|
||||
)
|
||||
return True
|
||||
|
||||
def _register_action_component(self, action_info: ActionInfo, action_class: Type[BaseAction]) -> bool:
|
||||
"""注册Action组件到Action特定注册表"""
|
||||
if not (action_name := action_info.name):
|
||||
logger.error(f"Action组件 {action_class.__name__} 必须指定名称")
|
||||
return False
|
||||
if not isinstance(action_info, ActionInfo) or not issubclass(action_class, BaseAction):
|
||||
logger.error(f"注册失败: {action_name} 不是有效的Action")
|
||||
return False
|
||||
|
||||
self._action_registry[action_name] = action_class
|
||||
|
||||
# 如果启用,添加到默认动作集
|
||||
if action_info.enabled:
|
||||
self._default_actions[action_name] = action_info
|
||||
|
||||
return True
|
||||
|
||||
def _register_command_component(self, command_info: CommandInfo, command_class: Type[BaseCommand]) -> bool:
|
||||
"""注册Command组件到Command特定注册表"""
|
||||
if not (command_name := command_info.name):
|
||||
logger.error(f"Command组件 {command_class.__name__} 必须指定名称")
|
||||
return False
|
||||
if not isinstance(command_info, CommandInfo) or not issubclass(command_class, BaseCommand):
|
||||
logger.error(f"注册失败: {command_name} 不是有效的Command")
|
||||
return False
|
||||
|
||||
self._command_registry[command_name] = command_class
|
||||
|
||||
# 如果启用了且有匹配模式
|
||||
if command_info.enabled and command_info.command_pattern:
|
||||
pattern = re.compile(command_info.command_pattern, re.IGNORECASE | re.DOTALL)
|
||||
if pattern not in self._command_patterns:
|
||||
self._command_patterns[pattern] = command_name
|
||||
else:
|
||||
logger.warning(
|
||||
f"'{command_name}' 对应的命令模式与 '{self._command_patterns[pattern]}' 重复,忽略此命令"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
def _register_tool_component(self, tool_info: ToolInfo, tool_class: Type[BaseTool]) -> bool:
|
||||
"""注册Tool组件到Tool特定注册表"""
|
||||
tool_name = tool_info.name
|
||||
|
||||
self._tool_registry[tool_name] = tool_class
|
||||
|
||||
# 如果是llm可用的且启用的工具,添加到 llm可用工具列表
|
||||
if tool_info.enabled:
|
||||
self._llm_available_tools[tool_name] = tool_class
|
||||
|
||||
return True
|
||||
|
||||
def _register_event_handler_component(
|
||||
self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler]
|
||||
) -> bool:
|
||||
if not (handler_name := handler_info.name):
|
||||
logger.error(f"EventHandler组件 {handler_class.__name__} 必须指定名称")
|
||||
return False
|
||||
if not isinstance(handler_info, EventHandlerInfo) or not issubclass(handler_class, BaseEventHandler):
|
||||
logger.error(f"注册失败: {handler_name} 不是有效的EventHandler")
|
||||
return False
|
||||
|
||||
self._event_handler_registry[handler_name] = handler_class
|
||||
|
||||
if not handler_info.enabled:
|
||||
logger.warning(f"EventHandler组件 {handler_name} 未启用")
|
||||
return True # 未启用,但是也是注册成功
|
||||
|
||||
from .events_manager import events_manager # 延迟导入防止循环导入问题
|
||||
|
||||
if events_manager.register_event_subscriber(handler_info, handler_class):
|
||||
self._enabled_event_handlers[handler_name] = handler_class
|
||||
return True
|
||||
else:
|
||||
logger.error(f"注册事件处理器 {handler_name} 失败")
|
||||
return False
|
||||
|
||||
# === 组件移除相关 ===
|
||||
|
||||
async def remove_component(self, component_name: str, component_type: ComponentType, plugin_name: str) -> bool:
|
||||
target_component_class = self.get_component_class(component_name, component_type)
|
||||
if not target_component_class:
|
||||
logger.warning(f"组件 {component_name} 未注册,无法移除")
|
||||
return False
|
||||
try:
|
||||
# 根据组件类型进行特定的清理操作
|
||||
match component_type:
|
||||
case ComponentType.ACTION:
|
||||
# 移除Action注册
|
||||
self._action_registry.pop(component_name, None)
|
||||
self._default_actions.pop(component_name, None)
|
||||
logger.debug(f"已移除Action组件: {component_name}")
|
||||
|
||||
case ComponentType.COMMAND:
|
||||
# 移除Command注册和模式
|
||||
self._command_registry.pop(component_name, None)
|
||||
keys_to_remove = [k for k, v in self._command_patterns.items() if v == component_name]
|
||||
for key in keys_to_remove:
|
||||
self._command_patterns.pop(key, None)
|
||||
logger.debug(f"已移除Command组件: {component_name} (清理了 {len(keys_to_remove)} 个模式)")
|
||||
|
||||
case ComponentType.TOOL:
|
||||
# 移除Tool注册
|
||||
self._tool_registry.pop(component_name, None)
|
||||
self._llm_available_tools.pop(component_name, None)
|
||||
logger.debug(f"已移除Tool组件: {component_name}")
|
||||
|
||||
case ComponentType.EVENT_HANDLER:
|
||||
# 移除EventHandler注册和事件订阅
|
||||
from .events_manager import events_manager # 延迟导入防止循环导入问题
|
||||
|
||||
self._event_handler_registry.pop(component_name, None)
|
||||
self._enabled_event_handlers.pop(component_name, None)
|
||||
try:
|
||||
await events_manager.unregister_event_subscriber(component_name)
|
||||
logger.debug(f"已移除EventHandler组件: {component_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"移除EventHandler事件订阅时出错: {e}")
|
||||
|
||||
case _:
|
||||
logger.warning(f"未知的组件类型: {component_type}")
|
||||
return False
|
||||
|
||||
# 移除通用注册信息
|
||||
namespaced_name = f"{component_type}.{component_name}"
|
||||
self._components.pop(namespaced_name, None)
|
||||
self._components_by_type[component_type].pop(component_name, None)
|
||||
self._components_classes.pop(namespaced_name, None)
|
||||
|
||||
logger.info(f"组件 {component_name} ({component_type}) 已完全移除")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"移除组件 {component_name} ({component_type}) 时发生错误: {e}")
|
||||
return False
|
||||
|
||||
def remove_plugin_registry(self, plugin_name: str) -> bool:
|
||||
"""移除插件注册信息
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
bool: 是否成功移除
|
||||
"""
|
||||
if plugin_name not in self._plugins:
|
||||
logger.warning(f"插件 {plugin_name} 未注册,无法移除")
|
||||
return False
|
||||
del self._plugins[plugin_name]
|
||||
logger.info(f"插件 {plugin_name} 已移除")
|
||||
return True
|
||||
|
||||
# === 组件全局启用/禁用方法 ===
|
||||
|
||||
def enable_component(self, component_name: str, component_type: ComponentType) -> bool:
|
||||
"""全局的启用某个组件
|
||||
Parameters:
|
||||
component_name: 组件名称
|
||||
component_type: 组件类型
|
||||
Returns:
|
||||
bool: 启用成功返回True,失败返回False
|
||||
"""
|
||||
target_component_class = self.get_component_class(component_name, component_type)
|
||||
target_component_info = self.get_component_info(component_name, component_type)
|
||||
if not target_component_class or not target_component_info:
|
||||
logger.warning(f"组件 {component_name} 未注册,无法启用")
|
||||
return False
|
||||
target_component_info.enabled = True
|
||||
match component_type:
|
||||
case ComponentType.ACTION:
|
||||
assert isinstance(target_component_info, ActionInfo)
|
||||
self._default_actions[component_name] = target_component_info
|
||||
case ComponentType.COMMAND:
|
||||
assert isinstance(target_component_info, CommandInfo)
|
||||
pattern = target_component_info.command_pattern
|
||||
self._command_patterns[re.compile(pattern)] = component_name
|
||||
case ComponentType.TOOL:
|
||||
assert isinstance(target_component_info, ToolInfo)
|
||||
assert issubclass(target_component_class, BaseTool)
|
||||
self._llm_available_tools[component_name] = target_component_class
|
||||
case ComponentType.EVENT_HANDLER:
|
||||
assert isinstance(target_component_info, EventHandlerInfo)
|
||||
assert issubclass(target_component_class, BaseEventHandler)
|
||||
self._enabled_event_handlers[component_name] = target_component_class
|
||||
from .events_manager import events_manager # 延迟导入防止循环导入问题
|
||||
|
||||
events_manager.register_event_subscriber(target_component_info, target_component_class)
|
||||
namespaced_name = f"{component_type}.{component_name}"
|
||||
self._components[namespaced_name].enabled = True
|
||||
self._components_by_type[component_type][component_name].enabled = True
|
||||
logger.info(f"组件 {component_name} 已启用")
|
||||
return True
|
||||
|
||||
async def disable_component(self, component_name: str, component_type: ComponentType) -> bool:
|
||||
"""全局的禁用某个组件
|
||||
Parameters:
|
||||
component_name: 组件名称
|
||||
component_type: 组件类型
|
||||
Returns:
|
||||
bool: 禁用成功返回True,失败返回False
|
||||
"""
|
||||
target_component_class = self.get_component_class(component_name, component_type)
|
||||
target_component_info = self.get_component_info(component_name, component_type)
|
||||
if not target_component_class or not target_component_info:
|
||||
logger.warning(f"组件 {component_name} 未注册,无法禁用")
|
||||
return False
|
||||
target_component_info.enabled = False
|
||||
try:
|
||||
match component_type:
|
||||
case ComponentType.ACTION:
|
||||
self._default_actions.pop(component_name)
|
||||
case ComponentType.COMMAND:
|
||||
self._command_patterns = {k: v for k, v in self._command_patterns.items() if v != component_name}
|
||||
case ComponentType.TOOL:
|
||||
self._llm_available_tools.pop(component_name)
|
||||
case ComponentType.EVENT_HANDLER:
|
||||
self._enabled_event_handlers.pop(component_name)
|
||||
from .events_manager import events_manager # 延迟导入防止循环导入问题
|
||||
|
||||
await events_manager.unregister_event_subscriber(component_name)
|
||||
self._components[component_name].enabled = False
|
||||
self._components_by_type[component_type][component_name].enabled = False
|
||||
logger.info(f"组件 {component_name} 已禁用")
|
||||
return True
|
||||
except KeyError as e:
|
||||
logger.warning(f"禁用组件时未找到组件或已禁用: {component_name}, 发生错误: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"禁用组件 {component_name} 时发生错误: {e}")
|
||||
return False
|
||||
|
||||
# === 组件查询方法 ===
|
||||
def get_component_info(
|
||||
self, component_name: str, component_type: Optional[ComponentType] = None
|
||||
) -> Optional[ComponentInfo]:
|
||||
# sourcery skip: class-extract-method
|
||||
"""获取组件信息,支持自动命名空间解析
|
||||
|
||||
Args:
|
||||
component_name: 组件名称,可以是原始名称或命名空间化的名称
|
||||
component_type: 组件类型,如果提供则优先在该类型中查找
|
||||
|
||||
Returns:
|
||||
Optional[ComponentInfo]: 组件信息或None
|
||||
"""
|
||||
# 1. 如果已经是命名空间化的名称,直接查找
|
||||
if "." in component_name:
|
||||
return self._components.get(component_name)
|
||||
|
||||
# 2. 如果指定了组件类型,构造命名空间化的名称查找
|
||||
if component_type:
|
||||
namespaced_name = f"{component_type}.{component_name}"
|
||||
return self._components.get(namespaced_name)
|
||||
|
||||
# 3. 如果没有指定类型,尝试在所有命名空间中查找
|
||||
candidates = []
|
||||
for namespace_prefix in [types.value for types in ComponentType]:
|
||||
namespaced_name = f"{namespace_prefix}.{component_name}"
|
||||
if component_info := self._components.get(namespaced_name):
|
||||
candidates.append((namespace_prefix, namespaced_name, component_info))
|
||||
|
||||
if len(candidates) == 1:
|
||||
# 只有一个匹配,直接返回
|
||||
return candidates[0][2]
|
||||
elif len(candidates) > 1:
|
||||
# 多个匹配,记录警告并返回第一个
|
||||
namespaces = [ns for ns, _, _ in candidates]
|
||||
logger.warning(
|
||||
f"组件名称 '{component_name}' 在多个命名空间中存在: {namespaces},使用第一个匹配项: {candidates[0][1]}"
|
||||
)
|
||||
return candidates[0][2]
|
||||
|
||||
# 4. 都没找到
|
||||
return None
|
||||
|
||||
def get_component_class(
|
||||
self,
|
||||
component_name: str,
|
||||
component_type: Optional[ComponentType] = None,
|
||||
) -> Optional[Union[Type[BaseCommand], Type[BaseAction], Type[BaseEventHandler], Type[BaseTool]]]:
|
||||
"""获取组件类,支持自动命名空间解析
|
||||
|
||||
Args:
|
||||
component_name: 组件名称,可以是原始名称或命名空间化的名称
|
||||
component_type: 组件类型,如果提供则优先在该类型中查找
|
||||
|
||||
Returns:
|
||||
Optional[Union[BaseCommand, BaseAction]]: 组件类或None
|
||||
"""
|
||||
# 1. 如果已经是命名空间化的名称,直接查找
|
||||
if "." in component_name:
|
||||
return self._components_classes.get(component_name)
|
||||
|
||||
# 2. 如果指定了组件类型,构造命名空间化的名称查找
|
||||
if component_type:
|
||||
namespaced_name = f"{component_type.value}.{component_name}"
|
||||
return self._components_classes.get(namespaced_name)
|
||||
|
||||
# 3. 如果没有指定类型,尝试在所有命名空间中查找
|
||||
candidates = []
|
||||
for namespace_prefix in [types.value for types in ComponentType]:
|
||||
namespaced_name = f"{namespace_prefix}.{component_name}"
|
||||
if component_class := self._components_classes.get(namespaced_name):
|
||||
candidates.append((namespace_prefix, namespaced_name, component_class))
|
||||
|
||||
if len(candidates) == 1:
|
||||
# 只有一个匹配,直接返回
|
||||
_, full_name, cls = candidates[0]
|
||||
logger.debug(f"自动解析组件: '{component_name}' -> '{full_name}'")
|
||||
return cls
|
||||
elif len(candidates) > 1:
|
||||
# 多个匹配,记录警告并返回第一个
|
||||
namespaces = [ns for ns, _, _ in candidates]
|
||||
logger.warning(
|
||||
f"组件名称 '{component_name}' 在多个命名空间中存在: {namespaces},使用第一个匹配项: {candidates[0][1]}"
|
||||
)
|
||||
return candidates[0][2]
|
||||
|
||||
# 4. 都没找到
|
||||
return None
|
||||
|
||||
def get_components_by_type(self, component_type: ComponentType) -> Dict[str, ComponentInfo]:
|
||||
"""获取指定类型的所有组件"""
|
||||
return self._components_by_type.get(component_type, {}).copy()
|
||||
|
||||
def get_enabled_components_by_type(self, component_type: ComponentType) -> Dict[str, ComponentInfo]:
|
||||
"""获取指定类型的所有启用组件"""
|
||||
components = self.get_components_by_type(component_type)
|
||||
return {name: info for name, info in components.items() if info.enabled}
|
||||
|
||||
# === Action特定查询方法 ===
|
||||
|
||||
def get_action_registry(self) -> Dict[str, Type[BaseAction]]:
|
||||
"""获取Action注册表"""
|
||||
return self._action_registry.copy()
|
||||
|
||||
def get_registered_action_info(self, action_name: str) -> Optional[ActionInfo]:
|
||||
"""获取Action信息"""
|
||||
info = self.get_component_info(action_name, ComponentType.ACTION)
|
||||
return info if isinstance(info, ActionInfo) else None
|
||||
|
||||
def get_default_actions(self) -> Dict[str, ActionInfo]:
|
||||
"""获取默认动作集"""
|
||||
return self._default_actions.copy()
|
||||
|
||||
# === Command特定查询方法 ===
|
||||
|
||||
def get_command_registry(self) -> Dict[str, Type[BaseCommand]]:
|
||||
"""获取Command注册表"""
|
||||
return self._command_registry.copy()
|
||||
|
||||
def get_registered_command_info(self, command_name: str) -> Optional[CommandInfo]:
|
||||
"""获取Command信息"""
|
||||
info = self.get_component_info(command_name, ComponentType.COMMAND)
|
||||
return info if isinstance(info, CommandInfo) else None
|
||||
|
||||
def get_command_patterns(self) -> Dict[Pattern, str]:
|
||||
"""获取Command模式注册表"""
|
||||
return self._command_patterns.copy()
|
||||
|
||||
def find_command_by_text(self, text: str) -> Optional[Tuple[Type[BaseCommand], dict, CommandInfo]]:
|
||||
# sourcery skip: use-named-expression, use-next
|
||||
"""根据文本查找匹配的命令
|
||||
|
||||
Args:
|
||||
text: 输入文本
|
||||
|
||||
Returns:
|
||||
Tuple: (命令类, 匹配的命名组, 是否拦截消息, 插件名) 或 None
|
||||
"""
|
||||
|
||||
candidates = [pattern for pattern in self._command_patterns if pattern.match(text)]
|
||||
if not candidates:
|
||||
return None
|
||||
if len(candidates) > 1:
|
||||
logger.warning(f"文本 '{text}' 匹配到多个命令模式: {candidates},使用第一个匹配")
|
||||
command_name = self._command_patterns[candidates[0]]
|
||||
command_info: CommandInfo = self.get_registered_command_info(command_name) # type: ignore
|
||||
return (
|
||||
self._command_registry[command_name],
|
||||
candidates[0].match(text).groupdict(), # type: ignore
|
||||
command_info,
|
||||
)
|
||||
|
||||
# === Tool 特定查询方法 ===
|
||||
def get_tool_registry(self) -> Dict[str, Type[BaseTool]]:
|
||||
"""获取Tool注册表"""
|
||||
return self._tool_registry.copy()
|
||||
|
||||
def get_llm_available_tools(self) -> Dict[str, Type[BaseTool]]:
|
||||
"""获取LLM可用的Tool列表"""
|
||||
return self._llm_available_tools.copy()
|
||||
|
||||
def get_registered_tool_info(self, tool_name: str) -> Optional[ToolInfo]:
|
||||
"""获取Tool信息
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
|
||||
Returns:
|
||||
ToolInfo: 工具信息对象,如果工具不存在则返回 None
|
||||
"""
|
||||
info = self.get_component_info(tool_name, ComponentType.TOOL)
|
||||
return info if isinstance(info, ToolInfo) else None
|
||||
|
||||
# === EventHandler 特定查询方法 ===
|
||||
|
||||
def get_event_handler_registry(self) -> Dict[str, Type[BaseEventHandler]]:
|
||||
"""获取事件处理器注册表"""
|
||||
return self._event_handler_registry.copy()
|
||||
|
||||
def get_registered_event_handler_info(self, handler_name: str) -> Optional[EventHandlerInfo]:
|
||||
"""获取事件处理器信息"""
|
||||
info = self.get_component_info(handler_name, ComponentType.EVENT_HANDLER)
|
||||
return info if isinstance(info, EventHandlerInfo) else None
|
||||
|
||||
def get_enabled_event_handlers(self) -> Dict[str, Type[BaseEventHandler]]:
|
||||
"""获取启用的事件处理器"""
|
||||
return self._enabled_event_handlers.copy()
|
||||
|
||||
# === 插件查询方法 ===
|
||||
|
||||
def get_plugin_info(self, plugin_name: str) -> Optional[PluginInfo]:
|
||||
"""获取插件信息"""
|
||||
return self._plugins.get(plugin_name)
|
||||
|
||||
def get_all_plugins(self) -> Dict[str, PluginInfo]:
|
||||
"""获取所有插件"""
|
||||
return self._plugins.copy()
|
||||
|
||||
# def get_enabled_plugins(self) -> Dict[str, PluginInfo]:
|
||||
# """获取所有启用的插件"""
|
||||
# return {name: info for name, info in self._plugins.items() if info.enabled}
|
||||
|
||||
def get_plugin_components(self, plugin_name: str) -> List[ComponentInfo]:
|
||||
"""获取插件的所有组件"""
|
||||
plugin_info = self.get_plugin_info(plugin_name)
|
||||
return plugin_info.components if plugin_info else []
|
||||
|
||||
def get_plugin_config(self, plugin_name: str) -> Optional[dict]:
|
||||
"""获取插件配置
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
Optional[dict]: 插件配置字典或None
|
||||
"""
|
||||
# 从插件管理器获取插件实例的配置
|
||||
from src.plugin_system.core.plugin_manager import plugin_manager
|
||||
|
||||
plugin_instance = plugin_manager.get_plugin_instance(plugin_name)
|
||||
return plugin_instance.config if plugin_instance else None
|
||||
|
||||
def get_registry_stats(self) -> Dict[str, Any]:
|
||||
"""获取注册中心统计信息"""
|
||||
action_components: int = 0
|
||||
command_components: int = 0
|
||||
tool_components: int = 0
|
||||
events_handlers: int = 0
|
||||
for component in self._components.values():
|
||||
if component.component_type == ComponentType.ACTION:
|
||||
action_components += 1
|
||||
elif component.component_type == ComponentType.COMMAND:
|
||||
command_components += 1
|
||||
elif component.component_type == ComponentType.TOOL:
|
||||
tool_components += 1
|
||||
elif component.component_type == ComponentType.EVENT_HANDLER:
|
||||
events_handlers += 1
|
||||
return {
|
||||
"action_components": action_components,
|
||||
"command_components": command_components,
|
||||
"tool_components": tool_components,
|
||||
"event_handlers": events_handlers,
|
||||
"total_components": len(self._components),
|
||||
"total_plugins": len(self._plugins),
|
||||
"components_by_type": {
|
||||
component_type.value: len(components) for component_type, components in self._components_by_type.items()
|
||||
},
|
||||
"enabled_components": len([c for c in self._components.values() if c.enabled]),
|
||||
"enabled_plugins": len([p for p in self._plugins.values() if p.enabled]),
|
||||
}
|
||||
|
||||
# === 组件移除相关 ===
|
||||
|
||||
async def unregister_plugin(self, plugin_name: str) -> bool:
|
||||
"""卸载插件及其所有组件
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
bool: 是否成功卸载
|
||||
"""
|
||||
plugin_info = self.get_plugin_info(plugin_name)
|
||||
if not plugin_info:
|
||||
logger.warning(f"插件 {plugin_name} 未注册,无法卸载")
|
||||
return False
|
||||
|
||||
logger.info(f"开始卸载插件: {plugin_name}")
|
||||
|
||||
# 记录卸载失败的组件
|
||||
failed_components = []
|
||||
|
||||
# 逐个移除插件的所有组件
|
||||
for component_info in plugin_info.components:
|
||||
try:
|
||||
success = await self.remove_component(
|
||||
component_info.name,
|
||||
component_info.component_type,
|
||||
plugin_name,
|
||||
)
|
||||
if not success:
|
||||
failed_components.append(f"{component_info.component_type}.{component_info.name}")
|
||||
except Exception as e:
|
||||
logger.error(f"移除组件 {component_info.name} 时发生异常: {e}")
|
||||
failed_components.append(f"{component_info.component_type}.{component_info.name}")
|
||||
|
||||
# 移除插件注册信息
|
||||
plugin_removed = self.remove_plugin_registry(plugin_name)
|
||||
|
||||
if failed_components:
|
||||
logger.warning(f"插件 {plugin_name} 部分组件卸载失败: {failed_components}")
|
||||
return False
|
||||
elif not plugin_removed:
|
||||
logger.error(f"插件 {plugin_name} 注册信息移除失败")
|
||||
return False
|
||||
else:
|
||||
logger.info(f"插件 {plugin_name} 卸载成功")
|
||||
return True
|
||||
|
||||
|
||||
# 创建全局组件注册中心实例
|
||||
component_registry = ComponentRegistry()
|
||||
262
src/plugin_system/core/events_manager.py
Normal file
262
src/plugin_system/core/events_manager.py
Normal file
@@ -0,0 +1,262 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
from typing import List, Dict, Optional, Type, Tuple, Any
|
||||
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import EventType, EventHandlerInfo, MaiMessages
|
||||
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
||||
from .global_announcement_manager import global_announcement_manager
|
||||
|
||||
logger = get_logger("events_manager")
|
||||
|
||||
|
||||
class EventsManager:
|
||||
def __init__(self):
|
||||
# 有权重的 events 订阅者注册表
|
||||
self._events_subscribers: Dict[EventType, List[BaseEventHandler]] = {event: [] for event in EventType}
|
||||
self._handler_mapping: Dict[str, Type[BaseEventHandler]] = {} # 事件处理器映射表
|
||||
self._handler_tasks: Dict[str, List[asyncio.Task]] = {} # 事件处理器正在处理的任务
|
||||
|
||||
def register_event_subscriber(self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler]) -> bool:
|
||||
"""注册事件处理器
|
||||
|
||||
Args:
|
||||
handler_info (EventHandlerInfo): 事件处理器信息
|
||||
handler_class (Type[BaseEventHandler]): 事件处理器类
|
||||
|
||||
Returns:
|
||||
bool: 是否注册成功
|
||||
"""
|
||||
handler_name = handler_info.name
|
||||
|
||||
if handler_name in self._handler_mapping:
|
||||
logger.warning(f"事件处理器 {handler_name} 已存在,跳过注册")
|
||||
return True
|
||||
|
||||
if not issubclass(handler_class, BaseEventHandler):
|
||||
logger.error(f"类 {handler_class.__name__} 不是 BaseEventHandler 的子类")
|
||||
return False
|
||||
|
||||
self._handler_mapping[handler_name] = handler_class
|
||||
return self._insert_event_handler(handler_class, handler_info)
|
||||
|
||||
async def handle_mai_events(
|
||||
self,
|
||||
event_type: EventType,
|
||||
message: Optional[MessageRecv] = None,
|
||||
llm_prompt: Optional[str] = None,
|
||||
llm_response: Optional[Dict[str, Any]] = None,
|
||||
stream_id: Optional[str] = None,
|
||||
action_usage: Optional[List[str]] = None,
|
||||
) -> bool:
|
||||
"""处理 events"""
|
||||
from src.plugin_system.core import component_registry
|
||||
|
||||
continue_flag = True
|
||||
transformed_message: Optional[MaiMessages] = None
|
||||
if not message:
|
||||
assert stream_id, "如果没有消息,必须提供流ID"
|
||||
if event_type in [EventType.ON_MESSAGE, EventType.ON_PLAN, EventType.POST_LLM, EventType.AFTER_LLM]:
|
||||
transformed_message = self._build_message_from_stream(stream_id, llm_prompt, llm_response)
|
||||
else:
|
||||
transformed_message = self._transform_event_without_message(
|
||||
stream_id, llm_prompt, llm_response, action_usage
|
||||
)
|
||||
else:
|
||||
transformed_message = self._transform_event_message(message, llm_prompt, llm_response)
|
||||
for handler in self._events_subscribers.get(event_type, []):
|
||||
if transformed_message.stream_id:
|
||||
stream_id = transformed_message.stream_id
|
||||
if handler.handler_name in global_announcement_manager.get_disabled_chat_event_handlers(stream_id):
|
||||
continue
|
||||
handler.set_plugin_config(component_registry.get_plugin_config(handler.plugin_name) or {})
|
||||
if handler.intercept_message:
|
||||
try:
|
||||
success, continue_processing, result = await handler.execute(transformed_message)
|
||||
if not success:
|
||||
logger.error(f"EventHandler {handler.handler_name} 执行失败: {result}")
|
||||
else:
|
||||
logger.debug(f"EventHandler {handler.handler_name} 执行成功: {result}")
|
||||
continue_flag = continue_flag and continue_processing
|
||||
except Exception as e:
|
||||
logger.error(f"EventHandler {handler.handler_name} 发生异常: {e}")
|
||||
continue
|
||||
else:
|
||||
try:
|
||||
handler_task = asyncio.create_task(handler.execute(transformed_message))
|
||||
handler_task.add_done_callback(self._task_done_callback)
|
||||
handler_task.set_name(f"{handler.plugin_name}-{handler.handler_name}")
|
||||
if handler.handler_name not in self._handler_tasks:
|
||||
self._handler_tasks[handler.handler_name] = []
|
||||
self._handler_tasks[handler.handler_name].append(handler_task)
|
||||
except Exception as e:
|
||||
logger.error(f"创建事件处理器任务 {handler.handler_name} 时发生异常: {e}")
|
||||
continue
|
||||
return continue_flag
|
||||
|
||||
def _insert_event_handler(self, handler_class: Type[BaseEventHandler], handler_info: EventHandlerInfo) -> bool:
|
||||
"""插入事件处理器到对应的事件类型列表中并设置其插件配置"""
|
||||
if handler_class.event_type == EventType.UNKNOWN:
|
||||
logger.error(f"事件处理器 {handler_class.__name__} 的事件类型未知,无法注册")
|
||||
return False
|
||||
|
||||
handler_instance = handler_class()
|
||||
handler_instance.set_plugin_name(handler_info.plugin_name or "unknown")
|
||||
self._events_subscribers[handler_class.event_type].append(handler_instance)
|
||||
self._events_subscribers[handler_class.event_type].sort(key=lambda x: x.weight, reverse=True)
|
||||
|
||||
return True
|
||||
|
||||
def _remove_event_handler_instance(self, handler_class: Type[BaseEventHandler]) -> bool:
|
||||
"""从事件类型列表中移除事件处理器"""
|
||||
display_handler_name = handler_class.handler_name or handler_class.__name__
|
||||
if handler_class.event_type == EventType.UNKNOWN:
|
||||
logger.warning(f"事件处理器 {display_handler_name} 的事件类型未知,不存在于处理器列表中")
|
||||
return False
|
||||
|
||||
handlers = self._events_subscribers[handler_class.event_type]
|
||||
for i, handler in enumerate(handlers):
|
||||
if isinstance(handler, handler_class):
|
||||
del handlers[i]
|
||||
logger.debug(f"事件处理器 {display_handler_name} 已移除")
|
||||
return True
|
||||
|
||||
logger.warning(f"未找到事件处理器 {display_handler_name},无法移除")
|
||||
return False
|
||||
|
||||
def _transform_event_message(
|
||||
self, message: MessageRecv, llm_prompt: Optional[str] = None, llm_response: Optional[Dict[str, Any]] = None
|
||||
) -> MaiMessages:
|
||||
"""转换事件消息格式"""
|
||||
# 直接赋值部分内容
|
||||
transformed_message = MaiMessages(
|
||||
llm_prompt=llm_prompt,
|
||||
llm_response_content=llm_response.get("content") if llm_response else None,
|
||||
llm_response_reasoning=llm_response.get("reasoning") if llm_response else None,
|
||||
llm_response_model=llm_response.get("model") if llm_response else None,
|
||||
llm_response_tool_call=llm_response.get("tool_calls") if llm_response else None,
|
||||
raw_message=message.raw_message,
|
||||
additional_data=message.message_info.additional_config or {},
|
||||
)
|
||||
|
||||
# 消息段处理
|
||||
if message.message_segment.type == "seglist":
|
||||
transformed_message.message_segments = list(message.message_segment.data) # type: ignore
|
||||
else:
|
||||
transformed_message.message_segments = [message.message_segment]
|
||||
|
||||
# stream_id 处理
|
||||
if hasattr(message, "chat_stream") and message.chat_stream:
|
||||
transformed_message.stream_id = message.chat_stream.stream_id
|
||||
|
||||
# 处理后文本
|
||||
transformed_message.plain_text = message.processed_plain_text
|
||||
|
||||
# 基本信息
|
||||
if hasattr(message, "message_info") and message.message_info:
|
||||
if message.message_info.platform:
|
||||
transformed_message.message_base_info["platform"] = message.message_info.platform
|
||||
if message.message_info.group_info:
|
||||
transformed_message.is_group_message = True
|
||||
transformed_message.message_base_info.update(
|
||||
{
|
||||
"group_id": message.message_info.group_info.group_id,
|
||||
"group_name": message.message_info.group_info.group_name,
|
||||
}
|
||||
)
|
||||
if message.message_info.user_info:
|
||||
if not transformed_message.is_group_message:
|
||||
transformed_message.is_private_message = True
|
||||
transformed_message.message_base_info.update(
|
||||
{
|
||||
"user_id": message.message_info.user_info.user_id,
|
||||
"user_cardname": message.message_info.user_info.user_cardname, # 用户群昵称
|
||||
"user_nickname": message.message_info.user_info.user_nickname, # 用户昵称(用户名)
|
||||
}
|
||||
)
|
||||
|
||||
return transformed_message
|
||||
|
||||
def _build_message_from_stream(
|
||||
self, stream_id: str, llm_prompt: Optional[str] = None, llm_response: Optional[Dict[str, Any]] = None
|
||||
) -> MaiMessages:
|
||||
"""从流ID构建消息"""
|
||||
chat_stream = get_chat_manager().get_stream(stream_id)
|
||||
assert chat_stream, f"未找到流ID为 {stream_id} 的聊天流"
|
||||
message = chat_stream.context.get_last_message()
|
||||
return self._transform_event_message(message, llm_prompt, llm_response)
|
||||
|
||||
def _transform_event_without_message(
|
||||
self,
|
||||
stream_id: str,
|
||||
llm_prompt: Optional[str] = None,
|
||||
llm_response: Optional[Dict[str, Any]] = None,
|
||||
action_usage: Optional[List[str]] = None,
|
||||
) -> MaiMessages:
|
||||
"""没有message对象时进行转换"""
|
||||
chat_stream = get_chat_manager().get_stream(stream_id)
|
||||
assert chat_stream, f"未找到流ID为 {stream_id} 的聊天流"
|
||||
return MaiMessages(
|
||||
stream_id=stream_id,
|
||||
llm_prompt=llm_prompt,
|
||||
llm_response_content=(llm_response.get("content") if llm_response else None),
|
||||
llm_response_reasoning=(llm_response.get("reasoning") if llm_response else None),
|
||||
llm_response_model=llm_response.get("model") if llm_response else None,
|
||||
llm_response_tool_call=(llm_response.get("tool_calls") if llm_response else None),
|
||||
is_group_message=(not (not chat_stream.group_info)),
|
||||
is_private_message=(not chat_stream.group_info),
|
||||
action_usage=action_usage,
|
||||
additional_data={"response_is_processed": True},
|
||||
)
|
||||
|
||||
def _task_done_callback(self, task: asyncio.Task[Tuple[bool, bool, str | None]]):
|
||||
"""任务完成回调"""
|
||||
task_name = task.get_name() or "Unknown Task"
|
||||
try:
|
||||
success, _, result = task.result() # 忽略是否继续的标志,因为消息本身未被拦截
|
||||
if success:
|
||||
logger.debug(f"事件处理任务 {task_name} 已成功完成: {result}")
|
||||
else:
|
||||
logger.error(f"事件处理任务 {task_name} 执行失败: {result}")
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"事件处理任务 {task_name} 发生异常: {e}")
|
||||
finally:
|
||||
with contextlib.suppress(ValueError, KeyError):
|
||||
self._handler_tasks[task_name].remove(task)
|
||||
|
||||
async def cancel_handler_tasks(self, handler_name: str) -> None:
|
||||
tasks_to_be_cancelled = self._handler_tasks.get(handler_name, [])
|
||||
if remaining_tasks := [task for task in tasks_to_be_cancelled if not task.done()]:
|
||||
for task in remaining_tasks:
|
||||
task.cancel()
|
||||
try:
|
||||
await asyncio.wait_for(asyncio.gather(*remaining_tasks, return_exceptions=True), timeout=5)
|
||||
logger.info(f"已取消事件处理器 {handler_name} 的所有任务")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"取消事件处理器 {handler_name} 的任务超时,开始强制取消")
|
||||
except Exception as e:
|
||||
logger.error(f"取消事件处理器 {handler_name} 的任务时发生异常: {e}")
|
||||
if handler_name in self._handler_tasks:
|
||||
del self._handler_tasks[handler_name]
|
||||
|
||||
async def unregister_event_subscriber(self, handler_name: str) -> bool:
|
||||
"""取消注册事件处理器"""
|
||||
if handler_name not in self._handler_mapping:
|
||||
logger.warning(f"事件处理器 {handler_name} 不存在,无法取消注册")
|
||||
return False
|
||||
|
||||
await self.cancel_handler_tasks(handler_name)
|
||||
|
||||
handler_class = self._handler_mapping.pop(handler_name)
|
||||
if not self._remove_event_handler_instance(handler_class):
|
||||
return False
|
||||
|
||||
logger.info(f"事件处理器 {handler_name} 已成功取消注册")
|
||||
return True
|
||||
|
||||
|
||||
events_manager = EventsManager()
|
||||
120
src/plugin_system/core/global_announcement_manager.py
Normal file
120
src/plugin_system/core/global_announcement_manager.py
Normal file
@@ -0,0 +1,120 @@
|
||||
from typing import List, Dict
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("global_announcement_manager")
|
||||
|
||||
|
||||
class GlobalAnnouncementManager:
|
||||
def __init__(self) -> None:
|
||||
# 用户禁用的动作,chat_id -> [action_name]
|
||||
self._user_disabled_actions: Dict[str, List[str]] = {}
|
||||
# 用户禁用的命令,chat_id -> [command_name]
|
||||
self._user_disabled_commands: Dict[str, List[str]] = {}
|
||||
# 用户禁用的事件处理器,chat_id -> [handler_name]
|
||||
self._user_disabled_event_handlers: Dict[str, List[str]] = {}
|
||||
# 用户禁用的工具,chat_id -> [tool_name]
|
||||
self._user_disabled_tools: Dict[str, List[str]] = {}
|
||||
|
||||
def disable_specific_chat_action(self, chat_id: str, action_name: str) -> bool:
|
||||
"""禁用特定聊天的某个动作"""
|
||||
if chat_id not in self._user_disabled_actions:
|
||||
self._user_disabled_actions[chat_id] = []
|
||||
if action_name in self._user_disabled_actions[chat_id]:
|
||||
logger.warning(f"动作 {action_name} 已经被禁用")
|
||||
return False
|
||||
self._user_disabled_actions[chat_id].append(action_name)
|
||||
return True
|
||||
|
||||
def enable_specific_chat_action(self, chat_id: str, action_name: str) -> bool:
|
||||
"""启用特定聊天的某个动作"""
|
||||
if chat_id in self._user_disabled_actions:
|
||||
try:
|
||||
self._user_disabled_actions[chat_id].remove(action_name)
|
||||
return True
|
||||
except ValueError:
|
||||
logger.warning(f"动作 {action_name} 不在禁用列表中")
|
||||
return False
|
||||
return False
|
||||
|
||||
def disable_specific_chat_command(self, chat_id: str, command_name: str) -> bool:
|
||||
"""禁用特定聊天的某个命令"""
|
||||
if chat_id not in self._user_disabled_commands:
|
||||
self._user_disabled_commands[chat_id] = []
|
||||
if command_name in self._user_disabled_commands[chat_id]:
|
||||
logger.warning(f"命令 {command_name} 已经被禁用")
|
||||
return False
|
||||
self._user_disabled_commands[chat_id].append(command_name)
|
||||
return True
|
||||
|
||||
def enable_specific_chat_command(self, chat_id: str, command_name: str) -> bool:
|
||||
"""启用特定聊天的某个命令"""
|
||||
if chat_id in self._user_disabled_commands:
|
||||
try:
|
||||
self._user_disabled_commands[chat_id].remove(command_name)
|
||||
return True
|
||||
except ValueError:
|
||||
logger.warning(f"命令 {command_name} 不在禁用列表中")
|
||||
return False
|
||||
return False
|
||||
|
||||
def disable_specific_chat_event_handler(self, chat_id: str, handler_name: str) -> bool:
|
||||
"""禁用特定聊天的某个事件处理器"""
|
||||
if chat_id not in self._user_disabled_event_handlers:
|
||||
self._user_disabled_event_handlers[chat_id] = []
|
||||
if handler_name in self._user_disabled_event_handlers[chat_id]:
|
||||
logger.warning(f"事件处理器 {handler_name} 已经被禁用")
|
||||
return False
|
||||
self._user_disabled_event_handlers[chat_id].append(handler_name)
|
||||
return True
|
||||
|
||||
def enable_specific_chat_event_handler(self, chat_id: str, handler_name: str) -> bool:
|
||||
"""启用特定聊天的某个事件处理器"""
|
||||
if chat_id in self._user_disabled_event_handlers:
|
||||
try:
|
||||
self._user_disabled_event_handlers[chat_id].remove(handler_name)
|
||||
return True
|
||||
except ValueError:
|
||||
logger.warning(f"事件处理器 {handler_name} 不在禁用列表中")
|
||||
return False
|
||||
return False
|
||||
|
||||
def disable_specific_chat_tool(self, chat_id: str, tool_name: str) -> bool:
|
||||
"""禁用特定聊天的某个工具"""
|
||||
if chat_id not in self._user_disabled_tools:
|
||||
self._user_disabled_tools[chat_id] = []
|
||||
if tool_name in self._user_disabled_tools[chat_id]:
|
||||
logger.warning(f"工具 {tool_name} 已经被禁用")
|
||||
return False
|
||||
self._user_disabled_tools[chat_id].append(tool_name)
|
||||
return True
|
||||
|
||||
def enable_specific_chat_tool(self, chat_id: str, tool_name: str) -> bool:
|
||||
"""启用特定聊天的某个工具"""
|
||||
if chat_id in self._user_disabled_tools:
|
||||
try:
|
||||
self._user_disabled_tools[chat_id].remove(tool_name)
|
||||
return True
|
||||
except ValueError:
|
||||
logger.warning(f"工具 {tool_name} 不在禁用列表中")
|
||||
return False
|
||||
return False
|
||||
|
||||
def get_disabled_chat_actions(self, chat_id: str) -> List[str]:
|
||||
"""获取特定聊天禁用的所有动作"""
|
||||
return self._user_disabled_actions.get(chat_id, []).copy()
|
||||
|
||||
def get_disabled_chat_commands(self, chat_id: str) -> List[str]:
|
||||
"""获取特定聊天禁用的所有命令"""
|
||||
return self._user_disabled_commands.get(chat_id, []).copy()
|
||||
|
||||
def get_disabled_chat_event_handlers(self, chat_id: str) -> List[str]:
|
||||
"""获取特定聊天禁用的所有事件处理器"""
|
||||
return self._user_disabled_event_handlers.get(chat_id, []).copy()
|
||||
|
||||
def get_disabled_chat_tools(self, chat_id: str) -> List[str]:
|
||||
"""获取特定聊天禁用的所有工具"""
|
||||
return self._user_disabled_tools.get(chat_id, []).copy()
|
||||
|
||||
|
||||
global_announcement_manager = GlobalAnnouncementManager()
|
||||
242
src/plugin_system/core/plugin_hot_reload.py
Normal file
242
src/plugin_system/core/plugin_hot_reload.py
Normal file
@@ -0,0 +1,242 @@
|
||||
"""
|
||||
插件热重载模块
|
||||
|
||||
使用 Watchdog 监听插件目录变化,自动重载插件
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
from typing import Dict, Set
|
||||
|
||||
from watchdog.observers import Observer
|
||||
from watchdog.events import FileSystemEventHandler
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from .plugin_manager import plugin_manager
|
||||
|
||||
logger = get_logger("plugin_hot_reload")
|
||||
|
||||
|
||||
class PluginFileHandler(FileSystemEventHandler):
|
||||
"""插件文件变化处理器"""
|
||||
|
||||
def __init__(self, hot_reload_manager):
|
||||
super().__init__()
|
||||
self.hot_reload_manager = hot_reload_manager
|
||||
self.pending_reloads: Set[str] = set() # 待重载的插件名称
|
||||
self.last_reload_time: Dict[str, float] = {} # 上次重载时间
|
||||
self.debounce_delay = 1.0 # 防抖延迟(秒)
|
||||
|
||||
def on_modified(self, event):
|
||||
"""文件修改事件"""
|
||||
if not event.is_directory and (event.src_path.endswith('.py') or event.src_path.endswith('.toml')):
|
||||
self._handle_file_change(event.src_path, "modified")
|
||||
|
||||
def on_created(self, event):
|
||||
"""文件创建事件"""
|
||||
if not event.is_directory and (event.src_path.endswith('.py') or event.src_path.endswith('.toml')):
|
||||
self._handle_file_change(event.src_path, "created")
|
||||
|
||||
def on_deleted(self, event):
|
||||
"""文件删除事件"""
|
||||
if not event.is_directory and (event.src_path.endswith('.py') or event.src_path.endswith('.toml')):
|
||||
self._handle_file_change(event.src_path, "deleted")
|
||||
|
||||
def _handle_file_change(self, file_path: str, change_type: str):
|
||||
"""处理文件变化"""
|
||||
try:
|
||||
# 获取插件名称
|
||||
plugin_name = self._get_plugin_name_from_path(file_path)
|
||||
if not plugin_name:
|
||||
return
|
||||
|
||||
current_time = time.time()
|
||||
last_time = self.last_reload_time.get(plugin_name, 0)
|
||||
|
||||
# 防抖处理,避免频繁重载
|
||||
if current_time - last_time < self.debounce_delay:
|
||||
return
|
||||
|
||||
file_name = Path(file_path).name
|
||||
logger.info(f"📁 检测到插件文件变化: {file_name} ({change_type})")
|
||||
|
||||
# 如果是删除事件,处理关键文件删除
|
||||
if change_type == "deleted":
|
||||
if file_name == "plugin.py":
|
||||
if plugin_name in plugin_manager.loaded_plugins:
|
||||
logger.info(f"🗑️ 插件主文件被删除,卸载插件: {plugin_name}")
|
||||
self.hot_reload_manager._unload_plugin(plugin_name)
|
||||
return
|
||||
elif file_name == "manifest.toml":
|
||||
if plugin_name in plugin_manager.loaded_plugins:
|
||||
logger.info(f"🗑️ 插件配置文件被删除,卸载插件: {plugin_name}")
|
||||
self.hot_reload_manager._unload_plugin(plugin_name)
|
||||
return
|
||||
|
||||
# 对于修改和创建事件,都进行重载
|
||||
# 添加到待重载列表
|
||||
self.pending_reloads.add(plugin_name)
|
||||
self.last_reload_time[plugin_name] = current_time
|
||||
|
||||
# 延迟重载,避免文件正在写入时重载
|
||||
reload_thread = Thread(
|
||||
target=self._delayed_reload,
|
||||
args=(plugin_name,),
|
||||
daemon=True
|
||||
)
|
||||
reload_thread.start()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 处理文件变化时发生错误: {e}")
|
||||
|
||||
def _delayed_reload(self, plugin_name: str):
|
||||
"""延迟重载插件"""
|
||||
try:
|
||||
time.sleep(self.debounce_delay)
|
||||
|
||||
if plugin_name in self.pending_reloads:
|
||||
self.pending_reloads.remove(plugin_name)
|
||||
self.hot_reload_manager._reload_plugin(plugin_name)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 延迟重载插件 {plugin_name} 时发生错误: {e}")
|
||||
|
||||
def _get_plugin_name_from_path(self, file_path: str) -> str:
|
||||
"""从文件路径获取插件名称"""
|
||||
try:
|
||||
path = Path(file_path)
|
||||
|
||||
# 检查是否在监听的插件目录中
|
||||
plugin_root = Path(self.hot_reload_manager.watch_directory)
|
||||
if not path.is_relative_to(plugin_root):
|
||||
return ""
|
||||
|
||||
# 获取插件目录名(插件名)
|
||||
relative_path = path.relative_to(plugin_root)
|
||||
plugin_name = relative_path.parts[0]
|
||||
|
||||
# 确认这是一个有效的插件目录(检查是否有 plugin.py 或 manifest.toml)
|
||||
plugin_dir = plugin_root / plugin_name
|
||||
if plugin_dir.is_dir() and ((plugin_dir / "plugin.py").exists() or (plugin_dir / "manifest.toml").exists()):
|
||||
return plugin_name
|
||||
|
||||
return ""
|
||||
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
class PluginHotReloadManager:
|
||||
"""插件热重载管理器"""
|
||||
|
||||
def __init__(self, watch_directory: str = None):
|
||||
print("fuck")
|
||||
print(os.getcwd())
|
||||
self.watch_directory = os.path.join(os.getcwd(), "plugins")
|
||||
self.observer = None
|
||||
self.file_handler = None
|
||||
self.is_running = False
|
||||
|
||||
# 确保监听目录存在
|
||||
if not os.path.exists(self.watch_directory):
|
||||
os.makedirs(self.watch_directory, exist_ok=True)
|
||||
logger.info(f"创建插件监听目录: {self.watch_directory}")
|
||||
|
||||
def start(self):
|
||||
"""启动热重载监听"""
|
||||
if self.is_running:
|
||||
logger.warning("插件热重载已经在运行中")
|
||||
return
|
||||
|
||||
try:
|
||||
self.observer = Observer()
|
||||
self.file_handler = PluginFileHandler(self)
|
||||
|
||||
self.observer.schedule(
|
||||
self.file_handler,
|
||||
self.watch_directory,
|
||||
recursive=True
|
||||
)
|
||||
|
||||
self.observer.start()
|
||||
self.is_running = True
|
||||
|
||||
logger.info("🚀 插件热重载已启动,监听目录: plugins")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 启动插件热重载失败: {e}")
|
||||
self.is_running = False
|
||||
|
||||
def stop(self):
|
||||
"""停止热重载监听"""
|
||||
if not self.is_running:
|
||||
return
|
||||
|
||||
if self.observer:
|
||||
self.observer.stop()
|
||||
self.observer.join()
|
||||
|
||||
self.is_running = False
|
||||
|
||||
def _reload_plugin(self, plugin_name: str):
|
||||
"""重载指定插件"""
|
||||
try:
|
||||
logger.info(f"🔄 开始重载插件: {plugin_name}")
|
||||
|
||||
if plugin_manager.reload_plugin(plugin_name):
|
||||
logger.info(f"✅ 插件重载成功: {plugin_name}")
|
||||
else:
|
||||
logger.error(f"❌ 插件重载失败: {plugin_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 重载插件 {plugin_name} 时发生错误: {e}")
|
||||
|
||||
def _unload_plugin(self, plugin_name: str):
|
||||
"""卸载指定插件"""
|
||||
try:
|
||||
logger.info(f"🗑️ 开始卸载插件: {plugin_name}")
|
||||
|
||||
if plugin_manager.unload_plugin(plugin_name):
|
||||
logger.info(f"✅ 插件卸载成功: {plugin_name}")
|
||||
else:
|
||||
logger.error(f"❌ 插件卸载失败: {plugin_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 卸载插件 {plugin_name} 时发生错误: {e}")
|
||||
|
||||
def reload_all_plugins(self):
|
||||
"""重载所有插件"""
|
||||
try:
|
||||
logger.info("🔄 开始重载所有插件...")
|
||||
|
||||
# 获取当前已加载的插件列表
|
||||
loaded_plugins = list(plugin_manager.loaded_plugins.keys())
|
||||
|
||||
success_count = 0
|
||||
fail_count = 0
|
||||
|
||||
for plugin_name in loaded_plugins:
|
||||
if plugin_manager.reload_plugin(plugin_name):
|
||||
success_count += 1
|
||||
else:
|
||||
fail_count += 1
|
||||
|
||||
logger.info(f"✅ 插件重载完成: 成功 {success_count} 个,失败 {fail_count} 个")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 重载所有插件时发生错误: {e}")
|
||||
|
||||
def get_status(self) -> dict:
|
||||
"""获取热重载状态"""
|
||||
return {
|
||||
"is_running": self.is_running,
|
||||
"watch_directory": self.watch_directory,
|
||||
"loaded_plugins": len(plugin_manager.loaded_plugins),
|
||||
"failed_plugins": len(plugin_manager.failed_plugins),
|
||||
}
|
||||
|
||||
|
||||
# 全局热重载管理器实例
|
||||
hot_reload_manager = PluginHotReloadManager()
|
||||
593
src/plugin_system/core/plugin_manager.py
Normal file
593
src/plugin_system/core/plugin_manager.py
Normal file
@@ -0,0 +1,593 @@
|
||||
import os
|
||||
import traceback
|
||||
import sys
|
||||
|
||||
from typing import Dict, List, Optional, Tuple, Type, Any
|
||||
from importlib.util import spec_from_file_location, module_from_spec
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.plugin_base import PluginBase
|
||||
from src.plugin_system.base.component_types import ComponentType
|
||||
from src.plugin_system.utils.manifest_utils import VersionComparator
|
||||
from .component_registry import component_registry
|
||||
|
||||
logger = get_logger("plugin_manager")
|
||||
|
||||
|
||||
class PluginManager:
|
||||
"""
|
||||
插件管理器类
|
||||
|
||||
负责加载,重载和卸载插件,同时管理插件的所有组件
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.plugin_directories: List[str] = [] # 插件根目录列表
|
||||
self.plugin_classes: Dict[str, Type[PluginBase]] = {} # 全局插件类注册表,插件名 -> 插件类
|
||||
self.plugin_paths: Dict[str, str] = {} # 记录插件名到目录路径的映射,插件名 -> 目录路径
|
||||
|
||||
self.loaded_plugins: Dict[str, PluginBase] = {} # 已加载的插件类实例注册表,插件名 -> 插件类实例
|
||||
self.failed_plugins: Dict[str, str] = {} # 记录加载失败的插件文件及其错误信息,插件名 -> 错误信息
|
||||
|
||||
# 确保插件目录存在
|
||||
self._ensure_plugin_directories()
|
||||
logger.info("插件管理器初始化完成")
|
||||
|
||||
# === 插件目录管理 ===
|
||||
|
||||
def add_plugin_directory(self, directory: str) -> bool:
|
||||
"""添加插件目录"""
|
||||
if os.path.exists(directory):
|
||||
if directory not in self.plugin_directories:
|
||||
self.plugin_directories.append(directory)
|
||||
logger.debug(f"已添加插件目录: {directory}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"插件不可重复加载: {directory}")
|
||||
else:
|
||||
logger.warning(f"插件目录不存在: {directory}")
|
||||
return False
|
||||
|
||||
# === 插件加载管理 ===
|
||||
|
||||
def load_all_plugins(self) -> Tuple[int, int]:
|
||||
"""加载所有插件
|
||||
|
||||
Returns:
|
||||
tuple[int, int]: (插件数量, 组件数量)
|
||||
"""
|
||||
logger.debug("开始加载所有插件...")
|
||||
|
||||
# 第一阶段:加载所有插件模块(注册插件类)
|
||||
total_loaded_modules = 0
|
||||
total_failed_modules = 0
|
||||
|
||||
for directory in self.plugin_directories:
|
||||
loaded, failed = self._load_plugin_modules_from_directory(directory)
|
||||
total_loaded_modules += loaded
|
||||
total_failed_modules += failed
|
||||
|
||||
logger.debug(f"插件模块加载完成 - 成功: {total_loaded_modules}, 失败: {total_failed_modules}")
|
||||
|
||||
total_registered = 0
|
||||
total_failed_registration = 0
|
||||
|
||||
for plugin_name in self.plugin_classes.keys():
|
||||
load_status, count = self.load_registered_plugin_classes(plugin_name)
|
||||
if load_status:
|
||||
total_registered += 1
|
||||
else:
|
||||
total_failed_registration += count
|
||||
|
||||
self._show_stats(total_registered, total_failed_registration)
|
||||
|
||||
return total_registered, total_failed_registration
|
||||
|
||||
def load_registered_plugin_classes(self, plugin_name: str) -> Tuple[bool, int]:
|
||||
# sourcery skip: extract-duplicate-method, extract-method
|
||||
"""
|
||||
加载已经注册的插件类
|
||||
"""
|
||||
plugin_class = self.plugin_classes.get(plugin_name)
|
||||
if not plugin_class:
|
||||
logger.error(f"插件 {plugin_name} 的插件类未注册或不存在")
|
||||
return False, 1
|
||||
try:
|
||||
# 使用记录的插件目录路径
|
||||
plugin_dir = self.plugin_paths.get(plugin_name)
|
||||
|
||||
# 如果没有记录,直接返回失败
|
||||
if not plugin_dir:
|
||||
return False, 1
|
||||
|
||||
plugin_instance = plugin_class(plugin_dir=plugin_dir) # 实例化插件(可能因为缺少manifest而失败)
|
||||
if not plugin_instance:
|
||||
logger.error(f"插件 {plugin_name} 实例化失败")
|
||||
return False, 1
|
||||
# 检查插件是否启用
|
||||
if not plugin_instance.enable_plugin:
|
||||
logger.info(f"插件 {plugin_name} 已禁用,跳过加载")
|
||||
return False, 0
|
||||
|
||||
# 检查版本兼容性
|
||||
is_compatible, compatibility_error = self._check_plugin_version_compatibility(
|
||||
plugin_name, plugin_instance.manifest_data
|
||||
)
|
||||
if not is_compatible:
|
||||
self.failed_plugins[plugin_name] = compatibility_error
|
||||
logger.error(f"❌ 插件加载失败: {plugin_name} - {compatibility_error}")
|
||||
return False, 1
|
||||
if plugin_instance.register_plugin():
|
||||
self.loaded_plugins[plugin_name] = plugin_instance
|
||||
self._show_plugin_components(plugin_name)
|
||||
return True, 1
|
||||
else:
|
||||
self.failed_plugins[plugin_name] = "插件注册失败"
|
||||
logger.error(f"❌ 插件注册失败: {plugin_name}")
|
||||
return False, 1
|
||||
|
||||
except FileNotFoundError as e:
|
||||
# manifest文件缺失
|
||||
error_msg = f"缺少manifest文件: {str(e)}"
|
||||
self.failed_plugins[plugin_name] = error_msg
|
||||
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
|
||||
return False, 1
|
||||
|
||||
except ValueError as e:
|
||||
# manifest文件格式错误或验证失败
|
||||
traceback.print_exc()
|
||||
error_msg = f"manifest验证失败: {str(e)}"
|
||||
self.failed_plugins[plugin_name] = error_msg
|
||||
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
|
||||
return False, 1
|
||||
|
||||
except Exception as e:
|
||||
# 其他错误
|
||||
error_msg = f"未知错误: {str(e)}"
|
||||
self.failed_plugins[plugin_name] = error_msg
|
||||
logger.error(f"❌ 插件加载失败: {plugin_name} - {error_msg}")
|
||||
logger.debug("详细错误信息: ", exc_info=True)
|
||||
return False, 1
|
||||
|
||||
async def remove_registered_plugin(self, plugin_name: str) -> bool:
|
||||
"""
|
||||
禁用插件模块
|
||||
"""
|
||||
if not plugin_name:
|
||||
raise ValueError("插件名称不能为空")
|
||||
if plugin_name not in self.loaded_plugins:
|
||||
logger.warning(f"插件 {plugin_name} 未加载")
|
||||
return False
|
||||
plugin_instance = self.loaded_plugins[plugin_name]
|
||||
plugin_info = plugin_instance.plugin_info
|
||||
success = True
|
||||
for component in plugin_info.components:
|
||||
success &= await component_registry.remove_component(component.name, component.component_type, plugin_name)
|
||||
success &= component_registry.remove_plugin_registry(plugin_name)
|
||||
del self.loaded_plugins[plugin_name]
|
||||
return success
|
||||
|
||||
async def reload_registered_plugin(self, plugin_name: str) -> bool:
|
||||
"""
|
||||
重载插件模块
|
||||
"""
|
||||
if not await self.remove_registered_plugin(plugin_name):
|
||||
return False
|
||||
if not self.load_registered_plugin_classes(plugin_name)[0]:
|
||||
return False
|
||||
logger.debug(f"插件 {plugin_name} 重载成功")
|
||||
return True
|
||||
|
||||
def rescan_plugin_directory(self) -> Tuple[int, int]:
|
||||
"""
|
||||
重新扫描插件根目录
|
||||
"""
|
||||
total_success = 0
|
||||
total_fail = 0
|
||||
for directory in self.plugin_directories:
|
||||
if os.path.exists(directory):
|
||||
logger.debug(f"重新扫描插件根目录: {directory}")
|
||||
success, fail = self._load_plugin_modules_from_directory(directory)
|
||||
total_success += success
|
||||
total_fail += fail
|
||||
else:
|
||||
logger.warning(f"插件根目录不存在: {directory}")
|
||||
return total_success, total_fail
|
||||
|
||||
def get_plugin_instance(self, plugin_name: str) -> Optional["PluginBase"]:
|
||||
"""获取插件实例
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
Optional[BasePlugin]: 插件实例或None
|
||||
"""
|
||||
return self.loaded_plugins.get(plugin_name)
|
||||
|
||||
# === 查询方法 ===
|
||||
def list_loaded_plugins(self) -> List[str]:
|
||||
"""
|
||||
列出所有当前加载的插件。
|
||||
|
||||
Returns:
|
||||
list: 当前加载的插件名称列表。
|
||||
"""
|
||||
return list(self.loaded_plugins.keys())
|
||||
|
||||
def list_registered_plugins(self) -> List[str]:
|
||||
"""
|
||||
列出所有已注册的插件类。
|
||||
|
||||
Returns:
|
||||
list: 已注册的插件类名称列表。
|
||||
"""
|
||||
return list(self.plugin_classes.keys())
|
||||
|
||||
def get_plugin_path(self, plugin_name: str) -> Optional[str]:
|
||||
"""
|
||||
获取指定插件的路径。
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
Optional[str]: 插件目录的绝对路径,如果插件不存在则返回None。
|
||||
"""
|
||||
return self.plugin_paths.get(plugin_name)
|
||||
|
||||
# === 私有方法 ===
|
||||
# == 目录管理 ==
|
||||
def _ensure_plugin_directories(self) -> None:
|
||||
"""确保所有插件根目录存在,如果不存在则创建"""
|
||||
default_directories = ["src/plugins/built_in", "plugins"]
|
||||
|
||||
for directory in default_directories:
|
||||
if not os.path.exists(directory):
|
||||
os.makedirs(directory, exist_ok=True)
|
||||
logger.info(f"创建插件根目录: {directory}")
|
||||
if directory not in self.plugin_directories:
|
||||
self.plugin_directories.append(directory)
|
||||
logger.debug(f"已添加插件根目录: {directory}")
|
||||
else:
|
||||
logger.warning(f"根目录不可重复加载: {directory}")
|
||||
|
||||
# == 插件加载 ==
|
||||
|
||||
def _load_plugin_modules_from_directory(self, directory: str) -> tuple[int, int]:
|
||||
"""从指定目录加载插件模块"""
|
||||
loaded_count = 0
|
||||
failed_count = 0
|
||||
|
||||
if not os.path.exists(directory):
|
||||
logger.warning(f"插件根目录不存在: {directory}")
|
||||
return 0, 1
|
||||
|
||||
logger.debug(f"正在扫描插件根目录: {directory}")
|
||||
|
||||
# 遍历目录中的所有包
|
||||
for item in os.listdir(directory):
|
||||
item_path = os.path.join(directory, item)
|
||||
|
||||
if os.path.isdir(item_path) and not item.startswith(".") and not item.startswith("__"):
|
||||
plugin_file = os.path.join(item_path, "plugin.py")
|
||||
if os.path.exists(plugin_file):
|
||||
if self._load_plugin_module_file(plugin_file):
|
||||
loaded_count += 1
|
||||
else:
|
||||
failed_count += 1
|
||||
|
||||
return loaded_count, failed_count
|
||||
|
||||
def _load_plugin_module_file(self, plugin_file: str) -> bool:
|
||||
# sourcery skip: extract-method
|
||||
"""加载单个插件模块文件
|
||||
|
||||
Args:
|
||||
plugin_file: 插件文件路径
|
||||
plugin_name: 插件名称
|
||||
plugin_dir: 插件目录路径
|
||||
"""
|
||||
# 生成模块名
|
||||
plugin_path = Path(plugin_file)
|
||||
module_name = ".".join(plugin_path.parent.parts)
|
||||
|
||||
try:
|
||||
# 动态导入插件模块
|
||||
spec = spec_from_file_location(module_name, plugin_file)
|
||||
if spec is None or spec.loader is None:
|
||||
logger.error(f"无法创建模块规范: {plugin_file}")
|
||||
return False
|
||||
|
||||
module = module_from_spec(spec)
|
||||
module.__package__ = module_name # 设置模块包名
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
logger.debug(f"插件模块加载成功: {plugin_file}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"加载插件模块 {plugin_file} 失败: {e}"
|
||||
logger.error(error_msg)
|
||||
self.failed_plugins[module_name] = error_msg
|
||||
return False
|
||||
|
||||
# == 兼容性检查 ==
|
||||
|
||||
def _check_plugin_version_compatibility(self, plugin_name: str, manifest_data: Dict[str, Any]) -> Tuple[bool, str]:
|
||||
"""检查插件版本兼容性
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称
|
||||
manifest_data: manifest数据
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (是否兼容, 错误信息)
|
||||
"""
|
||||
if "host_application" not in manifest_data:
|
||||
return True, "" # 没有版本要求,默认兼容
|
||||
|
||||
host_app = manifest_data["host_application"]
|
||||
if not isinstance(host_app, dict):
|
||||
return True, ""
|
||||
|
||||
min_version = host_app.get("min_version", "")
|
||||
max_version = host_app.get("max_version", "")
|
||||
|
||||
if not min_version and not max_version:
|
||||
return True, "" # 没有版本要求,默认兼容
|
||||
|
||||
try:
|
||||
current_version = VersionComparator.get_current_host_version()
|
||||
is_compatible, error_msg = VersionComparator.is_version_in_range(current_version, min_version, max_version)
|
||||
if not is_compatible:
|
||||
return False, f"版本不兼容: {error_msg}"
|
||||
logger.debug(f"插件 {plugin_name} 版本兼容性检查通过")
|
||||
return True, ""
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"插件 {plugin_name} 版本兼容性检查失败: {e}")
|
||||
return False, f"插件 {plugin_name} 版本兼容性检查失败: {e}" # 检查失败时默认不允许加载
|
||||
|
||||
# == 显示统计与插件信息 ==
|
||||
|
||||
def _show_stats(self, total_registered: int, total_failed_registration: int):
|
||||
# sourcery skip: low-code-quality
|
||||
# 获取组件统计信息
|
||||
stats = component_registry.get_registry_stats()
|
||||
action_count = stats.get("action_components", 0)
|
||||
command_count = stats.get("command_components", 0)
|
||||
tool_count = stats.get("tool_components", 0)
|
||||
event_handler_count = stats.get("event_handlers", 0)
|
||||
total_components = stats.get("total_components", 0)
|
||||
|
||||
# 📋 显示插件加载总览
|
||||
if total_registered > 0:
|
||||
logger.info("🎉 插件系统加载完成!")
|
||||
logger.info(
|
||||
f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count}, Tool: {tool_count}, EventHandler: {event_handler_count})"
|
||||
)
|
||||
|
||||
# 显示详细的插件列表
|
||||
logger.info("📋 已加载插件详情:")
|
||||
for plugin_name in self.loaded_plugins.keys():
|
||||
if plugin_info := component_registry.get_plugin_info(plugin_name):
|
||||
# 插件基本信息
|
||||
version_info = f"v{plugin_info.version}" if plugin_info.version else ""
|
||||
author_info = f"by {plugin_info.author}" if plugin_info.author else "unknown"
|
||||
license_info = f"[{plugin_info.license}]" if plugin_info.license else ""
|
||||
info_parts = [part for part in [version_info, author_info, license_info] if part]
|
||||
extra_info = f" ({', '.join(info_parts)})" if info_parts else ""
|
||||
|
||||
logger.info(f" 📦 {plugin_info.display_name}{extra_info}")
|
||||
|
||||
# Manifest信息
|
||||
if plugin_info.manifest_data:
|
||||
"""
|
||||
if plugin_info.keywords:
|
||||
logger.info(f" 🏷️ 关键词: {', '.join(plugin_info.keywords)}")
|
||||
if plugin_info.categories:
|
||||
logger.info(f" 📁 分类: {', '.join(plugin_info.categories)}")
|
||||
"""
|
||||
if plugin_info.homepage_url:
|
||||
logger.info(f" 🌐 主页: {plugin_info.homepage_url}")
|
||||
|
||||
# 组件列表
|
||||
if plugin_info.components:
|
||||
action_components = [
|
||||
c for c in plugin_info.components if c.component_type == ComponentType.ACTION
|
||||
]
|
||||
command_components = [
|
||||
c for c in plugin_info.components if c.component_type == ComponentType.COMMAND
|
||||
]
|
||||
tool_components = [
|
||||
c for c in plugin_info.components if c.component_type == ComponentType.TOOL
|
||||
]
|
||||
event_handler_components = [
|
||||
c for c in plugin_info.components if c.component_type == ComponentType.EVENT_HANDLER
|
||||
]
|
||||
|
||||
if action_components:
|
||||
action_names = [c.name for c in action_components]
|
||||
logger.info(f" 🎯 Action组件: {', '.join(action_names)}")
|
||||
|
||||
if command_components:
|
||||
command_names = [c.name for c in command_components]
|
||||
logger.info(f" ⚡ Command组件: {', '.join(command_names)}")
|
||||
if tool_components:
|
||||
tool_names = [c.name for c in tool_components]
|
||||
logger.info(f" 🛠️ Tool组件: {', '.join(tool_names)}")
|
||||
if event_handler_components:
|
||||
event_handler_names = [c.name for c in event_handler_components]
|
||||
logger.info(f" 📢 EventHandler组件: {', '.join(event_handler_names)}")
|
||||
|
||||
# 依赖信息
|
||||
if plugin_info.dependencies:
|
||||
logger.info(f" 🔗 依赖: {', '.join(plugin_info.dependencies)}")
|
||||
|
||||
# 配置文件信息
|
||||
if plugin_info.config_file:
|
||||
config_status = "✅" if self.plugin_paths.get(plugin_name) else "❌"
|
||||
logger.info(f" ⚙️ 配置: {plugin_info.config_file} {config_status}")
|
||||
|
||||
root_path = Path(__file__)
|
||||
|
||||
# 查找项目根目录
|
||||
while not (root_path / "pyproject.toml").exists() and root_path.parent != root_path:
|
||||
root_path = root_path.parent
|
||||
|
||||
# 显示目录统计
|
||||
logger.info("📂 加载目录统计:")
|
||||
for directory in self.plugin_directories:
|
||||
if os.path.exists(directory):
|
||||
plugins_in_dir = []
|
||||
for plugin_name in self.loaded_plugins.keys():
|
||||
plugin_path = self.plugin_paths.get(plugin_name, "")
|
||||
if (
|
||||
Path(plugin_path)
|
||||
.resolve()
|
||||
.is_relative_to(Path(os.path.join(str(root_path), directory)).resolve())
|
||||
):
|
||||
plugins_in_dir.append(plugin_name)
|
||||
|
||||
if plugins_in_dir:
|
||||
logger.info(f" 📁 {directory}: {len(plugins_in_dir)}个插件 ({', '.join(plugins_in_dir)})")
|
||||
else:
|
||||
logger.info(f" 📁 {directory}: 0个插件")
|
||||
|
||||
# 失败信息
|
||||
if total_failed_registration > 0:
|
||||
logger.info(f"⚠️ 失败统计: {total_failed_registration}个插件加载失败")
|
||||
for failed_plugin, error in self.failed_plugins.items():
|
||||
logger.info(f" ❌ {failed_plugin}: {error}")
|
||||
else:
|
||||
logger.warning("😕 没有成功加载任何插件")
|
||||
|
||||
def _show_plugin_components(self, plugin_name: str) -> None:
|
||||
if plugin_info := component_registry.get_plugin_info(plugin_name):
|
||||
component_types = {}
|
||||
for comp in plugin_info.components:
|
||||
comp_type = comp.component_type.name
|
||||
component_types[comp_type] = component_types.get(comp_type, 0) + 1
|
||||
|
||||
components_str = ", ".join([f"{count}个{ctype}" for ctype, count in component_types.items()])
|
||||
|
||||
# 显示manifest信息
|
||||
manifest_info = ""
|
||||
if plugin_info.license:
|
||||
manifest_info += f" [{plugin_info.license}]"
|
||||
if plugin_info.keywords:
|
||||
manifest_info += f" 关键词: {', '.join(plugin_info.keywords[:3])}" # 只显示前3个关键词
|
||||
if len(plugin_info.keywords) > 3:
|
||||
manifest_info += "..."
|
||||
|
||||
logger.info(
|
||||
f"✅ 插件加载成功: {plugin_name} v{plugin_info.version} ({components_str}){manifest_info} - {plugin_info.description}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"✅ 插件加载成功: {plugin_name}")
|
||||
|
||||
# === 插件卸载和重载管理 ===
|
||||
|
||||
def unload_plugin(self, plugin_name: str) -> bool:
|
||||
"""卸载指定插件
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
bool: 卸载是否成功
|
||||
"""
|
||||
if plugin_name not in self.loaded_plugins:
|
||||
logger.warning(f"插件 {plugin_name} 未加载,无需卸载")
|
||||
return False
|
||||
|
||||
try:
|
||||
# 获取插件实例
|
||||
plugin_instance = self.loaded_plugins[plugin_name]
|
||||
|
||||
# 调用插件的清理方法(如果有的话)
|
||||
if hasattr(plugin_instance, 'on_unload'):
|
||||
plugin_instance.on_unload()
|
||||
|
||||
# 从组件注册表中移除插件的所有组件
|
||||
component_registry.unregister_plugin(plugin_name)
|
||||
|
||||
# 从已加载插件中移除
|
||||
del self.loaded_plugins[plugin_name]
|
||||
|
||||
# 从失败列表中移除(如果存在)
|
||||
if plugin_name in self.failed_plugins:
|
||||
del self.failed_plugins[plugin_name]
|
||||
|
||||
logger.info(f"✅ 插件卸载成功: {plugin_name}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 插件卸载失败: {plugin_name} - {str(e)}")
|
||||
return False
|
||||
|
||||
def reload_plugin(self, plugin_name: str) -> bool:
|
||||
"""重载指定插件
|
||||
|
||||
Args:
|
||||
plugin_name: 插件名称
|
||||
|
||||
Returns:
|
||||
bool: 重载是否成功
|
||||
"""
|
||||
try:
|
||||
# 先卸载插件
|
||||
if plugin_name in self.loaded_plugins:
|
||||
self.unload_plugin(plugin_name)
|
||||
|
||||
# 清除Python模块缓存
|
||||
plugin_path = self.plugin_paths.get(plugin_name)
|
||||
if plugin_path:
|
||||
plugin_file = os.path.join(plugin_path, "plugin.py")
|
||||
if os.path.exists(plugin_file):
|
||||
# 从sys.modules中移除相关模块
|
||||
modules_to_remove = []
|
||||
plugin_module_prefix = ".".join(Path(plugin_file).parent.parts)
|
||||
|
||||
for module_name in sys.modules:
|
||||
if module_name.startswith(plugin_module_prefix):
|
||||
modules_to_remove.append(module_name)
|
||||
|
||||
for module_name in modules_to_remove:
|
||||
del sys.modules[module_name]
|
||||
|
||||
# 从插件类注册表中移除
|
||||
if plugin_name in self.plugin_classes:
|
||||
del self.plugin_classes[plugin_name]
|
||||
|
||||
# 重新加载插件模块
|
||||
if self._load_plugin_module_file(plugin_file):
|
||||
# 重新加载插件实例
|
||||
success, _ = self.load_registered_plugin_classes(plugin_name)
|
||||
if success:
|
||||
logger.info(f"🔄 插件重载成功: {plugin_name}")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"❌ 插件重载失败: {plugin_name} - 实例化失败")
|
||||
return False
|
||||
else:
|
||||
logger.error(f"❌ 插件重载失败: {plugin_name} - 模块加载失败")
|
||||
return False
|
||||
else:
|
||||
logger.error(f"❌ 插件重载失败: {plugin_name} - 插件文件不存在")
|
||||
return False
|
||||
else:
|
||||
logger.error(f"❌ 插件重载失败: {plugin_name} - 插件路径未知")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 插件重载失败: {plugin_name} - {str(e)}")
|
||||
logger.debug("详细错误信息: ", exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
# 全局插件管理器实例
|
||||
plugin_manager = PluginManager()
|
||||
421
src/plugin_system/core/tool_use.py
Normal file
421
src/plugin_system/core/tool_use.py
Normal file
@@ -0,0 +1,421 @@
|
||||
import time
|
||||
from typing import List, Dict, Tuple, Optional, Any
|
||||
from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance
|
||||
from src.plugin_system.base.base_tool import BaseTool
|
||||
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.llm_models.payload_content import ToolCall
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("tool_use")
|
||||
|
||||
|
||||
def init_tool_executor_prompt():
|
||||
"""初始化工具执行器的提示词"""
|
||||
tool_executor_prompt = """
|
||||
你是一个专门执行工具的助手。你的名字是{bot_name}。现在是{time_now}。
|
||||
群里正在进行的聊天内容:
|
||||
{chat_history}
|
||||
|
||||
现在,{sender}发送了内容:{target_message},你想要回复ta。
|
||||
请仔细分析聊天内容,考虑以下几点:
|
||||
1. 内容中是否包含需要查询信息的问题
|
||||
2. 是否有明确的工具使用指令
|
||||
|
||||
If you need to use a tool, please directly call the corresponding tool function. If you do not need to use any tool, simply output "No tool needed".
|
||||
"""
|
||||
Prompt(tool_executor_prompt, "tool_executor_prompt")
|
||||
|
||||
|
||||
# 初始化提示词
|
||||
init_tool_executor_prompt()
|
||||
|
||||
|
||||
class ToolExecutor:
|
||||
"""独立的工具执行器组件
|
||||
|
||||
可以直接输入聊天消息内容,自动判断并执行相应的工具,返回结构化的工具执行结果。
|
||||
"""
|
||||
|
||||
def __init__(self, chat_id: str, enable_cache: bool = True, cache_ttl: int = 3):
|
||||
"""初始化工具执行器
|
||||
|
||||
Args:
|
||||
executor_id: 执行器标识符,用于日志记录
|
||||
enable_cache: 是否启用缓存机制
|
||||
cache_ttl: 缓存生存时间(周期数)
|
||||
"""
|
||||
self.chat_id = chat_id
|
||||
self.chat_stream = get_chat_manager().get_stream(self.chat_id)
|
||||
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.chat_id) or self.chat_id}]"
|
||||
|
||||
self.llm_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="tool_executor")
|
||||
|
||||
# 缓存配置
|
||||
self.enable_cache = enable_cache
|
||||
self.cache_ttl = cache_ttl
|
||||
self.tool_cache = {} # 格式: {cache_key: {"result": result, "ttl": ttl, "timestamp": timestamp}}
|
||||
|
||||
logger.info(f"{self.log_prefix}工具执行器初始化完成,缓存{'启用' if enable_cache else '禁用'},TTL={cache_ttl}")
|
||||
|
||||
async def execute_from_chat_message(
|
||||
self, target_message: str, chat_history: str, sender: str, return_details: bool = False
|
||||
) -> Tuple[List[Dict[str, Any]], List[str], str]:
|
||||
"""从聊天消息执行工具
|
||||
|
||||
Args:
|
||||
target_message: 目标消息内容
|
||||
chat_history: 聊天历史
|
||||
sender: 发送者
|
||||
return_details: 是否返回详细信息(使用的工具列表和提示词)
|
||||
|
||||
Returns:
|
||||
如果return_details为False: Tuple[List[Dict], List[str], str] - (工具执行结果列表, 空, 空)
|
||||
如果return_details为True: Tuple[List[Dict], List[str], str] - (结果列表, 使用的工具, 提示词)
|
||||
"""
|
||||
|
||||
# 首先检查缓存
|
||||
cache_key = self._generate_cache_key(target_message, chat_history, sender)
|
||||
if cached_result := self._get_from_cache(cache_key):
|
||||
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具执行")
|
||||
if not return_details:
|
||||
return cached_result, [], ""
|
||||
|
||||
# 从缓存结果中提取工具名称
|
||||
used_tools = [result.get("tool_name", "unknown") for result in cached_result]
|
||||
return cached_result, used_tools, ""
|
||||
|
||||
# 缓存未命中,执行工具调用
|
||||
# 获取可用工具
|
||||
tools = self._get_tool_definitions()
|
||||
|
||||
# 获取当前时间
|
||||
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
|
||||
# 构建工具调用提示词
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"tool_executor_prompt",
|
||||
target_message=target_message,
|
||||
chat_history=chat_history,
|
||||
sender=sender,
|
||||
bot_name=bot_name,
|
||||
time_now=time_now,
|
||||
)
|
||||
|
||||
logger.debug(f"{self.log_prefix}开始LLM工具调用分析")
|
||||
|
||||
# 调用LLM进行工具决策
|
||||
response, (reasoning_content, model_name, tool_calls) = await self.llm_model.generate_response_async(
|
||||
prompt=prompt, tools=tools, raise_when_empty=False
|
||||
)
|
||||
|
||||
# 执行工具调用
|
||||
tool_results, used_tools = await self.execute_tool_calls(tool_calls)
|
||||
|
||||
# 缓存结果
|
||||
if tool_results:
|
||||
self._set_cache(cache_key, tool_results)
|
||||
|
||||
if used_tools:
|
||||
logger.info(f"{self.log_prefix}工具执行完成,共执行{len(used_tools)}个工具: {used_tools}")
|
||||
|
||||
if return_details:
|
||||
return tool_results, used_tools, prompt
|
||||
else:
|
||||
return tool_results, [], ""
|
||||
|
||||
def _get_tool_definitions(self) -> List[Dict[str, Any]]:
|
||||
all_tools = get_llm_available_tool_definitions()
|
||||
user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id)
|
||||
return [definition for name, definition in all_tools if name not in user_disabled_tools]
|
||||
|
||||
async def execute_tool_calls(self, tool_calls: Optional[List[ToolCall]]) -> Tuple[List[Dict[str, Any]], List[str]]:
|
||||
"""执行工具调用
|
||||
|
||||
Args:
|
||||
tool_calls: LLM返回的工具调用列表
|
||||
|
||||
Returns:
|
||||
Tuple[List[Dict], List[str]]: (工具执行结果列表, 使用的工具名称列表)
|
||||
"""
|
||||
tool_results: List[Dict[str, Any]] = []
|
||||
used_tools = []
|
||||
|
||||
if not tool_calls:
|
||||
logger.debug(f"{self.log_prefix}无需执行工具")
|
||||
return [], []
|
||||
|
||||
# 提取tool_calls中的函数名称
|
||||
func_names = [call.func_name for call in tool_calls if call.func_name]
|
||||
|
||||
logger.info(f"{self.log_prefix}开始执行工具调用: {func_names}")
|
||||
|
||||
# 执行每个工具调用
|
||||
for tool_call in tool_calls:
|
||||
try:
|
||||
tool_name = tool_call.func_name
|
||||
logger.debug(f"{self.log_prefix}执行工具: {tool_name}")
|
||||
|
||||
# 执行工具
|
||||
result = await self.execute_tool_call(tool_call)
|
||||
|
||||
if result:
|
||||
tool_info = {
|
||||
"type": result.get("type", "unknown_type"),
|
||||
"id": result.get("id", f"tool_exec_{time.time()}"),
|
||||
"content": result.get("content", ""),
|
||||
"tool_name": tool_name,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
content = tool_info["content"]
|
||||
if not isinstance(content, (str, list, tuple)):
|
||||
tool_info["content"] = str(content)
|
||||
|
||||
tool_results.append(tool_info)
|
||||
used_tools.append(tool_name)
|
||||
logger.info(f"{self.log_prefix}工具{tool_name}执行成功,类型: {tool_info['type']}")
|
||||
preview = content[:200]
|
||||
logger.debug(f"{self.log_prefix}工具{tool_name}结果内容: {preview}...")
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}工具{tool_name}执行失败: {e}")
|
||||
# 添加错误信息到结果中
|
||||
error_info = {
|
||||
"type": "tool_error",
|
||||
"id": f"tool_error_{time.time()}",
|
||||
"content": f"工具{tool_name}执行失败: {str(e)}",
|
||||
"tool_name": tool_name,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
tool_results.append(error_info)
|
||||
|
||||
return tool_results, used_tools
|
||||
|
||||
async def execute_tool_call(self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None) -> Optional[Dict[str, Any]]:
|
||||
# sourcery skip: use-assigned-variable
|
||||
"""执行单个工具调用
|
||||
|
||||
Args:
|
||||
tool_call: 工具调用对象
|
||||
|
||||
Returns:
|
||||
Optional[Dict]: 工具调用结果,如果失败则返回None
|
||||
"""
|
||||
try:
|
||||
function_name = tool_call.func_name
|
||||
function_args = tool_call.args or {}
|
||||
function_args["llm_called"] = True # 标记为LLM调用
|
||||
|
||||
# 获取对应工具实例
|
||||
tool_instance = tool_instance or get_tool_instance(function_name)
|
||||
if not tool_instance:
|
||||
logger.warning(f"未知工具名称: {function_name}")
|
||||
return None
|
||||
|
||||
# 执行工具
|
||||
result = await tool_instance.execute(function_args)
|
||||
if result:
|
||||
return {
|
||||
"tool_call_id": tool_call.call_id,
|
||||
"role": "tool",
|
||||
"name": function_name,
|
||||
"type": "function",
|
||||
"content": result["content"],
|
||||
}
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"执行工具调用时发生错误: {str(e)}")
|
||||
raise e
|
||||
|
||||
def _generate_cache_key(self, target_message: str, chat_history: str, sender: str) -> str:
|
||||
"""生成缓存键
|
||||
|
||||
Args:
|
||||
target_message: 目标消息内容
|
||||
chat_history: 聊天历史
|
||||
sender: 发送者
|
||||
|
||||
Returns:
|
||||
str: 缓存键
|
||||
"""
|
||||
import hashlib
|
||||
|
||||
# 使用消息内容和群聊状态生成唯一缓存键
|
||||
content = f"{target_message}_{chat_history}_{sender}"
|
||||
return hashlib.md5(content.encode()).hexdigest()
|
||||
|
||||
def _get_from_cache(self, cache_key: str) -> Optional[List[Dict]]:
|
||||
"""从缓存获取结果
|
||||
|
||||
Args:
|
||||
cache_key: 缓存键
|
||||
|
||||
Returns:
|
||||
Optional[List[Dict]]: 缓存的结果,如果不存在或过期则返回None
|
||||
"""
|
||||
if not self.enable_cache or cache_key not in self.tool_cache:
|
||||
return None
|
||||
|
||||
cache_item = self.tool_cache[cache_key]
|
||||
if cache_item["ttl"] <= 0:
|
||||
# 缓存过期,删除
|
||||
del self.tool_cache[cache_key]
|
||||
logger.debug(f"{self.log_prefix}缓存过期,删除缓存键: {cache_key}")
|
||||
return None
|
||||
|
||||
# 减少TTL
|
||||
cache_item["ttl"] -= 1
|
||||
logger.debug(f"{self.log_prefix}使用缓存结果,剩余TTL: {cache_item['ttl']}")
|
||||
return cache_item["result"]
|
||||
|
||||
def _set_cache(self, cache_key: str, result: List[Dict]):
|
||||
"""设置缓存
|
||||
|
||||
Args:
|
||||
cache_key: 缓存键
|
||||
result: 要缓存的结果
|
||||
"""
|
||||
if not self.enable_cache:
|
||||
return
|
||||
|
||||
self.tool_cache[cache_key] = {"result": result, "ttl": self.cache_ttl, "timestamp": time.time()}
|
||||
logger.debug(f"{self.log_prefix}设置缓存,TTL: {self.cache_ttl}")
|
||||
|
||||
def _cleanup_expired_cache(self):
|
||||
"""清理过期的缓存"""
|
||||
if not self.enable_cache:
|
||||
return
|
||||
|
||||
expired_keys = []
|
||||
expired_keys.extend(cache_key for cache_key, cache_item in self.tool_cache.items() if cache_item["ttl"] <= 0)
|
||||
for key in expired_keys:
|
||||
del self.tool_cache[key]
|
||||
|
||||
if expired_keys:
|
||||
logger.debug(f"{self.log_prefix}清理了{len(expired_keys)}个过期缓存")
|
||||
|
||||
async def execute_specific_tool_simple(self, tool_name: str, tool_args: Dict) -> Optional[Dict]:
|
||||
"""直接执行指定工具
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
tool_args: 工具参数
|
||||
validate_args: 是否验证参数
|
||||
|
||||
Returns:
|
||||
Optional[Dict]: 工具执行结果,失败时返回None
|
||||
"""
|
||||
try:
|
||||
tool_call = ToolCall(
|
||||
call_id=f"direct_tool_{time.time()}",
|
||||
func_name=tool_name,
|
||||
args=tool_args,
|
||||
)
|
||||
|
||||
logger.info(f"{self.log_prefix}直接执行工具: {tool_name}")
|
||||
|
||||
result = await self.execute_tool_call(tool_call)
|
||||
|
||||
if result:
|
||||
tool_info = {
|
||||
"type": result.get("type", "unknown_type"),
|
||||
"id": result.get("id", f"direct_tool_{time.time()}"),
|
||||
"content": result.get("content", ""),
|
||||
"tool_name": tool_name,
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
logger.info(f"{self.log_prefix}直接工具执行成功: {tool_name}")
|
||||
return tool_info
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix}直接工具执行失败 {tool_name}: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def clear_cache(self):
|
||||
"""清空所有缓存"""
|
||||
if self.enable_cache:
|
||||
cache_count = len(self.tool_cache)
|
||||
self.tool_cache.clear()
|
||||
logger.info(f"{self.log_prefix}清空了{cache_count}个缓存项")
|
||||
|
||||
def get_cache_status(self) -> Dict:
|
||||
"""获取缓存状态信息
|
||||
|
||||
Returns:
|
||||
Dict: 包含缓存统计信息的字典
|
||||
"""
|
||||
if not self.enable_cache:
|
||||
return {"enabled": False, "cache_count": 0}
|
||||
|
||||
# 清理过期缓存
|
||||
self._cleanup_expired_cache()
|
||||
|
||||
total_count = len(self.tool_cache)
|
||||
ttl_distribution = {}
|
||||
|
||||
for cache_item in self.tool_cache.values():
|
||||
ttl = cache_item["ttl"]
|
||||
ttl_distribution[ttl] = ttl_distribution.get(ttl, 0) + 1
|
||||
|
||||
return {
|
||||
"enabled": True,
|
||||
"cache_count": total_count,
|
||||
"cache_ttl": self.cache_ttl,
|
||||
"ttl_distribution": ttl_distribution,
|
||||
}
|
||||
|
||||
def set_cache_config(self, enable_cache: Optional[bool] = None, cache_ttl: int = -1):
|
||||
"""动态修改缓存配置
|
||||
|
||||
Args:
|
||||
enable_cache: 是否启用缓存
|
||||
cache_ttl: 缓存TTL
|
||||
"""
|
||||
if enable_cache is not None:
|
||||
self.enable_cache = enable_cache
|
||||
logger.info(f"{self.log_prefix}缓存状态修改为: {'启用' if enable_cache else '禁用'}")
|
||||
|
||||
if cache_ttl > 0:
|
||||
self.cache_ttl = cache_ttl
|
||||
logger.info(f"{self.log_prefix}缓存TTL修改为: {cache_ttl}")
|
||||
|
||||
|
||||
"""
|
||||
ToolExecutor使用示例:
|
||||
|
||||
# 1. 基础使用 - 从聊天消息执行工具(启用缓存,默认TTL=3)
|
||||
executor = ToolExecutor(executor_id="my_executor")
|
||||
results, _, _ = await executor.execute_from_chat_message(
|
||||
talking_message_str="今天天气怎么样?现在几点了?",
|
||||
is_group_chat=False
|
||||
)
|
||||
|
||||
# 2. 禁用缓存的执行器
|
||||
no_cache_executor = ToolExecutor(executor_id="no_cache", enable_cache=False)
|
||||
|
||||
# 3. 自定义缓存TTL
|
||||
long_cache_executor = ToolExecutor(executor_id="long_cache", cache_ttl=10)
|
||||
|
||||
# 4. 获取详细信息
|
||||
results, used_tools, prompt = await executor.execute_from_chat_message(
|
||||
talking_message_str="帮我查询Python相关知识",
|
||||
is_group_chat=False,
|
||||
return_details=True
|
||||
)
|
||||
|
||||
# 5. 直接执行特定工具
|
||||
result = await executor.execute_specific_tool_simple(
|
||||
tool_name="get_knowledge",
|
||||
tool_args={"query": "机器学习"}
|
||||
)
|
||||
|
||||
# 6. 缓存管理
|
||||
cache_status = executor.get_cache_status() # 查看缓存状态
|
||||
executor.clear_cache() # 清空缓存
|
||||
executor.set_cache_config(cache_ttl=5) # 动态修改缓存配置
|
||||
"""
|
||||
19
src/plugin_system/utils/__init__.py
Normal file
19
src/plugin_system/utils/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
插件系统工具模块
|
||||
|
||||
提供插件开发和管理的实用工具
|
||||
"""
|
||||
|
||||
from .manifest_utils import (
|
||||
ManifestValidator,
|
||||
# ManifestGenerator,
|
||||
# validate_plugin_manifest,
|
||||
# generate_plugin_manifest,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ManifestValidator",
|
||||
# "ManifestGenerator",
|
||||
# "validate_plugin_manifest",
|
||||
# "generate_plugin_manifest",
|
||||
]
|
||||
515
src/plugin_system/utils/manifest_utils.py
Normal file
515
src/plugin_system/utils/manifest_utils.py
Normal file
@@ -0,0 +1,515 @@
|
||||
"""
|
||||
插件Manifest工具模块
|
||||
|
||||
提供manifest文件的验证、生成和管理功能
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Dict, Any, Tuple
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import MMC_VERSION
|
||||
|
||||
# if TYPE_CHECKING:
|
||||
# from src.plugin_system.base.base_plugin import BasePlugin
|
||||
|
||||
logger = get_logger("manifest_utils")
|
||||
|
||||
|
||||
class VersionComparator:
|
||||
"""版本号比较器
|
||||
|
||||
支持语义化版本号比较,自动处理snapshot版本,并支持向前兼容性检查
|
||||
"""
|
||||
|
||||
# 版本兼容性映射表(硬编码)
|
||||
# 格式: {插件最大支持版本: [实际兼容的版本列表]}
|
||||
COMPATIBILITY_MAP = {
|
||||
# 0.8.x 系列向前兼容规则
|
||||
"0.8.0": ["0.8.1", "0.8.2", "0.8.3", "0.8.4", "0.8.5", "0.8.6", "0.8.7", "0.8.8", "0.8.9", "0.8.10"],
|
||||
"0.8.1": ["0.8.2", "0.8.3", "0.8.4", "0.8.5", "0.8.6", "0.8.7", "0.8.8", "0.8.9", "0.8.10"],
|
||||
"0.8.2": ["0.8.3", "0.8.4", "0.8.5", "0.8.6", "0.8.7", "0.8.8", "0.8.9", "0.8.10"],
|
||||
"0.8.3": ["0.8.4", "0.8.5", "0.8.6", "0.8.7", "0.8.8", "0.8.9", "0.8.10"],
|
||||
"0.8.4": ["0.8.5", "0.8.6", "0.8.7", "0.8.8", "0.8.9", "0.8.10"],
|
||||
"0.8.5": ["0.8.6", "0.8.7", "0.8.8", "0.8.9", "0.8.10"],
|
||||
"0.8.6": ["0.8.7", "0.8.8", "0.8.9", "0.8.10"],
|
||||
"0.8.7": ["0.8.8", "0.8.9", "0.8.10"],
|
||||
"0.8.8": ["0.8.9", "0.8.10"],
|
||||
"0.8.9": ["0.8.10"],
|
||||
# 可以根据需要添加更多兼容映射
|
||||
# "0.9.0": ["0.9.1", "0.9.2", "0.9.3"], # 示例:0.9.x系列兼容
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def normalize_version(version: str) -> str:
|
||||
"""标准化版本号,移除snapshot标识
|
||||
|
||||
Args:
|
||||
version: 原始版本号,如 "0.8.0-snapshot.1"
|
||||
|
||||
Returns:
|
||||
str: 标准化后的版本号,如 "0.8.0"
|
||||
"""
|
||||
if not version:
|
||||
return "0.0.0"
|
||||
|
||||
# 移除snapshot部分
|
||||
normalized = re.sub(r"-snapshot\.\d+", "", version.strip())
|
||||
|
||||
# 确保版本号格式正确
|
||||
if not re.match(r"^\d+(\.\d+){0,2}$", normalized):
|
||||
# 如果不是有效的版本号格式,返回默认版本
|
||||
return "0.0.0"
|
||||
|
||||
# 尝试补全版本号
|
||||
parts = normalized.split(".")
|
||||
while len(parts) < 3:
|
||||
parts.append("0")
|
||||
normalized = ".".join(parts[:3])
|
||||
|
||||
return normalized
|
||||
|
||||
@staticmethod
|
||||
def parse_version(version: str) -> Tuple[int, int, int]:
|
||||
"""解析版本号为元组
|
||||
|
||||
Args:
|
||||
version: 版本号字符串
|
||||
|
||||
Returns:
|
||||
Tuple[int, int, int]: (major, minor, patch)
|
||||
"""
|
||||
normalized = VersionComparator.normalize_version(version)
|
||||
try:
|
||||
parts = normalized.split(".")
|
||||
return (int(parts[0]), int(parts[1]), int(parts[2]))
|
||||
except (ValueError, IndexError):
|
||||
logger.warning(f"无法解析版本号: {version},使用默认版本 0.0.0")
|
||||
return (0, 0, 0)
|
||||
|
||||
@staticmethod
|
||||
def compare_versions(version1: str, version2: str) -> int:
|
||||
"""比较两个版本号
|
||||
|
||||
Args:
|
||||
version1: 第一个版本号
|
||||
version2: 第二个版本号
|
||||
|
||||
Returns:
|
||||
int: -1 if version1 < version2, 0 if equal, 1 if version1 > version2
|
||||
"""
|
||||
v1_tuple = VersionComparator.parse_version(version1)
|
||||
v2_tuple = VersionComparator.parse_version(version2)
|
||||
|
||||
if v1_tuple < v2_tuple:
|
||||
return -1
|
||||
elif v1_tuple > v2_tuple:
|
||||
return 1
|
||||
else:
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
def check_forward_compatibility(current_version: str, max_version: str) -> Tuple[bool, str]:
|
||||
"""检查向前兼容性(仅使用兼容性映射表)
|
||||
|
||||
Args:
|
||||
current_version: 当前版本
|
||||
max_version: 插件声明的最大支持版本
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (是否兼容, 兼容信息)
|
||||
"""
|
||||
current_normalized = VersionComparator.normalize_version(current_version)
|
||||
max_normalized = VersionComparator.normalize_version(max_version)
|
||||
|
||||
# 检查兼容性映射表
|
||||
if max_normalized in VersionComparator.COMPATIBILITY_MAP:
|
||||
compatible_versions = VersionComparator.COMPATIBILITY_MAP[max_normalized]
|
||||
if current_normalized in compatible_versions:
|
||||
return True, f"根据兼容性映射表,版本 {current_normalized} 与 {max_normalized} 兼容"
|
||||
|
||||
return False, ""
|
||||
|
||||
@staticmethod
|
||||
def is_version_in_range(version: str, min_version: str = "", max_version: str = "") -> Tuple[bool, str]:
|
||||
"""检查版本是否在指定范围内,支持兼容性检查
|
||||
|
||||
Args:
|
||||
version: 要检查的版本号
|
||||
min_version: 最小版本号(可选)
|
||||
max_version: 最大版本号(可选)
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (是否兼容, 错误信息或兼容信息)
|
||||
"""
|
||||
if not min_version and not max_version:
|
||||
return True, ""
|
||||
|
||||
version_normalized = VersionComparator.normalize_version(version)
|
||||
|
||||
# 检查最小版本
|
||||
if min_version:
|
||||
min_normalized = VersionComparator.normalize_version(min_version)
|
||||
if VersionComparator.compare_versions(version_normalized, min_normalized) < 0:
|
||||
return False, f"版本 {version_normalized} 低于最小要求版本 {min_normalized}"
|
||||
|
||||
# 检查最大版本
|
||||
if max_version:
|
||||
max_normalized = VersionComparator.normalize_version(max_version)
|
||||
comparison = VersionComparator.compare_versions(version_normalized, max_normalized)
|
||||
|
||||
if comparison > 0:
|
||||
# 严格版本检查失败,尝试兼容性检查
|
||||
is_compatible, compat_msg = VersionComparator.check_forward_compatibility(
|
||||
version_normalized, max_normalized
|
||||
)
|
||||
|
||||
if not is_compatible:
|
||||
return False, f"版本 {version_normalized} 高于最大支持版本 {max_normalized},且无兼容性映射"
|
||||
|
||||
logger.info(f"版本兼容性检查:{compat_msg}")
|
||||
return True, compat_msg
|
||||
return True, ""
|
||||
|
||||
@staticmethod
|
||||
def get_current_host_version() -> str:
|
||||
"""获取当前主机应用版本
|
||||
|
||||
Returns:
|
||||
str: 当前版本号
|
||||
"""
|
||||
return VersionComparator.normalize_version(MMC_VERSION)
|
||||
|
||||
@staticmethod
|
||||
def add_compatibility_mapping(base_version: str, compatible_versions: list) -> None:
|
||||
"""动态添加兼容性映射
|
||||
|
||||
Args:
|
||||
base_version: 基础版本(插件声明的最大支持版本)
|
||||
compatible_versions: 兼容的版本列表
|
||||
"""
|
||||
base_normalized = VersionComparator.normalize_version(base_version)
|
||||
VersionComparator.COMPATIBILITY_MAP[base_normalized] = [
|
||||
VersionComparator.normalize_version(v) for v in compatible_versions
|
||||
]
|
||||
logger.info(f"添加兼容性映射:{base_normalized} -> {compatible_versions}")
|
||||
|
||||
@staticmethod
|
||||
def get_compatibility_info() -> Dict[str, list]:
|
||||
"""获取当前的兼容性映射表
|
||||
|
||||
Returns:
|
||||
Dict[str, list]: 兼容性映射表的副本
|
||||
"""
|
||||
return VersionComparator.COMPATIBILITY_MAP.copy()
|
||||
|
||||
|
||||
class ManifestValidator:
|
||||
"""Manifest文件验证器"""
|
||||
|
||||
# 必需字段(必须存在且不能为空)
|
||||
REQUIRED_FIELDS = ["manifest_version", "name", "version", "description", "author"]
|
||||
|
||||
# 可选字段(可以不存在或为空)
|
||||
OPTIONAL_FIELDS = [
|
||||
"license",
|
||||
"host_application",
|
||||
"homepage_url",
|
||||
"repository_url",
|
||||
"keywords",
|
||||
"categories",
|
||||
"default_locale",
|
||||
"locales_path",
|
||||
"plugin_info",
|
||||
]
|
||||
|
||||
# 建议填写的字段(会给出警告但不会导致验证失败)
|
||||
RECOMMENDED_FIELDS = ["license", "keywords", "categories"]
|
||||
|
||||
SUPPORTED_MANIFEST_VERSIONS = [1]
|
||||
|
||||
def __init__(self):
|
||||
self.validation_errors = []
|
||||
self.validation_warnings = []
|
||||
|
||||
def validate_manifest(self, manifest_data: Dict[str, Any]) -> bool:
|
||||
"""验证manifest数据
|
||||
|
||||
Args:
|
||||
manifest_data: manifest数据字典
|
||||
|
||||
Returns:
|
||||
bool: 是否验证通过(只有错误会导致验证失败,警告不会)
|
||||
"""
|
||||
self.validation_errors.clear()
|
||||
self.validation_warnings.clear()
|
||||
|
||||
# 检查必需字段
|
||||
for field in self.REQUIRED_FIELDS:
|
||||
if field not in manifest_data:
|
||||
self.validation_errors.append(f"缺少必需字段: {field}")
|
||||
elif not manifest_data[field]:
|
||||
self.validation_errors.append(f"必需字段不能为空: {field}")
|
||||
|
||||
# 检查manifest版本
|
||||
if "manifest_version" in manifest_data:
|
||||
version = manifest_data["manifest_version"]
|
||||
if version not in self.SUPPORTED_MANIFEST_VERSIONS:
|
||||
self.validation_errors.append(
|
||||
f"不支持的manifest版本: {version},支持的版本: {self.SUPPORTED_MANIFEST_VERSIONS}"
|
||||
)
|
||||
|
||||
# 检查作者信息格式
|
||||
if "author" in manifest_data:
|
||||
author = manifest_data["author"]
|
||||
if isinstance(author, dict):
|
||||
if "name" not in author or not author["name"]:
|
||||
self.validation_errors.append("作者信息缺少name字段或为空")
|
||||
# url字段是可选的
|
||||
if "url" in author and author["url"]:
|
||||
url = author["url"]
|
||||
if not (url.startswith("http://") or url.startswith("https://")):
|
||||
self.validation_warnings.append("作者URL建议使用完整的URL格式")
|
||||
elif isinstance(author, str):
|
||||
if not author.strip():
|
||||
self.validation_errors.append("作者信息不能为空")
|
||||
else:
|
||||
self.validation_errors.append("作者信息格式错误,应为字符串或包含name字段的对象")
|
||||
# 检查主机应用版本要求(可选)
|
||||
if "host_application" in manifest_data:
|
||||
host_app = manifest_data["host_application"]
|
||||
if isinstance(host_app, dict):
|
||||
min_version = host_app.get("min_version", "")
|
||||
max_version = host_app.get("max_version", "")
|
||||
|
||||
# 验证版本字段格式
|
||||
for version_field in ["min_version", "max_version"]:
|
||||
if version_field in host_app and not host_app[version_field]:
|
||||
self.validation_warnings.append(f"host_application.{version_field}为空")
|
||||
|
||||
# 检查当前主机版本兼容性
|
||||
if min_version or max_version:
|
||||
current_version = VersionComparator.get_current_host_version()
|
||||
is_compatible, error_msg = VersionComparator.is_version_in_range(
|
||||
current_version, min_version, max_version
|
||||
)
|
||||
|
||||
if not is_compatible:
|
||||
self.validation_errors.append(f"版本兼容性检查失败: {error_msg} (当前版本: {current_version})")
|
||||
else:
|
||||
logger.debug(
|
||||
f"版本兼容性检查通过: 当前版本 {current_version} 符合要求 [{min_version}, {max_version}]"
|
||||
)
|
||||
else:
|
||||
self.validation_errors.append("host_application格式错误,应为对象")
|
||||
|
||||
# 检查URL格式(可选字段)
|
||||
for url_field in ["homepage_url", "repository_url"]:
|
||||
if url_field in manifest_data and manifest_data[url_field]:
|
||||
url: str = manifest_data[url_field]
|
||||
if not (url.startswith("http://") or url.startswith("https://")):
|
||||
self.validation_warnings.append(f"{url_field}建议使用完整的URL格式")
|
||||
|
||||
# 检查数组字段格式(可选字段)
|
||||
for list_field in ["keywords", "categories"]:
|
||||
if list_field in manifest_data:
|
||||
field_value = manifest_data[list_field]
|
||||
if field_value is not None and not isinstance(field_value, list):
|
||||
self.validation_errors.append(f"{list_field}应为数组格式")
|
||||
elif isinstance(field_value, list):
|
||||
# 检查数组元素是否为字符串
|
||||
for i, item in enumerate(field_value):
|
||||
if not isinstance(item, str):
|
||||
self.validation_warnings.append(f"{list_field}[{i}]应为字符串")
|
||||
|
||||
# 检查建议字段(给出警告)
|
||||
for field in self.RECOMMENDED_FIELDS:
|
||||
if field not in manifest_data or not manifest_data[field]:
|
||||
self.validation_warnings.append(f"建议填写字段: {field}")
|
||||
|
||||
# 检查plugin_info结构(可选)
|
||||
if "plugin_info" in manifest_data:
|
||||
plugin_info = manifest_data["plugin_info"]
|
||||
if isinstance(plugin_info, dict):
|
||||
# 检查components数组
|
||||
if "components" in plugin_info:
|
||||
components = plugin_info["components"]
|
||||
if not isinstance(components, list):
|
||||
self.validation_errors.append("plugin_info.components应为数组格式")
|
||||
else:
|
||||
for i, component in enumerate(components):
|
||||
if not isinstance(component, dict):
|
||||
self.validation_errors.append(f"plugin_info.components[{i}]应为对象")
|
||||
else:
|
||||
# 检查组件必需字段
|
||||
for comp_field in ["type", "name", "description"]:
|
||||
if comp_field not in component or not component[comp_field]:
|
||||
self.validation_errors.append(
|
||||
f"plugin_info.components[{i}]缺少必需字段: {comp_field}"
|
||||
)
|
||||
else:
|
||||
self.validation_errors.append("plugin_info应为对象格式")
|
||||
|
||||
return len(self.validation_errors) == 0
|
||||
|
||||
def get_validation_report(self) -> str:
|
||||
"""获取验证报告"""
|
||||
report = []
|
||||
|
||||
if self.validation_errors:
|
||||
report.append("❌ 验证错误:")
|
||||
report.extend(f" - {error}" for error in self.validation_errors)
|
||||
if self.validation_warnings:
|
||||
report.append("⚠️ 验证警告:")
|
||||
report.extend(f" - {warning}" for warning in self.validation_warnings)
|
||||
if not self.validation_errors and not self.validation_warnings:
|
||||
report.append("✅ Manifest文件验证通过")
|
||||
|
||||
return "\n".join(report)
|
||||
|
||||
|
||||
# class ManifestGenerator:
|
||||
# """Manifest文件生成器"""
|
||||
|
||||
# def __init__(self):
|
||||
# self.template = {
|
||||
# "manifest_version": 1,
|
||||
# "name": "",
|
||||
# "version": "1.0.0",
|
||||
# "description": "",
|
||||
# "author": {"name": "", "url": ""},
|
||||
# "license": "MIT",
|
||||
# "host_application": {"min_version": "1.0.0", "max_version": "4.0.0"},
|
||||
# "homepage_url": "",
|
||||
# "repository_url": "",
|
||||
# "keywords": [],
|
||||
# "categories": [],
|
||||
# "default_locale": "zh-CN",
|
||||
# "locales_path": "_locales",
|
||||
# }
|
||||
|
||||
# def generate_from_plugin(self, plugin_instance: BasePlugin) -> Dict[str, Any]:
|
||||
# """从插件实例生成manifest
|
||||
|
||||
# Args:
|
||||
# plugin_instance: BasePlugin实例
|
||||
|
||||
# Returns:
|
||||
# Dict[str, Any]: 生成的manifest数据
|
||||
# """
|
||||
# manifest = self.template.copy()
|
||||
|
||||
# # 基本信息
|
||||
# manifest["name"] = plugin_instance.plugin_name
|
||||
# manifest["version"] = plugin_instance.plugin_version
|
||||
# manifest["description"] = plugin_instance.plugin_description
|
||||
|
||||
# # 作者信息
|
||||
# if plugin_instance.plugin_author:
|
||||
# manifest["author"]["name"] = plugin_instance.plugin_author
|
||||
|
||||
# # 组件信息
|
||||
# components = []
|
||||
# plugin_components = plugin_instance.get_plugin_components()
|
||||
|
||||
# for component_info, component_class in plugin_components:
|
||||
# component_data: Dict[str, Any] = {
|
||||
# "type": component_info.component_type.value,
|
||||
# "name": component_info.name,
|
||||
# "description": component_info.description,
|
||||
# }
|
||||
|
||||
# # 添加激活模式信息(对于Action组件)
|
||||
# if hasattr(component_class, "focus_activation_type"):
|
||||
# activation_modes = []
|
||||
# if hasattr(component_class, "focus_activation_type"):
|
||||
# activation_modes.append(component_class.focus_activation_type.value)
|
||||
# if hasattr(component_class, "normal_activation_type"):
|
||||
# activation_modes.append(component_class.normal_activation_type.value)
|
||||
# component_data["activation_modes"] = list(set(activation_modes))
|
||||
|
||||
# # 添加关键词信息
|
||||
# if hasattr(component_class, "activation_keywords"):
|
||||
# keywords = getattr(component_class, "activation_keywords", [])
|
||||
# if keywords:
|
||||
# component_data["keywords"] = keywords
|
||||
|
||||
# components.append(component_data)
|
||||
|
||||
# manifest["plugin_info"] = {"is_built_in": True, "plugin_type": "general", "components": components}
|
||||
|
||||
# return manifest
|
||||
|
||||
# def save_manifest(self, manifest_data: Dict[str, Any], plugin_dir: str) -> bool:
|
||||
# """保存manifest文件
|
||||
|
||||
# Args:
|
||||
# manifest_data: manifest数据
|
||||
# plugin_dir: 插件目录
|
||||
|
||||
# Returns:
|
||||
# bool: 是否保存成功
|
||||
# """
|
||||
# try:
|
||||
# manifest_path = os.path.join(plugin_dir, "_manifest.json")
|
||||
# with open(manifest_path, "w", encoding="utf-8") as f:
|
||||
# json.dump(manifest_data, f, ensure_ascii=False, indent=2)
|
||||
# logger.info(f"Manifest文件已保存: {manifest_path}")
|
||||
# return True
|
||||
# except Exception as e:
|
||||
# logger.error(f"保存manifest文件失败: {e}")
|
||||
# return False
|
||||
|
||||
|
||||
# def validate_plugin_manifest(plugin_dir: str) -> bool:
|
||||
# """验证插件目录中的manifest文件
|
||||
|
||||
# Args:
|
||||
# plugin_dir: 插件目录路径
|
||||
|
||||
# Returns:
|
||||
# bool: 是否验证通过
|
||||
# """
|
||||
# manifest_path = os.path.join(plugin_dir, "_manifest.json")
|
||||
|
||||
# if not os.path.exists(manifest_path):
|
||||
# logger.warning(f"未找到manifest文件: {manifest_path}")
|
||||
# return False
|
||||
|
||||
# try:
|
||||
# with open(manifest_path, "r", encoding="utf-8") as f:
|
||||
# manifest_data = json.load(f)
|
||||
|
||||
# validator = ManifestValidator()
|
||||
# is_valid = validator.validate_manifest(manifest_data)
|
||||
|
||||
# logger.info(f"Manifest验证结果:\n{validator.get_validation_report()}")
|
||||
|
||||
# return is_valid
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"读取或验证manifest文件失败: {e}")
|
||||
# return False
|
||||
|
||||
|
||||
# def generate_plugin_manifest(plugin_instance: BasePlugin, save_to_file: bool = True) -> Optional[Dict[str, Any]]:
|
||||
# """为插件生成manifest文件
|
||||
|
||||
# Args:
|
||||
# plugin_instance: BasePlugin实例
|
||||
# save_to_file: 是否保存到文件
|
||||
|
||||
# Returns:
|
||||
# Optional[Dict[str, Any]]: 生成的manifest数据
|
||||
# """
|
||||
# try:
|
||||
# generator = ManifestGenerator()
|
||||
# manifest_data = generator.generate_from_plugin(plugin_instance)
|
||||
|
||||
# if save_to_file and plugin_instance.plugin_dir:
|
||||
# generator.save_manifest(manifest_data, plugin_instance.plugin_dir)
|
||||
|
||||
# return manifest_data
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"生成manifest文件失败: {e}")
|
||||
# return None
|
||||
Reference in New Issue
Block a user