初始化
This commit is contained in:
49
src/plugin_system/base/__init__.py
Normal file
49
src/plugin_system/base/__init__.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""
|
||||
插件基础类模块
|
||||
|
||||
提供插件开发的基础类和类型定义
|
||||
"""
|
||||
|
||||
from .base_plugin import BasePlugin
|
||||
from .base_action import BaseAction
|
||||
from .base_tool import BaseTool
|
||||
from .base_command import BaseCommand
|
||||
from .base_events_handler import BaseEventHandler
|
||||
from .component_types import (
|
||||
ComponentType,
|
||||
ActionActivationType,
|
||||
ChatMode,
|
||||
ComponentInfo,
|
||||
ActionInfo,
|
||||
CommandInfo,
|
||||
ToolInfo,
|
||||
PluginInfo,
|
||||
PythonDependency,
|
||||
EventHandlerInfo,
|
||||
EventType,
|
||||
MaiMessages,
|
||||
ToolParamType,
|
||||
)
|
||||
from .config_types import ConfigField
|
||||
|
||||
__all__ = [
|
||||
"BasePlugin",
|
||||
"BaseAction",
|
||||
"BaseCommand",
|
||||
"BaseTool",
|
||||
"ComponentType",
|
||||
"ActionActivationType",
|
||||
"ChatMode",
|
||||
"ComponentInfo",
|
||||
"ActionInfo",
|
||||
"CommandInfo",
|
||||
"ToolInfo",
|
||||
"PluginInfo",
|
||||
"PythonDependency",
|
||||
"ConfigField",
|
||||
"EventHandlerInfo",
|
||||
"EventType",
|
||||
"BaseEventHandler",
|
||||
"MaiMessages",
|
||||
"ToolParamType",
|
||||
]
|
||||
437
src/plugin_system/base/base_action.py
Normal file
437
src/plugin_system/base/base_action.py
Normal file
@@ -0,0 +1,437 @@
|
||||
import time
|
||||
import asyncio
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Tuple, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.plugin_system.base.component_types import ActionActivationType, ChatMode, ActionInfo, ComponentType
|
||||
from src.plugin_system.apis import send_api, database_api, message_api
|
||||
|
||||
|
||||
logger = get_logger("base_action")
|
||||
|
||||
|
||||
class BaseAction(ABC):
|
||||
"""Action组件基类
|
||||
|
||||
Action是插件的一种组件类型,用于处理聊天中的动作逻辑
|
||||
|
||||
子类可以通过类属性定义激活条件,这些会在实例化时转换为实例属性:
|
||||
- focus_activation_type: 专注模式激活类型
|
||||
- normal_activation_type: 普通模式激活类型
|
||||
- activation_keywords: 激活关键词列表
|
||||
- keyword_case_sensitive: 关键词是否区分大小写
|
||||
- mode_enable: 启用的聊天模式
|
||||
- parallel_action: 是否允许并行执行
|
||||
- random_activation_probability: 随机激活概率
|
||||
- llm_judge_prompt: LLM判断提示词
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
action_data: dict,
|
||||
reasoning: str,
|
||||
cycle_timers: dict,
|
||||
thinking_id: str,
|
||||
chat_stream: ChatStream,
|
||||
log_prefix: str = "",
|
||||
plugin_config: Optional[dict] = None,
|
||||
action_message: Optional[dict] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# sourcery skip: hoist-similar-statement-from-if, merge-else-if-into-elif, move-assign-in-block, swap-if-else-branches, swap-nested-ifs
|
||||
"""初始化Action组件
|
||||
|
||||
Args:
|
||||
action_data: 动作数据
|
||||
reasoning: 执行该动作的理由
|
||||
cycle_timers: 计时器字典
|
||||
thinking_id: 思考ID
|
||||
chat_stream: 聊天流对象
|
||||
log_prefix: 日志前缀
|
||||
plugin_config: 插件配置字典
|
||||
action_message: 消息数据
|
||||
**kwargs: 其他参数
|
||||
"""
|
||||
if plugin_config is None:
|
||||
plugin_config = {}
|
||||
self.action_data = action_data
|
||||
self.reasoning = reasoning
|
||||
self.cycle_timers = cycle_timers
|
||||
self.thinking_id = thinking_id
|
||||
self.log_prefix = log_prefix
|
||||
|
||||
self.plugin_config = plugin_config or {}
|
||||
"""对应的插件配置"""
|
||||
|
||||
# 设置动作基本信息实例属性
|
||||
self.action_name: str = getattr(self, "action_name", self.__class__.__name__.lower().replace("action", ""))
|
||||
"""Action的名字"""
|
||||
self.action_description: str = getattr(self, "action_description", self.__doc__ or "Action组件")
|
||||
"""Action的描述"""
|
||||
self.action_parameters: dict = getattr(self.__class__, "action_parameters", {}).copy()
|
||||
self.action_require: list[str] = getattr(self.__class__, "action_require", []).copy()
|
||||
|
||||
# 设置激活类型实例属性(从类属性复制,提供默认值)
|
||||
self.focus_activation_type = getattr(self.__class__, "focus_activation_type", ActionActivationType.ALWAYS)
|
||||
"""FOCUS模式下的激活类型"""
|
||||
self.normal_activation_type = getattr(self.__class__, "normal_activation_type", ActionActivationType.ALWAYS)
|
||||
"""NORMAL模式下的激活类型"""
|
||||
self.activation_type = getattr(self.__class__, "activation_type", self.focus_activation_type)
|
||||
"""激活类型"""
|
||||
self.random_activation_probability: float = getattr(self.__class__, "random_activation_probability", 0.0)
|
||||
"""当激活类型为RANDOM时的概率"""
|
||||
self.llm_judge_prompt: str = getattr(self.__class__, "llm_judge_prompt", "")
|
||||
"""协助LLM进行判断的Prompt"""
|
||||
self.activation_keywords: list[str] = getattr(self.__class__, "activation_keywords", []).copy()
|
||||
"""激活类型为KEYWORD时的KEYWORDS列表"""
|
||||
self.keyword_case_sensitive: bool = getattr(self.__class__, "keyword_case_sensitive", False)
|
||||
self.mode_enable: ChatMode = getattr(self.__class__, "mode_enable", ChatMode.ALL)
|
||||
self.parallel_action: bool = getattr(self.__class__, "parallel_action", True)
|
||||
self.associated_types: list[str] = getattr(self.__class__, "associated_types", []).copy()
|
||||
|
||||
# =============================================================================
|
||||
# 便捷属性 - 直接在初始化时获取常用聊天信息(带类型注解)
|
||||
# =============================================================================
|
||||
|
||||
# 获取聊天流对象
|
||||
self.chat_stream = chat_stream or kwargs.get("chat_stream")
|
||||
self.chat_id = self.chat_stream.stream_id
|
||||
self.platform = getattr(self.chat_stream, "platform", None)
|
||||
|
||||
# 初始化基础信息(带类型注解)
|
||||
self.action_message = action_message
|
||||
|
||||
self.group_id = None
|
||||
self.group_name = None
|
||||
self.user_id = None
|
||||
self.user_nickname = None
|
||||
self.is_group = False
|
||||
self.target_id = None
|
||||
self.has_action_message = False
|
||||
|
||||
if self.action_message:
|
||||
self.has_action_message = True
|
||||
else:
|
||||
self.action_message = {}
|
||||
|
||||
if self.has_action_message:
|
||||
if self.action_name != "no_reply":
|
||||
self.group_id = str(self.action_message.get("chat_info_group_id", None))
|
||||
self.group_name = self.action_message.get("chat_info_group_name", None)
|
||||
|
||||
self.user_id = str(self.action_message.get("user_id", None))
|
||||
self.user_nickname = self.action_message.get("user_nickname", None)
|
||||
if self.group_id:
|
||||
self.is_group = True
|
||||
self.target_id = self.group_id
|
||||
else:
|
||||
self.is_group = False
|
||||
self.target_id = self.user_id
|
||||
else:
|
||||
if self.chat_stream.group_info:
|
||||
self.group_id = self.chat_stream.group_info.group_id
|
||||
self.group_name = self.chat_stream.group_info.group_name
|
||||
self.is_group = True
|
||||
self.target_id = self.group_id
|
||||
else:
|
||||
self.user_id = self.chat_stream.user_info.user_id
|
||||
self.user_nickname = self.chat_stream.user_info.user_nickname
|
||||
self.is_group = False
|
||||
self.target_id = self.user_id
|
||||
|
||||
logger.debug(f"{self.log_prefix} Action组件初始化完成")
|
||||
logger.debug(
|
||||
f"{self.log_prefix} 聊天信息: 类型={'群聊' if self.is_group else '私聊'}, 平台={self.platform}, 目标={self.target_id}"
|
||||
)
|
||||
|
||||
async def wait_for_new_message(self, timeout: int = 1200) -> Tuple[bool, str]:
|
||||
"""等待新消息或超时
|
||||
|
||||
在loop_start_time之后等待新消息,如果没有新消息且没有超时,就一直等待。
|
||||
使用message_api检查self.chat_id对应的聊天中是否有新消息。
|
||||
|
||||
Args:
|
||||
timeout: 超时时间(秒),默认1200秒
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (是否收到新消息, 空字符串)
|
||||
"""
|
||||
try:
|
||||
# 获取循环开始时间,如果没有则使用当前时间
|
||||
loop_start_time = self.action_data.get("loop_start_time", time.time())
|
||||
logger.info(f"{self.log_prefix} 开始等待新消息... (最长等待: {timeout}秒, 从时间点: {loop_start_time})")
|
||||
|
||||
# 确保有有效的chat_id
|
||||
if not self.chat_id:
|
||||
logger.error(f"{self.log_prefix} 等待新消息失败: 没有有效的chat_id")
|
||||
return False, "没有有效的chat_id"
|
||||
|
||||
wait_start_time = asyncio.get_event_loop().time()
|
||||
while True:
|
||||
# 检查关闭标志
|
||||
# shutting_down = self.get_action_context("shutting_down", False)
|
||||
# if shutting_down:
|
||||
# logger.info(f"{self.log_prefix} 等待新消息时检测到关闭信号,中断等待")
|
||||
# return False, ""
|
||||
|
||||
# 检查新消息
|
||||
current_time = time.time()
|
||||
new_message_count = message_api.count_new_messages(
|
||||
chat_id=self.chat_id, start_time=loop_start_time, end_time=current_time
|
||||
)
|
||||
|
||||
if new_message_count > 0:
|
||||
logger.info(f"{self.log_prefix} 检测到{new_message_count}条新消息,聊天ID: {self.chat_id}")
|
||||
return True, ""
|
||||
|
||||
# 检查超时
|
||||
elapsed_time = asyncio.get_event_loop().time() - wait_start_time
|
||||
if elapsed_time > timeout:
|
||||
logger.warning(f"{self.log_prefix} 等待新消息超时({timeout}秒),聊天ID: {self.chat_id}")
|
||||
return False, ""
|
||||
|
||||
# 每30秒记录一次等待状态
|
||||
if int(elapsed_time) % 15 == 0 and int(elapsed_time) > 0:
|
||||
logger.debug(f"{self.log_prefix} 已等待{int(elapsed_time)}秒,继续等待新消息...")
|
||||
|
||||
# 短暂休眠
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"{self.log_prefix} 等待新消息被中断 (CancelledError)")
|
||||
return False, ""
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 等待新消息时发生错误: {e}")
|
||||
return False, f"等待新消息失败: {str(e)}"
|
||||
|
||||
async def send_text(
|
||||
self, content: str, reply_to: str = "", typing: bool = False
|
||||
) -> bool:
|
||||
"""发送文本消息
|
||||
|
||||
Args:
|
||||
content: 文本内容
|
||||
reply_to: 回复消息,格式为"发送者:消息内容"
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
if not self.chat_id:
|
||||
logger.error(f"{self.log_prefix} 缺少聊天ID")
|
||||
return False
|
||||
|
||||
return await send_api.text_to_stream(
|
||||
text=content,
|
||||
stream_id=self.chat_id,
|
||||
reply_to=reply_to,
|
||||
typing=typing,
|
||||
)
|
||||
|
||||
async def send_emoji(self, emoji_base64: str) -> bool:
|
||||
"""发送表情包
|
||||
|
||||
Args:
|
||||
emoji_base64: 表情包的base64编码
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
if not self.chat_id:
|
||||
logger.error(f"{self.log_prefix} 缺少聊天ID")
|
||||
return False
|
||||
|
||||
return await send_api.emoji_to_stream(emoji_base64, self.chat_id)
|
||||
|
||||
async def send_image(self, image_base64: str) -> bool:
|
||||
"""发送图片
|
||||
|
||||
Args:
|
||||
image_base64: 图片的base64编码
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
if not self.chat_id:
|
||||
logger.error(f"{self.log_prefix} 缺少聊天ID")
|
||||
return False
|
||||
|
||||
return await send_api.image_to_stream(image_base64, self.chat_id)
|
||||
|
||||
async def send_custom(self, message_type: str, content: str, typing: bool = False, reply_to: str = "") -> bool:
|
||||
"""发送自定义类型消息
|
||||
|
||||
Args:
|
||||
message_type: 消息类型,如"video"、"file"、"audio"等
|
||||
content: 消息内容
|
||||
typing: 是否显示正在输入
|
||||
reply_to: 回复消息,格式为"发送者:消息内容"
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
if not self.chat_id:
|
||||
logger.error(f"{self.log_prefix} 缺少聊天ID")
|
||||
return False
|
||||
|
||||
return await send_api.custom_to_stream(
|
||||
message_type=message_type,
|
||||
content=content,
|
||||
stream_id=self.chat_id,
|
||||
typing=typing,
|
||||
reply_to=reply_to,
|
||||
)
|
||||
|
||||
async def store_action_info(
|
||||
self,
|
||||
action_build_into_prompt: bool = False,
|
||||
action_prompt_display: str = "",
|
||||
action_done: bool = True,
|
||||
) -> None:
|
||||
"""存储动作信息到数据库
|
||||
|
||||
Args:
|
||||
action_build_into_prompt: 是否构建到提示中
|
||||
action_prompt_display: 显示的action提示信息
|
||||
action_done: action是否完成
|
||||
"""
|
||||
await database_api.store_action_info(
|
||||
chat_stream=self.chat_stream,
|
||||
action_build_into_prompt=action_build_into_prompt,
|
||||
action_prompt_display=action_prompt_display,
|
||||
action_done=action_done,
|
||||
thinking_id=self.thinking_id,
|
||||
action_data=self.action_data,
|
||||
action_name=self.action_name,
|
||||
)
|
||||
|
||||
async def send_command(
|
||||
self, command_name: str, args: Optional[dict] = None, display_message: str = "", storage_message: bool = True
|
||||
) -> bool:
|
||||
"""发送命令消息
|
||||
|
||||
使用stream API发送命令
|
||||
|
||||
Args:
|
||||
command_name: 命令名称
|
||||
args: 命令参数
|
||||
display_message: 显示消息
|
||||
storage_message: 是否存储消息到数据库
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
try:
|
||||
if not self.chat_id:
|
||||
logger.error(f"{self.log_prefix} 缺少聊天ID")
|
||||
return False
|
||||
|
||||
# 构造命令数据
|
||||
command_data = {"name": command_name, "args": args or {}}
|
||||
|
||||
success = await send_api.command_to_stream(
|
||||
command=command_data,
|
||||
stream_id=self.chat_id,
|
||||
storage_message=storage_message,
|
||||
display_message=display_message,
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"{self.log_prefix} 成功发送命令: {command_name}")
|
||||
else:
|
||||
logger.error(f"{self.log_prefix} 发送命令失败: {command_name}")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 发送命令时出错: {e}")
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_action_info(cls) -> "ActionInfo":
|
||||
"""从类属性生成ActionInfo
|
||||
|
||||
所有信息都从类属性中读取,确保一致性和完整性。
|
||||
Action类必须定义所有必要的类属性。
|
||||
|
||||
Returns:
|
||||
ActionInfo: 生成的Action信息对象
|
||||
"""
|
||||
|
||||
# 从类属性读取名称,如果没有定义则使用类名自动生成
|
||||
name = getattr(cls, "action_name", cls.__name__.lower().replace("action", ""))
|
||||
if "." in name:
|
||||
logger.error(f"Action名称 '{name}' 包含非法字符 '.',请使用下划线替代")
|
||||
raise ValueError(f"Action名称 '{name}' 包含非法字符 '.',请使用下划线替代")
|
||||
# 获取focus_activation_type和normal_activation_type
|
||||
focus_activation_type = getattr(cls, "focus_activation_type", ActionActivationType.ALWAYS)
|
||||
normal_activation_type = getattr(cls, "normal_activation_type", ActionActivationType.ALWAYS)
|
||||
|
||||
# 处理activation_type:如果插件中声明了就用插件的值,否则默认使用focus_activation_type
|
||||
activation_type = getattr(cls, "activation_type", focus_activation_type)
|
||||
|
||||
return ActionInfo(
|
||||
name=name,
|
||||
component_type=ComponentType.ACTION,
|
||||
description=getattr(cls, "action_description", "Action动作"),
|
||||
focus_activation_type=focus_activation_type,
|
||||
normal_activation_type=normal_activation_type,
|
||||
activation_type=activation_type,
|
||||
activation_keywords=getattr(cls, "activation_keywords", []).copy(),
|
||||
keyword_case_sensitive=getattr(cls, "keyword_case_sensitive", False),
|
||||
mode_enable=getattr(cls, "mode_enable", ChatMode.ALL),
|
||||
parallel_action=getattr(cls, "parallel_action", True),
|
||||
random_activation_probability=getattr(cls, "random_activation_probability", 0.0),
|
||||
llm_judge_prompt=getattr(cls, "llm_judge_prompt", ""),
|
||||
# 使用正确的字段名
|
||||
action_parameters=getattr(cls, "action_parameters", {}).copy(),
|
||||
action_require=getattr(cls, "action_require", []).copy(),
|
||||
associated_types=getattr(cls, "associated_types", []).copy(),
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
"""执行Action的抽象方法,子类必须实现
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (是否执行成功, 回复文本)
|
||||
"""
|
||||
pass
|
||||
|
||||
async def handle_action(self) -> Tuple[bool, str]:
|
||||
"""兼容旧系统的handle_action接口,委托给execute方法
|
||||
|
||||
为了保持向后兼容性,旧系统的代码可能会调用handle_action方法。
|
||||
此方法将调用委托给新的execute方法。
|
||||
|
||||
Returns:
|
||||
Tuple[bool, str]: (是否执行成功, 回复文本)
|
||||
"""
|
||||
return await self.execute()
|
||||
|
||||
def get_config(self, key: str, default=None):
|
||||
"""获取插件配置值,使用嵌套键访问
|
||||
|
||||
Args:
|
||||
key: 配置键名,使用嵌套访问如 "section.subsection.key"
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
Any: 配置值或默认值
|
||||
"""
|
||||
if not self.plugin_config:
|
||||
return default
|
||||
|
||||
# 支持嵌套键访问
|
||||
keys = key.split(".")
|
||||
current = self.plugin_config
|
||||
|
||||
for k in keys:
|
||||
if isinstance(current, dict) and k in current:
|
||||
current = current[k]
|
||||
else:
|
||||
return default
|
||||
|
||||
return current
|
||||
228
src/plugin_system/base/base_command.py
Normal file
228
src/plugin_system/base/base_command.py
Normal file
@@ -0,0 +1,228 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Tuple, Optional
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import CommandInfo, ComponentType
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.plugin_system.apis import send_api
|
||||
|
||||
logger = get_logger("base_command")
|
||||
|
||||
|
||||
class BaseCommand(ABC):
|
||||
"""Command组件基类
|
||||
|
||||
Command是插件的一种组件类型,用于处理命令请求
|
||||
|
||||
子类可以通过类属性定义命令模式:
|
||||
- command_pattern: 命令匹配的正则表达式
|
||||
- command_help: 命令帮助信息
|
||||
- command_examples: 命令使用示例列表
|
||||
"""
|
||||
|
||||
command_name: str = ""
|
||||
"""Command组件的名称"""
|
||||
command_description: str = ""
|
||||
"""Command组件的描述"""
|
||||
# 默认命令设置
|
||||
command_pattern: str = r""
|
||||
"""命令匹配的正则表达式"""
|
||||
|
||||
def __init__(self, message: MessageRecv, plugin_config: Optional[dict] = None):
|
||||
"""初始化Command组件
|
||||
|
||||
Args:
|
||||
message: 接收到的消息对象
|
||||
plugin_config: 插件配置字典
|
||||
"""
|
||||
self.message = message
|
||||
self.matched_groups: Dict[str, str] = {} # 存储正则表达式匹配的命名组
|
||||
self.plugin_config = plugin_config or {} # 直接存储插件配置字典
|
||||
|
||||
self.log_prefix = "[Command]"
|
||||
|
||||
logger.debug(f"{self.log_prefix} Command组件初始化完成")
|
||||
|
||||
def set_matched_groups(self, groups: Dict[str, str]) -> None:
|
||||
"""设置正则表达式匹配的命名组
|
||||
|
||||
Args:
|
||||
groups: 正则表达式匹配的命名组
|
||||
"""
|
||||
self.matched_groups = groups
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self) -> Tuple[bool, Optional[str], bool]:
|
||||
"""执行Command的抽象方法,子类必须实现
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Optional[str], bool]: (是否执行成功, 可选的回复消息, 是否拦截消息 不进行 后续处理)
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_config(self, key: str, default=None):
|
||||
"""获取插件配置值,使用嵌套键访问
|
||||
|
||||
Args:
|
||||
key: 配置键名,使用嵌套访问如 "section.subsection.key"
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
Any: 配置值或默认值
|
||||
"""
|
||||
if not self.plugin_config:
|
||||
return default
|
||||
|
||||
# 支持嵌套键访问
|
||||
keys = key.split(".")
|
||||
current = self.plugin_config
|
||||
|
||||
for k in keys:
|
||||
if isinstance(current, dict) and k in current:
|
||||
current = current[k]
|
||||
else:
|
||||
return default
|
||||
|
||||
return current
|
||||
|
||||
async def send_text(self, content: str, reply_to: str = "") -> bool:
|
||||
"""发送回复消息
|
||||
|
||||
Args:
|
||||
content: 回复内容
|
||||
reply_to: 回复消息,格式为"发送者:消息内容"
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
# 获取聊天流信息
|
||||
chat_stream = self.message.chat_stream
|
||||
if not chat_stream or not hasattr(chat_stream, "stream_id"):
|
||||
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
|
||||
return False
|
||||
|
||||
return await send_api.text_to_stream(text=content, stream_id=chat_stream.stream_id, reply_to=reply_to)
|
||||
|
||||
async def send_type(
|
||||
self, message_type: str, content: str, display_message: str = "", typing: bool = False, reply_to: str = ""
|
||||
) -> bool:
|
||||
"""发送指定类型的回复消息到当前聊天环境
|
||||
|
||||
Args:
|
||||
message_type: 消息类型,如"text"、"image"、"emoji"等
|
||||
content: 消息内容
|
||||
display_message: 显示消息(可选)
|
||||
typing: 是否显示正在输入
|
||||
reply_to: 回复消息,格式为"发送者:消息内容"
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
# 获取聊天流信息
|
||||
chat_stream = self.message.chat_stream
|
||||
if not chat_stream or not hasattr(chat_stream, "stream_id"):
|
||||
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
|
||||
return False
|
||||
|
||||
return await send_api.custom_to_stream(
|
||||
message_type=message_type,
|
||||
content=content,
|
||||
stream_id=chat_stream.stream_id,
|
||||
display_message=display_message,
|
||||
typing=typing,
|
||||
reply_to=reply_to,
|
||||
)
|
||||
|
||||
async def send_command(
|
||||
self, command_name: str, args: Optional[dict] = None, display_message: str = "", storage_message: bool = True
|
||||
) -> bool:
|
||||
"""发送命令消息
|
||||
|
||||
Args:
|
||||
command_name: 命令名称
|
||||
args: 命令参数
|
||||
display_message: 显示消息
|
||||
storage_message: 是否存储消息到数据库
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
try:
|
||||
# 获取聊天流信息
|
||||
chat_stream = self.message.chat_stream
|
||||
if not chat_stream or not hasattr(chat_stream, "stream_id"):
|
||||
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
|
||||
return False
|
||||
|
||||
# 构造命令数据
|
||||
command_data = {"name": command_name, "args": args or {}}
|
||||
|
||||
success = await send_api.command_to_stream(
|
||||
command=command_data,
|
||||
stream_id=chat_stream.stream_id,
|
||||
storage_message=storage_message,
|
||||
display_message=display_message,
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f"{self.log_prefix} 成功发送命令: {command_name}")
|
||||
else:
|
||||
logger.error(f"{self.log_prefix} 发送命令失败: {command_name}")
|
||||
|
||||
return success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 发送命令时出错: {e}")
|
||||
return False
|
||||
|
||||
async def send_emoji(self, emoji_base64: str) -> bool:
|
||||
"""发送表情包
|
||||
|
||||
Args:
|
||||
emoji_base64: 表情包的base64编码
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
chat_stream = self.message.chat_stream
|
||||
if not chat_stream or not hasattr(chat_stream, "stream_id"):
|
||||
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
|
||||
return False
|
||||
|
||||
return await send_api.emoji_to_stream(emoji_base64, chat_stream.stream_id)
|
||||
|
||||
async def send_image(self, image_base64: str) -> bool:
|
||||
"""发送图片
|
||||
|
||||
Args:
|
||||
image_base64: 图片的base64编码
|
||||
|
||||
Returns:
|
||||
bool: 是否发送成功
|
||||
"""
|
||||
chat_stream = self.message.chat_stream
|
||||
if not chat_stream or not hasattr(chat_stream, "stream_id"):
|
||||
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
|
||||
return False
|
||||
|
||||
return await send_api.image_to_stream(image_base64, chat_stream.stream_id)
|
||||
|
||||
@classmethod
|
||||
def get_command_info(cls) -> "CommandInfo":
|
||||
"""从类属性生成CommandInfo
|
||||
|
||||
Args:
|
||||
name: Command名称,如果不提供则使用类名
|
||||
description: Command描述,如果不提供则使用类文档字符串
|
||||
|
||||
Returns:
|
||||
CommandInfo: 生成的Command信息对象
|
||||
"""
|
||||
if "." in cls.command_name:
|
||||
logger.error(f"Command名称 '{cls.command_name}' 包含非法字符 '.',请使用下划线替代")
|
||||
raise ValueError(f"Command名称 '{cls.command_name}' 包含非法字符 '.',请使用下划线替代")
|
||||
return CommandInfo(
|
||||
name=cls.command_name,
|
||||
component_type=ComponentType.COMMAND,
|
||||
description=cls.command_description,
|
||||
command_pattern=cls.command_pattern,
|
||||
)
|
||||
101
src/plugin_system/base/base_events_handler.py
Normal file
101
src/plugin_system/base/base_events_handler.py
Normal file
@@ -0,0 +1,101 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Tuple, Optional, Dict
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from .component_types import MaiMessages, EventType, EventHandlerInfo, ComponentType
|
||||
|
||||
logger = get_logger("base_event_handler")
|
||||
|
||||
|
||||
class BaseEventHandler(ABC):
|
||||
"""事件处理器基类
|
||||
|
||||
所有事件处理器都应该继承这个基类,提供事件处理的基本接口
|
||||
"""
|
||||
|
||||
event_type: EventType = EventType.UNKNOWN
|
||||
"""事件类型,默认为未知"""
|
||||
handler_name: str = ""
|
||||
"""处理器名称"""
|
||||
handler_description: str = ""
|
||||
"""处理器描述"""
|
||||
weight: int = 0
|
||||
"""处理器权重,越大权重越高"""
|
||||
intercept_message: bool = False
|
||||
"""是否拦截消息,默认为否"""
|
||||
|
||||
def __init__(self):
|
||||
self.log_prefix = "[EventHandler]"
|
||||
self.plugin_name = ""
|
||||
"""对应插件名"""
|
||||
self.plugin_config: Optional[Dict] = None
|
||||
"""插件配置字典"""
|
||||
if self.event_type == EventType.UNKNOWN:
|
||||
raise NotImplementedError("事件处理器必须指定 event_type")
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, message: MaiMessages) -> Tuple[bool, bool, Optional[str]]:
|
||||
"""执行事件处理的抽象方法,子类必须实现
|
||||
|
||||
Returns:
|
||||
Tuple[bool, bool, Optional[str]]: (是否执行成功, 是否需要继续处理, 可选的返回消息)
|
||||
"""
|
||||
raise NotImplementedError("子类必须实现 execute 方法")
|
||||
|
||||
@classmethod
|
||||
def get_handler_info(cls) -> "EventHandlerInfo":
|
||||
"""获取事件处理器的信息"""
|
||||
# 从类属性读取名称,如果没有定义则使用类名自动生成
|
||||
name: str = getattr(cls, "handler_name", cls.__name__.lower().replace("handler", ""))
|
||||
if "." in name:
|
||||
logger.error(f"事件处理器名称 '{name}' 包含非法字符 '.',请使用下划线替代")
|
||||
raise ValueError(f"事件处理器名称 '{name}' 包含非法字符 '.',请使用下划线替代")
|
||||
return EventHandlerInfo(
|
||||
name=name,
|
||||
component_type=ComponentType.EVENT_HANDLER,
|
||||
description=getattr(cls, "handler_description", "events处理器"),
|
||||
event_type=cls.event_type,
|
||||
weight=cls.weight,
|
||||
intercept_message=cls.intercept_message,
|
||||
)
|
||||
|
||||
def set_plugin_config(self, plugin_config: Dict) -> None:
|
||||
"""设置插件配置
|
||||
|
||||
Args:
|
||||
plugin_config (dict): 插件配置字典
|
||||
"""
|
||||
self.plugin_config = plugin_config
|
||||
|
||||
def set_plugin_name(self, plugin_name: str) -> None:
|
||||
"""设置插件名称
|
||||
|
||||
Args:
|
||||
plugin_name (str): 插件名称
|
||||
"""
|
||||
self.plugin_name = plugin_name
|
||||
|
||||
def get_config(self, key: str, default=None):
|
||||
"""获取插件配置值,支持嵌套键访问
|
||||
|
||||
Args:
|
||||
key: 配置键名,支持嵌套访问如 "section.subsection.key"
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
Any: 配置值或默认值
|
||||
"""
|
||||
if not self.plugin_config:
|
||||
return default
|
||||
|
||||
# 支持嵌套键访问
|
||||
keys = key.split(".")
|
||||
current = self.plugin_config
|
||||
|
||||
for k in keys:
|
||||
if isinstance(current, dict) and k in current:
|
||||
current = current[k]
|
||||
else:
|
||||
return default
|
||||
|
||||
return current
|
||||
76
src/plugin_system/base/base_plugin.py
Normal file
76
src/plugin_system/base/base_plugin.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from abc import abstractmethod
|
||||
from typing import List, Type, Tuple, Union
|
||||
from .plugin_base import PluginBase
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import ActionInfo, CommandInfo, EventHandlerInfo, ToolInfo
|
||||
from .base_action import BaseAction
|
||||
from .base_command import BaseCommand
|
||||
from .base_events_handler import BaseEventHandler
|
||||
from .base_tool import BaseTool
|
||||
|
||||
logger = get_logger("base_plugin")
|
||||
|
||||
|
||||
class BasePlugin(PluginBase):
|
||||
"""基于Action和Command的插件基类
|
||||
|
||||
所有上述类型的插件都应该继承这个基类,一个插件可以包含多种组件:
|
||||
- Action组件:处理聊天中的动作
|
||||
- Command组件:处理命令请求
|
||||
- 未来可扩展:Scheduler、Listener等
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def get_plugin_components(
|
||||
self,
|
||||
) -> List[
|
||||
Union[
|
||||
Tuple[ActionInfo, Type[BaseAction]],
|
||||
Tuple[CommandInfo, Type[BaseCommand]],
|
||||
Tuple[EventHandlerInfo, Type[BaseEventHandler]],
|
||||
Tuple[ToolInfo, Type[BaseTool]],
|
||||
]
|
||||
]:
|
||||
"""获取插件包含的组件列表
|
||||
|
||||
子类必须实现此方法,返回组件信息和组件类的列表
|
||||
|
||||
Returns:
|
||||
List[tuple[ComponentInfo, Type]]: [(组件信息, 组件类), ...]
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement this method")
|
||||
|
||||
def register_plugin(self) -> bool:
|
||||
"""注册插件及其所有组件"""
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
components = self.get_plugin_components()
|
||||
|
||||
# 检查依赖
|
||||
if not self._check_dependencies():
|
||||
logger.error(f"{self.log_prefix} 依赖检查失败,跳过注册")
|
||||
return False
|
||||
|
||||
# 注册所有组件
|
||||
registered_components = []
|
||||
for component_info, component_class in components:
|
||||
component_info.plugin_name = self.plugin_name
|
||||
if component_registry.register_component(component_info, component_class):
|
||||
registered_components.append(component_info)
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 组件 {component_info.name} 注册失败")
|
||||
|
||||
# 更新插件信息中的组件列表
|
||||
self.plugin_info.components = registered_components
|
||||
|
||||
# 注册插件
|
||||
if component_registry.register_plugin(self.plugin_info):
|
||||
logger.debug(f"{self.log_prefix} 插件注册成功,包含 {len(registered_components)} 个组件")
|
||||
return True
|
||||
else:
|
||||
logger.error(f"{self.log_prefix} 插件注册失败")
|
||||
return False
|
||||
119
src/plugin_system/base/base_tool.py
Normal file
119
src/plugin_system/base/base_tool.py
Normal file
@@ -0,0 +1,119 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Optional, Tuple
|
||||
from rich.traceback import install
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import ComponentType, ToolInfo, ToolParamType
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("base_tool")
|
||||
|
||||
|
||||
class BaseTool(ABC):
|
||||
"""所有工具的基类"""
|
||||
|
||||
name: str = ""
|
||||
"""工具的名称"""
|
||||
description: str = ""
|
||||
"""工具的描述"""
|
||||
parameters: List[Tuple[str, ToolParamType, str, bool, List[str] | None]] = []
|
||||
"""工具的参数定义,为[("param_name", param_type, "description", required, enum_values)]格式
|
||||
param_name: 参数名称
|
||||
param_type: 参数类型
|
||||
description: 参数描述
|
||||
required: 是否必填
|
||||
enum_values: 枚举值列表
|
||||
例如: [("arg1", ToolParamType.STRING, "参数1描述", True, None), ("arg2", ToolParamType.INTEGER, "参数2描述", False, ["1", "2", "3"])]
|
||||
"""
|
||||
available_for_llm: bool = False
|
||||
"""是否可供LLM使用"""
|
||||
|
||||
def __init__(self, plugin_config: Optional[dict] = None):
|
||||
self.plugin_config = plugin_config or {} # 直接存储插件配置字典
|
||||
|
||||
@classmethod
|
||||
def get_tool_definition(cls) -> dict[str, Any]:
|
||||
"""获取工具定义,用于LLM工具调用
|
||||
|
||||
Returns:
|
||||
dict: 工具定义字典
|
||||
"""
|
||||
if not cls.name or not cls.description or not cls.parameters:
|
||||
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
|
||||
|
||||
return {"name": cls.name, "description": cls.description, "parameters": cls.parameters}
|
||||
|
||||
@classmethod
|
||||
def get_tool_info(cls) -> ToolInfo:
|
||||
"""获取工具信息"""
|
||||
if not cls.name or not cls.description or not cls.parameters:
|
||||
raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
|
||||
|
||||
return ToolInfo(
|
||||
name=cls.name,
|
||||
tool_description=cls.description,
|
||||
enabled=cls.available_for_llm,
|
||||
tool_parameters=cls.parameters,
|
||||
component_type=ComponentType.TOOL,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""执行工具函数(供llm调用)
|
||||
通过该方法,maicore会通过llm的tool call来调用工具
|
||||
传入的是json格式的参数,符合parameters定义的格式
|
||||
|
||||
Args:
|
||||
function_args: 工具调用参数
|
||||
|
||||
Returns:
|
||||
dict: 工具执行结果
|
||||
"""
|
||||
raise NotImplementedError("子类必须实现execute方法")
|
||||
|
||||
async def direct_execute(self, **function_args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""直接执行工具函数(供插件调用)
|
||||
通过该方法,插件可以直接调用工具,而不需要传入字典格式的参数
|
||||
插件可以直接调用此方法,用更加明了的方式传入参数
|
||||
示例: result = await tool.direct_execute(arg1="参数",arg2="参数2")
|
||||
|
||||
工具开发者可以重写此方法以实现与llm调用差异化的执行逻辑
|
||||
|
||||
Args:
|
||||
**function_args: 工具调用参数
|
||||
|
||||
Returns:
|
||||
dict: 工具执行结果
|
||||
"""
|
||||
parameter_required = [param[0] for param in self.parameters if param[3]] # 获取所有必填参数名
|
||||
for param_name in parameter_required:
|
||||
if param_name not in function_args:
|
||||
raise ValueError(f"工具类 {self.__class__.__name__} 缺少必要参数: {param_name}")
|
||||
|
||||
return await self.execute(function_args)
|
||||
|
||||
def get_config(self, key: str, default=None):
|
||||
"""获取插件配置值,使用嵌套键访问
|
||||
|
||||
Args:
|
||||
key: 配置键名,使用嵌套访问如 "section.subsection.key"
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
Any: 配置值或默认值
|
||||
"""
|
||||
if not self.plugin_config:
|
||||
return default
|
||||
|
||||
# 支持嵌套键访问
|
||||
keys = key.split(".")
|
||||
current = self.plugin_config
|
||||
|
||||
for k in keys:
|
||||
if isinstance(current, dict) and k in current:
|
||||
current = current[k]
|
||||
else:
|
||||
return default
|
||||
|
||||
return current
|
||||
283
src/plugin_system/base/component_types.py
Normal file
283
src/plugin_system/base/component_types.py
Normal file
@@ -0,0 +1,283 @@
|
||||
from enum import Enum
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from maim_message import Seg
|
||||
|
||||
from src.llm_models.payload_content.tool_option import ToolParamType as ToolParamType
|
||||
from src.llm_models.payload_content.tool_option import ToolCall as ToolCall
|
||||
|
||||
# 组件类型枚举
|
||||
class ComponentType(Enum):
|
||||
"""组件类型枚举"""
|
||||
|
||||
ACTION = "action" # 动作组件
|
||||
COMMAND = "command" # 命令组件
|
||||
TOOL = "tool" # 服务组件(预留)
|
||||
SCHEDULER = "scheduler" # 定时任务组件(预留)
|
||||
EVENT_HANDLER = "event_handler" # 事件处理组件(预留)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
# 动作激活类型枚举
|
||||
class ActionActivationType(Enum):
|
||||
"""动作激活类型枚举"""
|
||||
|
||||
NEVER = "never" # 从不激活(默认关闭)
|
||||
ALWAYS = "always" # 默认参与到planner
|
||||
LLM_JUDGE = "llm_judge" # LLM判定是否启动该action到planner
|
||||
RANDOM = "random" # 随机启用action到planner
|
||||
KEYWORD = "keyword" # 关键词触发启用action到planner
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
|
||||
# 聊天模式枚举
|
||||
class ChatMode(Enum):
|
||||
"""聊天模式枚举"""
|
||||
|
||||
FOCUS = "focus" # Focus聊天模式
|
||||
NORMAL = "normal" # Normal聊天模式
|
||||
PRIORITY = "priority" # 优先级聊天模式
|
||||
ALL = "all" # 所有聊天模式
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
||||
|
||||
# 事件类型枚举
|
||||
class EventType(Enum):
|
||||
"""
|
||||
事件类型枚举类
|
||||
"""
|
||||
|
||||
ON_START = "on_start" # 启动事件,用于调用按时任务
|
||||
ON_MESSAGE = "on_message"
|
||||
ON_PLAN = "on_plan"
|
||||
POST_LLM = "post_llm"
|
||||
AFTER_LLM = "after_llm"
|
||||
POST_SEND = "post_send"
|
||||
AFTER_SEND = "after_send"
|
||||
UNKNOWN = "unknown" # 未知事件类型
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
|
||||
@dataclass
|
||||
class PythonDependency:
|
||||
"""Python包依赖信息"""
|
||||
|
||||
package_name: str # 包名称
|
||||
version: str = "" # 版本要求,例如: ">=1.0.0", "==2.1.3", ""表示任意版本
|
||||
optional: bool = False # 是否为可选依赖
|
||||
description: str = "" # 依赖描述
|
||||
install_name: str = "" # 安装时的包名(如果与import名不同)
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.install_name:
|
||||
self.install_name = self.package_name
|
||||
|
||||
def get_pip_requirement(self) -> str:
|
||||
"""获取pip安装格式的依赖字符串"""
|
||||
if self.version:
|
||||
return f"{self.install_name}{self.version}"
|
||||
return self.install_name
|
||||
|
||||
|
||||
@dataclass
|
||||
class ComponentInfo:
|
||||
"""组件信息"""
|
||||
|
||||
name: str # 组件名称
|
||||
component_type: ComponentType # 组件类型
|
||||
description: str = "" # 组件描述
|
||||
enabled: bool = True # 是否启用
|
||||
plugin_name: str = "" # 所属插件名称
|
||||
is_built_in: bool = False # 是否为内置组件
|
||||
metadata: Dict[str, Any] = field(default_factory=dict) # 额外元数据
|
||||
|
||||
def __post_init__(self):
|
||||
if self.metadata is None:
|
||||
self.metadata = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActionInfo(ComponentInfo):
|
||||
"""动作组件信息"""
|
||||
|
||||
action_parameters: Dict[str, str] = field(
|
||||
default_factory=dict
|
||||
) # 动作参数与描述,例如 {"param1": "描述1", "param2": "描述2"}
|
||||
action_require: List[str] = field(default_factory=list) # 动作需求说明
|
||||
associated_types: List[str] = field(default_factory=list) # 关联的消息类型
|
||||
# 激活类型相关
|
||||
focus_activation_type: ActionActivationType = ActionActivationType.ALWAYS
|
||||
normal_activation_type: ActionActivationType = ActionActivationType.ALWAYS
|
||||
activation_type: ActionActivationType = ActionActivationType.ALWAYS
|
||||
random_activation_probability: float = 0.0
|
||||
llm_judge_prompt: str = ""
|
||||
activation_keywords: List[str] = field(default_factory=list) # 激活关键词列表
|
||||
keyword_case_sensitive: bool = False
|
||||
# 模式和并行设置
|
||||
mode_enable: ChatMode = ChatMode.ALL
|
||||
parallel_action: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self.activation_keywords is None:
|
||||
self.activation_keywords = []
|
||||
if self.action_parameters is None:
|
||||
self.action_parameters = {}
|
||||
if self.action_require is None:
|
||||
self.action_require = []
|
||||
if self.associated_types is None:
|
||||
self.associated_types = []
|
||||
self.component_type = ComponentType.ACTION
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommandInfo(ComponentInfo):
|
||||
"""命令组件信息"""
|
||||
|
||||
command_pattern: str = "" # 命令匹配模式(正则表达式)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.component_type = ComponentType.COMMAND
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolInfo(ComponentInfo):
|
||||
"""工具组件信息"""
|
||||
|
||||
tool_parameters: List[Tuple[str, ToolParamType, str, bool, List[str] | None]] = field(default_factory=list) # 工具参数定义
|
||||
tool_description: str = "" # 工具描述
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.component_type = ComponentType.TOOL
|
||||
|
||||
|
||||
@dataclass
|
||||
class EventHandlerInfo(ComponentInfo):
|
||||
"""事件处理器组件信息"""
|
||||
|
||||
event_type: EventType = EventType.ON_MESSAGE # 监听事件类型
|
||||
intercept_message: bool = False # 是否拦截消息处理(默认不拦截)
|
||||
weight: int = 0 # 事件处理器权重,决定执行顺序
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.component_type = ComponentType.EVENT_HANDLER
|
||||
|
||||
|
||||
@dataclass
|
||||
class PluginInfo:
|
||||
"""插件信息"""
|
||||
|
||||
display_name: str # 插件显示名称
|
||||
name: str # 插件名称
|
||||
description: str # 插件描述
|
||||
version: str = "1.0.0" # 插件版本
|
||||
author: str = "" # 插件作者
|
||||
enabled: bool = True # 是否启用
|
||||
is_built_in: bool = False # 是否为内置插件
|
||||
components: List[ComponentInfo] = field(default_factory=list) # 包含的组件列表
|
||||
dependencies: List[str] = field(default_factory=list) # 依赖的其他插件
|
||||
python_dependencies: List[PythonDependency] = field(default_factory=list) # Python包依赖
|
||||
config_file: str = "" # 配置文件路径
|
||||
metadata: Dict[str, Any] = field(default_factory=dict) # 额外元数据
|
||||
# 新增:manifest相关信息
|
||||
manifest_data: Dict[str, Any] = field(default_factory=dict) # manifest文件数据
|
||||
license: str = "" # 插件许可证
|
||||
homepage_url: str = "" # 插件主页
|
||||
repository_url: str = "" # 插件仓库地址
|
||||
keywords: List[str] = field(default_factory=list) # 插件关键词
|
||||
categories: List[str] = field(default_factory=list) # 插件分类
|
||||
min_host_version: str = "" # 最低主机版本要求
|
||||
max_host_version: str = "" # 最高主机版本要求
|
||||
|
||||
def __post_init__(self):
|
||||
if self.components is None:
|
||||
self.components = []
|
||||
if self.dependencies is None:
|
||||
self.dependencies = []
|
||||
if self.python_dependencies is None:
|
||||
self.python_dependencies = []
|
||||
if self.metadata is None:
|
||||
self.metadata = {}
|
||||
if self.manifest_data is None:
|
||||
self.manifest_data = {}
|
||||
if self.keywords is None:
|
||||
self.keywords = []
|
||||
if self.categories is None:
|
||||
self.categories = []
|
||||
|
||||
def get_missing_packages(self) -> List[PythonDependency]:
|
||||
"""检查缺失的Python包"""
|
||||
missing = []
|
||||
for dep in self.python_dependencies:
|
||||
try:
|
||||
__import__(dep.package_name)
|
||||
except ImportError:
|
||||
if not dep.optional:
|
||||
missing.append(dep)
|
||||
return missing
|
||||
|
||||
def get_pip_requirements(self) -> List[str]:
|
||||
"""获取所有pip安装格式的依赖"""
|
||||
return [dep.get_pip_requirement() for dep in self.python_dependencies]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaiMessages:
|
||||
"""MaiM插件消息"""
|
||||
|
||||
message_segments: List[Seg] = field(default_factory=list)
|
||||
"""消息段列表,支持多段消息"""
|
||||
|
||||
message_base_info: Dict[str, Any] = field(default_factory=dict)
|
||||
"""消息基本信息,包含平台,用户信息等数据"""
|
||||
|
||||
plain_text: str = ""
|
||||
"""纯文本消息内容"""
|
||||
|
||||
raw_message: Optional[str] = None
|
||||
"""原始消息内容"""
|
||||
|
||||
is_group_message: bool = False
|
||||
"""是否为群组消息"""
|
||||
|
||||
is_private_message: bool = False
|
||||
"""是否为私聊消息"""
|
||||
|
||||
stream_id: Optional[str] = None
|
||||
"""流ID,用于标识消息流"""
|
||||
|
||||
llm_prompt: Optional[str] = None
|
||||
"""LLM提示词"""
|
||||
|
||||
llm_response_content: Optional[str] = None
|
||||
"""LLM响应内容"""
|
||||
|
||||
llm_response_reasoning: Optional[str] = None
|
||||
"""LLM响应推理内容"""
|
||||
|
||||
llm_response_model: Optional[str] = None
|
||||
"""LLM响应模型名称"""
|
||||
|
||||
llm_response_tool_call: Optional[List[ToolCall]] = None
|
||||
"""LLM使用的工具调用"""
|
||||
|
||||
action_usage: Optional[List[str]] = None
|
||||
"""使用的Action"""
|
||||
|
||||
additional_data: Dict[Any, Any] = field(default_factory=dict)
|
||||
"""附加数据,可以存储额外信息"""
|
||||
|
||||
def __post_init__(self):
|
||||
if self.message_segments is None:
|
||||
self.message_segments = []
|
||||
18
src/plugin_system/base/config_types.py
Normal file
18
src/plugin_system/base/config_types.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""
|
||||
插件系统配置类型定义
|
||||
"""
|
||||
|
||||
from typing import Any, Optional, List
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfigField:
|
||||
"""配置字段定义"""
|
||||
|
||||
type: type # 字段类型
|
||||
default: Any # 默认值
|
||||
description: str # 字段描述
|
||||
example: Optional[str] = None # 示例值
|
||||
required: bool = False # 是否必需
|
||||
choices: Optional[List[Any]] = field(default_factory=list) # 可选值列表
|
||||
577
src/plugin_system/base/plugin_base.py
Normal file
577
src/plugin_system/base/plugin_base.py
Normal file
@@ -0,0 +1,577 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Any, Union
|
||||
import os
|
||||
import inspect
|
||||
import toml
|
||||
import json
|
||||
import shutil
|
||||
import datetime
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.component_types import (
|
||||
PluginInfo,
|
||||
PythonDependency,
|
||||
)
|
||||
from src.plugin_system.base.config_types import ConfigField
|
||||
from src.plugin_system.utils.manifest_utils import ManifestValidator
|
||||
|
||||
logger = get_logger("plugin_base")
|
||||
|
||||
|
||||
class PluginBase(ABC):
|
||||
"""插件总基类
|
||||
|
||||
所有衍生插件基类都应该继承自此类,这个类定义了插件的基本结构和行为。
|
||||
"""
|
||||
|
||||
# 插件基本信息(子类必须定义)
|
||||
@property
|
||||
@abstractmethod
|
||||
def plugin_name(self) -> str:
|
||||
return "" # 插件内部标识符(如 "hello_world_plugin")
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def enable_plugin(self) -> bool:
|
||||
return True # 是否启用插件
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def dependencies(self) -> List[str]:
|
||||
return [] # 依赖的其他插件
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def python_dependencies(self) -> List[PythonDependency]:
|
||||
return [] # Python包依赖
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def config_file_name(self) -> str:
|
||||
return "" # 配置文件名
|
||||
|
||||
# manifest文件相关
|
||||
manifest_file_name: str = "_manifest.json" # manifest文件名
|
||||
manifest_data: Dict[str, Any] = {} # manifest数据
|
||||
|
||||
# 配置定义
|
||||
@property
|
||||
@abstractmethod
|
||||
def config_schema(self) -> Dict[str, Union[Dict[str, ConfigField], str]]:
|
||||
return {}
|
||||
|
||||
config_section_descriptions: Dict[str, str] = {}
|
||||
|
||||
def __init__(self, plugin_dir: str):
|
||||
"""初始化插件
|
||||
|
||||
Args:
|
||||
plugin_dir: 插件目录路径,由插件管理器传递
|
||||
"""
|
||||
self.config: Dict[str, Any] = {} # 插件配置
|
||||
self.plugin_dir = plugin_dir # 插件目录路径
|
||||
self.log_prefix = f"[Plugin:{self.plugin_name}]"
|
||||
|
||||
# 加载manifest文件
|
||||
self._load_manifest()
|
||||
|
||||
# 验证插件信息
|
||||
self._validate_plugin_info()
|
||||
|
||||
# 加载插件配置
|
||||
self._load_plugin_config()
|
||||
|
||||
# 从manifest获取显示信息
|
||||
self.display_name = self.get_manifest_info("name", self.plugin_name)
|
||||
self.plugin_version = self.get_manifest_info("version", "1.0.0")
|
||||
self.plugin_description = self.get_manifest_info("description", "")
|
||||
self.plugin_author = self._get_author_name()
|
||||
|
||||
# 创建插件信息对象
|
||||
self.plugin_info = PluginInfo(
|
||||
name=self.plugin_name,
|
||||
display_name=self.display_name,
|
||||
description=self.plugin_description,
|
||||
version=self.plugin_version,
|
||||
author=self.plugin_author,
|
||||
enabled=self.enable_plugin,
|
||||
is_built_in=False,
|
||||
config_file=self.config_file_name or "",
|
||||
dependencies=self.dependencies.copy(),
|
||||
python_dependencies=self.python_dependencies.copy(),
|
||||
# manifest相关信息
|
||||
manifest_data=self.manifest_data.copy(),
|
||||
license=self.get_manifest_info("license", ""),
|
||||
homepage_url=self.get_manifest_info("homepage_url", ""),
|
||||
repository_url=self.get_manifest_info("repository_url", ""),
|
||||
keywords=self.get_manifest_info("keywords", []).copy() if self.get_manifest_info("keywords") else [],
|
||||
categories=self.get_manifest_info("categories", []).copy() if self.get_manifest_info("categories") else [],
|
||||
min_host_version=self.get_manifest_info("host_application.min_version", ""),
|
||||
max_host_version=self.get_manifest_info("host_application.max_version", ""),
|
||||
)
|
||||
|
||||
logger.debug(f"{self.log_prefix} 插件基类初始化完成")
|
||||
|
||||
def _validate_plugin_info(self):
|
||||
"""验证插件基本信息"""
|
||||
if not self.plugin_name:
|
||||
raise ValueError(f"插件类 {self.__class__.__name__} 必须定义 plugin_name")
|
||||
|
||||
# 验证manifest中的必需信息
|
||||
if not self.get_manifest_info("name"):
|
||||
raise ValueError(f"插件 {self.plugin_name} 的manifest中缺少name字段")
|
||||
if not self.get_manifest_info("description"):
|
||||
raise ValueError(f"插件 {self.plugin_name} 的manifest中缺少description字段")
|
||||
|
||||
def _load_manifest(self): # sourcery skip: raise-from-previous-error
|
||||
"""加载manifest文件(强制要求)"""
|
||||
if not self.plugin_dir:
|
||||
raise ValueError(f"{self.log_prefix} 没有插件目录路径,无法加载manifest")
|
||||
|
||||
manifest_path = os.path.join(self.plugin_dir, self.manifest_file_name)
|
||||
|
||||
if not os.path.exists(manifest_path):
|
||||
error_msg = f"{self.log_prefix} 缺少必需的manifest文件: {manifest_path}"
|
||||
logger.error(error_msg)
|
||||
raise FileNotFoundError(error_msg)
|
||||
|
||||
try:
|
||||
with open(manifest_path, "r", encoding="utf-8") as f:
|
||||
self.manifest_data = json.load(f)
|
||||
|
||||
logger.debug(f"{self.log_prefix} 成功加载manifest文件: {manifest_path}")
|
||||
|
||||
# 验证manifest格式
|
||||
self._validate_manifest()
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
error_msg = f"{self.log_prefix} manifest文件格式错误: {e}"
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg) # noqa
|
||||
except IOError as e:
|
||||
error_msg = f"{self.log_prefix} 读取manifest文件失败: {e}"
|
||||
logger.error(error_msg)
|
||||
raise IOError(error_msg) # noqa
|
||||
|
||||
def _get_author_name(self) -> str:
|
||||
"""从manifest获取作者名称"""
|
||||
author_info = self.get_manifest_info("author", {})
|
||||
if isinstance(author_info, dict):
|
||||
return author_info.get("name", "")
|
||||
else:
|
||||
return str(author_info) if author_info else ""
|
||||
|
||||
def _validate_manifest(self):
|
||||
"""验证manifest文件格式(使用强化的验证器)"""
|
||||
if not self.manifest_data:
|
||||
raise ValueError(f"{self.log_prefix} manifest数据为空,验证失败")
|
||||
|
||||
validator = ManifestValidator()
|
||||
is_valid = validator.validate_manifest(self.manifest_data)
|
||||
|
||||
# 记录验证结果
|
||||
if validator.validation_errors or validator.validation_warnings:
|
||||
report = validator.get_validation_report()
|
||||
logger.info(f"{self.log_prefix} Manifest验证结果:\n{report}")
|
||||
|
||||
# 如果有验证错误,抛出异常
|
||||
if not is_valid:
|
||||
error_msg = f"{self.log_prefix} Manifest文件验证失败"
|
||||
if validator.validation_errors:
|
||||
error_msg += f": {'; '.join(validator.validation_errors)}"
|
||||
raise ValueError(error_msg)
|
||||
|
||||
def get_manifest_info(self, key: str, default: Any = None) -> Any:
|
||||
"""获取manifest信息
|
||||
|
||||
Args:
|
||||
key: 信息键,支持点分割的嵌套键(如 "author.name")
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
Any: 对应的值
|
||||
"""
|
||||
if not self.manifest_data:
|
||||
return default
|
||||
|
||||
keys = key.split(".")
|
||||
value = self.manifest_data
|
||||
|
||||
for k in keys:
|
||||
if isinstance(value, dict) and k in value:
|
||||
value = value[k]
|
||||
else:
|
||||
return default
|
||||
|
||||
return value
|
||||
|
||||
def _generate_and_save_default_config(self, config_file_path: str):
|
||||
"""根据插件的Schema生成并保存默认配置文件"""
|
||||
if not self.config_schema:
|
||||
logger.debug(f"{self.log_prefix} 插件未定义config_schema,不生成配置文件")
|
||||
return
|
||||
|
||||
toml_str = f"# {self.plugin_name} - 自动生成的配置文件\n"
|
||||
plugin_description = self.get_manifest_info("description", "插件配置文件")
|
||||
toml_str += f"# {plugin_description}\n\n"
|
||||
|
||||
# 遍历每个配置节
|
||||
for section, fields in self.config_schema.items():
|
||||
# 添加节描述
|
||||
if section in self.config_section_descriptions:
|
||||
toml_str += f"# {self.config_section_descriptions[section]}\n"
|
||||
|
||||
toml_str += f"[{section}]\n\n"
|
||||
|
||||
# 遍历节内的字段
|
||||
if isinstance(fields, dict):
|
||||
for field_name, field in fields.items():
|
||||
if isinstance(field, ConfigField):
|
||||
# 添加字段描述
|
||||
toml_str += f"# {field.description}"
|
||||
if field.required:
|
||||
toml_str += " (必需)"
|
||||
toml_str += "\n"
|
||||
|
||||
# 如果有示例值,添加示例
|
||||
if field.example:
|
||||
toml_str += f"# 示例: {field.example}\n"
|
||||
|
||||
# 如果有可选值,添加说明
|
||||
if field.choices:
|
||||
choices_str = ", ".join(map(str, field.choices))
|
||||
toml_str += f"# 可选值: {choices_str}\n"
|
||||
|
||||
# 添加字段值
|
||||
value = field.default
|
||||
if isinstance(value, str):
|
||||
toml_str += f'{field_name} = "{value}"\n'
|
||||
elif isinstance(value, bool):
|
||||
toml_str += f"{field_name} = {str(value).lower()}\n"
|
||||
else:
|
||||
toml_str += f"{field_name} = {value}\n"
|
||||
|
||||
toml_str += "\n"
|
||||
toml_str += "\n"
|
||||
|
||||
try:
|
||||
with open(config_file_path, "w", encoding="utf-8") as f:
|
||||
f.write(toml_str)
|
||||
logger.info(f"{self.log_prefix} 已生成默认配置文件: {config_file_path}")
|
||||
except IOError as e:
|
||||
logger.error(f"{self.log_prefix} 保存默认配置文件失败: {e}", exc_info=True)
|
||||
|
||||
def _get_expected_config_version(self) -> str:
|
||||
"""获取插件期望的配置版本号"""
|
||||
# 从config_schema的plugin.config_version字段获取
|
||||
if "plugin" in self.config_schema and isinstance(self.config_schema["plugin"], dict):
|
||||
config_version_field = self.config_schema["plugin"].get("config_version")
|
||||
if isinstance(config_version_field, ConfigField):
|
||||
return config_version_field.default
|
||||
return "1.0.0"
|
||||
|
||||
def _get_current_config_version(self, config: Dict[str, Any]) -> str:
|
||||
"""从配置文件中获取当前版本号"""
|
||||
if "plugin" in config and "config_version" in config["plugin"]:
|
||||
return str(config["plugin"]["config_version"])
|
||||
# 如果没有config_version字段,视为最早的版本
|
||||
return "0.0.0"
|
||||
|
||||
def _backup_config_file(self, config_file_path: str) -> str:
|
||||
"""备份配置文件"""
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
backup_path = f"{config_file_path}.backup_{timestamp}"
|
||||
|
||||
try:
|
||||
shutil.copy2(config_file_path, backup_path)
|
||||
logger.info(f"{self.log_prefix} 配置文件已备份到: {backup_path}")
|
||||
return backup_path
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 备份配置文件失败: {e}")
|
||||
return ""
|
||||
|
||||
def _migrate_config_values(self, old_config: Dict[str, Any], new_config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""将旧配置值迁移到新配置结构中
|
||||
|
||||
Args:
|
||||
old_config: 旧配置数据
|
||||
new_config: 基于新schema生成的默认配置
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 迁移后的配置
|
||||
"""
|
||||
|
||||
def migrate_section(
|
||||
old_section: Dict[str, Any], new_section: Dict[str, Any], section_name: str
|
||||
) -> Dict[str, Any]:
|
||||
"""迁移单个配置节"""
|
||||
result = new_section.copy()
|
||||
|
||||
for key, value in old_section.items():
|
||||
if key in new_section:
|
||||
# 特殊处理:config_version字段总是使用新版本
|
||||
if section_name == "plugin" and key == "config_version":
|
||||
# 保持新的版本号,不迁移旧值
|
||||
logger.debug(
|
||||
f"{self.log_prefix} 更新配置版本: {section_name}.{key} = {result[key]} (旧值: {value})"
|
||||
)
|
||||
continue
|
||||
|
||||
# 键存在于新配置中,复制值
|
||||
if isinstance(value, dict) and isinstance(new_section[key], dict):
|
||||
# 递归处理嵌套字典
|
||||
result[key] = migrate_section(value, new_section[key], f"{section_name}.{key}")
|
||||
else:
|
||||
result[key] = value
|
||||
logger.debug(f"{self.log_prefix} 迁移配置: {section_name}.{key} = {value}")
|
||||
else:
|
||||
# 键在新配置中不存在,记录警告
|
||||
logger.warning(f"{self.log_prefix} 配置项 {section_name}.{key} 在新版本中已被移除")
|
||||
|
||||
return result
|
||||
|
||||
migrated_config = {}
|
||||
|
||||
# 迁移每个配置节
|
||||
for section_name, new_section_data in new_config.items():
|
||||
if (
|
||||
section_name in old_config
|
||||
and isinstance(old_config[section_name], dict)
|
||||
and isinstance(new_section_data, dict)
|
||||
):
|
||||
migrated_config[section_name] = migrate_section(
|
||||
old_config[section_name], new_section_data, section_name
|
||||
)
|
||||
else:
|
||||
# 新增的节或类型不匹配,使用默认值
|
||||
migrated_config[section_name] = new_section_data
|
||||
if section_name in old_config:
|
||||
logger.warning(f"{self.log_prefix} 配置节 {section_name} 结构已改变,使用默认值")
|
||||
|
||||
# 检查旧配置中是否有新配置没有的节
|
||||
for section_name in old_config:
|
||||
if section_name not in migrated_config:
|
||||
logger.warning(f"{self.log_prefix} 配置节 {section_name} 在新版本中已被移除")
|
||||
|
||||
return migrated_config
|
||||
|
||||
def _generate_config_from_schema(self) -> Dict[str, Any]:
|
||||
# sourcery skip: dict-comprehension
|
||||
"""根据schema生成配置数据结构(不写入文件)"""
|
||||
if not self.config_schema:
|
||||
return {}
|
||||
|
||||
config_data = {}
|
||||
|
||||
# 遍历每个配置节
|
||||
for section, fields in self.config_schema.items():
|
||||
if isinstance(fields, dict):
|
||||
section_data = {}
|
||||
|
||||
# 遍历节内的字段
|
||||
for field_name, field in fields.items():
|
||||
if isinstance(field, ConfigField):
|
||||
section_data[field_name] = field.default
|
||||
|
||||
config_data[section] = section_data
|
||||
|
||||
return config_data
|
||||
|
||||
def _save_config_to_file(self, config_data: Dict[str, Any], config_file_path: str):
|
||||
"""将配置数据保存为TOML文件(包含注释)"""
|
||||
if not self.config_schema:
|
||||
logger.debug(f"{self.log_prefix} 插件未定义config_schema,不生成配置文件")
|
||||
return
|
||||
|
||||
toml_str = f"# {self.plugin_name} - 配置文件\n"
|
||||
plugin_description = self.get_manifest_info("description", "插件配置文件")
|
||||
toml_str += f"# {plugin_description}\n"
|
||||
|
||||
# 获取当前期望的配置版本
|
||||
expected_version = self._get_expected_config_version()
|
||||
toml_str += f"# 配置版本: {expected_version}\n\n"
|
||||
|
||||
# 遍历每个配置节
|
||||
for section, fields in self.config_schema.items():
|
||||
# 添加节描述
|
||||
if section in self.config_section_descriptions:
|
||||
toml_str += f"# {self.config_section_descriptions[section]}\n"
|
||||
|
||||
toml_str += f"[{section}]\n\n"
|
||||
|
||||
# 遍历节内的字段
|
||||
if isinstance(fields, dict) and section in config_data:
|
||||
section_data = config_data[section]
|
||||
|
||||
for field_name, field in fields.items():
|
||||
if isinstance(field, ConfigField):
|
||||
# 添加字段描述
|
||||
toml_str += f"# {field.description}"
|
||||
if field.required:
|
||||
toml_str += " (必需)"
|
||||
toml_str += "\n"
|
||||
|
||||
# 如果有示例值,添加示例
|
||||
if field.example:
|
||||
toml_str += f"# 示例: {field.example}\n"
|
||||
|
||||
# 如果有可选值,添加说明
|
||||
if field.choices:
|
||||
choices_str = ", ".join(map(str, field.choices))
|
||||
toml_str += f"# 可选值: {choices_str}\n"
|
||||
|
||||
# 添加字段值(使用迁移后的值)
|
||||
value = section_data.get(field_name, field.default)
|
||||
if isinstance(value, str):
|
||||
toml_str += f'{field_name} = "{value}"\n'
|
||||
elif isinstance(value, bool):
|
||||
toml_str += f"{field_name} = {str(value).lower()}\n"
|
||||
elif isinstance(value, list):
|
||||
# 格式化列表
|
||||
if all(isinstance(item, str) for item in value):
|
||||
formatted_list = "[" + ", ".join(f'"{item}"' for item in value) + "]"
|
||||
else:
|
||||
formatted_list = str(value)
|
||||
toml_str += f"{field_name} = {formatted_list}\n"
|
||||
else:
|
||||
toml_str += f"{field_name} = {value}\n"
|
||||
|
||||
toml_str += "\n"
|
||||
toml_str += "\n"
|
||||
|
||||
try:
|
||||
with open(config_file_path, "w", encoding="utf-8") as f:
|
||||
f.write(toml_str)
|
||||
logger.info(f"{self.log_prefix} 配置文件已保存: {config_file_path}")
|
||||
except IOError as e:
|
||||
logger.error(f"{self.log_prefix} 保存配置文件失败: {e}", exc_info=True)
|
||||
|
||||
def _load_plugin_config(self): # sourcery skip: extract-method
|
||||
"""加载插件配置文件,支持版本检查和自动迁移"""
|
||||
if not self.config_file_name:
|
||||
logger.debug(f"{self.log_prefix} 未指定配置文件,跳过加载")
|
||||
return
|
||||
|
||||
# 优先使用传入的插件目录路径
|
||||
if self.plugin_dir:
|
||||
plugin_dir = self.plugin_dir
|
||||
else:
|
||||
# fallback:尝试从类的模块信息获取路径
|
||||
try:
|
||||
plugin_module_path = inspect.getfile(self.__class__)
|
||||
plugin_dir = os.path.dirname(plugin_module_path)
|
||||
except (TypeError, OSError):
|
||||
# 最后的fallback:从模块的__file__属性获取
|
||||
module = inspect.getmodule(self.__class__)
|
||||
if module and hasattr(module, "__file__") and module.__file__:
|
||||
plugin_dir = os.path.dirname(module.__file__)
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 无法获取插件目录路径,跳过配置加载")
|
||||
return
|
||||
|
||||
config_file_path = os.path.join(plugin_dir, self.config_file_name)
|
||||
|
||||
# 如果配置文件不存在,生成默认配置
|
||||
if not os.path.exists(config_file_path):
|
||||
logger.info(f"{self.log_prefix} 配置文件 {config_file_path} 不存在,将生成默认配置。")
|
||||
self._generate_and_save_default_config(config_file_path)
|
||||
|
||||
if not os.path.exists(config_file_path):
|
||||
logger.warning(f"{self.log_prefix} 配置文件 {config_file_path} 不存在且无法生成。")
|
||||
return
|
||||
|
||||
file_ext = os.path.splitext(self.config_file_name)[1].lower()
|
||||
|
||||
if file_ext == ".toml":
|
||||
# 加载现有配置
|
||||
with open(config_file_path, "r", encoding="utf-8") as f:
|
||||
existing_config = toml.load(f) or {}
|
||||
|
||||
# 检查配置版本
|
||||
current_version = self._get_current_config_version(existing_config)
|
||||
|
||||
# 如果配置文件没有版本信息,跳过版本检查
|
||||
if current_version == "0.0.0":
|
||||
logger.debug(f"{self.log_prefix} 配置文件无版本信息,跳过版本检查")
|
||||
self.config = existing_config
|
||||
else:
|
||||
expected_version = self._get_expected_config_version()
|
||||
|
||||
if current_version != expected_version:
|
||||
logger.info(
|
||||
f"{self.log_prefix} 检测到配置版本需要更新: 当前=v{current_version}, 期望=v{expected_version}"
|
||||
)
|
||||
|
||||
# 生成新的默认配置结构
|
||||
new_config_structure = self._generate_config_from_schema()
|
||||
|
||||
# 迁移旧配置值到新结构
|
||||
migrated_config = self._migrate_config_values(existing_config, new_config_structure)
|
||||
|
||||
# 保存迁移后的配置
|
||||
self._save_config_to_file(migrated_config, config_file_path)
|
||||
|
||||
logger.info(f"{self.log_prefix} 配置文件已从 v{current_version} 更新到 v{expected_version}")
|
||||
|
||||
self.config = migrated_config
|
||||
else:
|
||||
logger.debug(f"{self.log_prefix} 配置版本匹配 (v{current_version}),直接加载")
|
||||
self.config = existing_config
|
||||
|
||||
logger.debug(f"{self.log_prefix} 配置已从 {config_file_path} 加载")
|
||||
|
||||
# 从配置中更新 enable_plugin
|
||||
if "plugin" in self.config and "enabled" in self.config["plugin"]:
|
||||
self.enable_plugin = self.config["plugin"]["enabled"] # type: ignore
|
||||
logger.debug(f"{self.log_prefix} 从配置更新插件启用状态: {self.enable_plugin}")
|
||||
else:
|
||||
logger.warning(f"{self.log_prefix} 不支持的配置文件格式: {file_ext},仅支持 .toml")
|
||||
self.config = {}
|
||||
|
||||
def _check_dependencies(self) -> bool:
|
||||
"""检查插件依赖"""
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
if not self.dependencies:
|
||||
return True
|
||||
|
||||
for dep in self.dependencies:
|
||||
if not component_registry.get_plugin_info(dep):
|
||||
logger.error(f"{self.log_prefix} 缺少依赖插件: {dep}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def get_config(self, key: str, default: Any = None) -> Any:
|
||||
"""获取插件配置值,支持嵌套键访问
|
||||
|
||||
Args:
|
||||
key: 配置键名,支持嵌套访问如 "section.subsection.key"
|
||||
default: 默认值
|
||||
|
||||
Returns:
|
||||
Any: 配置值或默认值
|
||||
"""
|
||||
# 支持嵌套键访问
|
||||
keys = key.split(".")
|
||||
current = self.config
|
||||
|
||||
for k in keys:
|
||||
if isinstance(current, dict) and k in current:
|
||||
current = current[k]
|
||||
else:
|
||||
return default
|
||||
|
||||
return current
|
||||
|
||||
@abstractmethod
|
||||
def register_plugin(self) -> bool:
|
||||
"""
|
||||
注册插件到插件管理器
|
||||
|
||||
子类必须实现此方法,返回注册是否成功
|
||||
|
||||
Returns:
|
||||
bool: 是否成功注册插件
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement this method")
|
||||
Reference in New Issue
Block a user