From 50a6c2de587cc7b23290137912f10d58617a46fc Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Sun, 19 Oct 2025 13:00:18 +0800 Subject: [PATCH] =?UTF-8?q?feat(plugin):=20=E5=BC=95=E5=85=A5Prompt?= =?UTF-8?q?=E7=BB=84=E4=BB=B6=E7=B3=BB=E7=BB=9F=E4=BB=A5=E5=AE=9E=E7=8E=B0?= =?UTF-8?q?=E5=8A=A8=E6=80=81Prompt=E6=B3=A8=E5=85=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 引入了一个新的插件组件类型 `BasePrompt`,允许插件动态地向核心Prompt模板中注入额外的上下文信息。该系统旨在提高Prompt的可扩展性和可定制性,使得开发者可以在不修改核心代码的情况下,通过插件来丰富和调整模型的行为。 主要变更包括: - **`BasePrompt` 基类**: 定义了Prompt组件的标准接口,包括 `execute` 方法用于生成注入内容,以及 `injection_point` 属性用于指定目标Prompt。 - **`PromptComponentManager`**: 一个新的管理器,负责注册、分类和执行所有 `BasePrompt` 组件。它会在构建核心Prompt时,自动查找并执行相关组件,将其输出拼接到主Prompt内容之前。 - **核心Prompt逻辑更新**: `src.chat.utils.prompt.Prompt` 类现在会调用 `PromptComponentManager` 来获取并注入组件内容。 - **插件系统集成**: `ComponentRegistry` 和 `PluginManager` 已更新,以支持 `BasePrompt` 组件的注册、管理和统计。 - **示例插件更新**: `hello_world_plugin` 中增加了一个 `WeatherPrompt` 示例,演示了如何创建和注册一个新的Prompt组件。 - **代码重构**: 将 `PromptParameters` 类从 `prompt.py` 移动到独立的 `prompt_params.py` 文件中,以改善模块化和解决循环依赖问题。 --- plugins/hello_world_plugin/plugin.py | 32 ++++- src/chat/replyer/default_generator.py | 3 +- src/chat/utils/prompt.py | 97 +++------------ src/chat/utils/prompt_component_manager.py | 123 +++++++++++++++++++ src/chat/utils/prompt_params.py | 79 ++++++++++++ src/plugin_system/__init__.py | 2 + src/plugin_system/base/__init__.py | 2 + src/plugin_system/base/base_plugin.py | 10 ++ src/plugin_system/base/base_prompt.py | 95 ++++++++++++++ src/plugin_system/base/component_types.py | 13 ++ src/plugin_system/core/component_registry.py | 48 +++++++- src/plugin_system/core/plugin_manager.py | 9 +- 12 files changed, 428 insertions(+), 85 deletions(-) create mode 100644 src/chat/utils/prompt_component_manager.py create mode 100644 src/chat/utils/prompt_params.py create mode 100644 src/plugin_system/base/base_prompt.py diff --git a/plugins/hello_world_plugin/plugin.py b/plugins/hello_world_plugin/plugin.py index 5d1026d52..fbb4fcab8 100644 --- a/plugins/hello_world_plugin/plugin.py +++ b/plugins/hello_world_plugin/plugin.py @@ -6,6 +6,8 @@ from src.plugin_system import ( BaseAction, BaseEventHandler, BasePlugin, + BasePrompt, + ToolParamType, BaseTool, ChatType, CommandArgs, @@ -36,7 +38,17 @@ class GetSystemInfoTool(BaseTool): name = "get_system_info" description = "获取当前系统的模拟版本和状态信息。" available_for_llm = True - parameters = [] + parameters = [ + ("query", ToolParamType.STRING, "要搜索的关键词或问题。", True, None), + ("num_results", ToolParamType.INTEGER, "期望每个搜索引擎返回的搜索结果数量,默认为5。", False, None), + ( + "time_range", + ToolParamType.STRING, + "指定搜索的时间范围,可以是 'any', 'week', 'month'。默认为 'any'。", + False, + ["any", "week", "month"], + ), + ] # type: ignore async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: return {"name": self.name, "content": "系统版本: 1.0.1, 状态: 运行正常"} @@ -99,7 +111,6 @@ class LLMJudgeExampleAction(BaseAction): async def go_activate(self, chat_content: str = "", llm_judge_model=None) -> bool: """LLM 判断激活:判断用户是否情绪低落""" return await self._llm_judge_activation( - chat_content=chat_content, judge_prompt=""" 判断用户是否表达了以下情绪或需求: 1. 感到难过、沮丧或失落 @@ -169,6 +180,19 @@ class RandomEmojiAction(BaseAction): return True, "成功发送了一个随机表情" +class WeatherPrompt(BasePrompt): + """一个简单的Prompt组件,用于向Planner注入天气信息。""" + + prompt_name = "weather_info_prompt" + prompt_description = "向Planner注入当前天气信息,以丰富对话上下文。" + injection_point = "planner_prompt" + + async def execute(self) -> str: + # 在实际应用中,这里可以调用天气API + # 为了演示,我们返回一个固定的天气信息 + return "当前天气:晴朗,温度25°C。" + + @register_plugin class HelloWorldPlugin(BasePlugin): """一个包含四大核心组件和高级配置功能的入门示例插件。""" @@ -178,7 +202,6 @@ class HelloWorldPlugin(BasePlugin): dependencies = [] python_dependencies = [] config_file_name = "config.toml" - enable_plugin = False config_schema = { "meta": { @@ -208,4 +231,7 @@ class HelloWorldPlugin(BasePlugin): if self.get_config("components.random_emoji_action_enabled", True): components.append((RandomEmojiAction.get_action_info(), RandomEmojiAction)) + # 注册新的Prompt组件 + components.append((WeatherPrompt.get_prompt_info(), WeatherPrompt)) + return components diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index 9b25fd110..ec7e14787 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -23,7 +23,8 @@ from src.chat.utils.chat_message_builder import ( from src.chat.utils.memory_mappings import get_memory_type_chinese_label # 导入新的统一Prompt系统 -from src.chat.utils.prompt import Prompt, PromptParameters, global_prompt_manager +from src.chat.utils.prompt import Prompt, global_prompt_manager +from src.chat.utils.prompt_params import PromptParameters from src.chat.utils.timer_calculator import Timer from src.chat.utils.utils import get_chat_type_and_target_info from src.common.logger import get_logger diff --git a/src/chat/utils/prompt.py b/src/chat/utils/prompt.py index 543da2d2b..356255fd5 100644 --- a/src/chat/utils/prompt.py +++ b/src/chat/utils/prompt.py @@ -8,13 +8,14 @@ import contextvars import re import time from contextlib import asynccontextmanager -from dataclasses import dataclass, field -from typing import Any, Literal, Optional +from typing import Any, Optional from rich.traceback import install from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.utils.chat_message_builder import build_readable_messages +from src.chat.utils.prompt_component_manager import prompt_component_manager +from src.chat.utils.prompt_params import PromptParameters from src.common.logger import get_logger from src.config.config import global_config from src.person_info.person_info import get_person_info_manager @@ -23,80 +24,6 @@ install(extra_lines=3) logger = get_logger("unified_prompt") -@dataclass -class PromptParameters: - """统一提示词参数系统""" - - # 基础参数 - chat_id: str = "" - is_group_chat: bool = False - sender: str = "" - target: str = "" - reply_to: str = "" - extra_info: str = "" - prompt_mode: Literal["s4u", "normal", "minimal"] = "s4u" - bot_name: str = "" - bot_nickname: str = "" - - # 功能开关 - enable_tool: bool = True - enable_memory: bool = True - enable_expression: bool = True - enable_relation: bool = True - enable_cross_context: bool = True - enable_knowledge: bool = True - - # 性能控制 - max_context_messages: int = 50 - - # 调试选项 - debug_mode: bool = False - - # 聊天历史和上下文 - chat_target_info: dict[str, Any] | None = None - message_list_before_now_long: list[dict[str, Any]] = field(default_factory=list) - message_list_before_short: list[dict[str, Any]] = field(default_factory=list) - chat_talking_prompt_short: str = "" - target_user_info: dict[str, Any] | None = None - - # 已构建的内容块 - expression_habits_block: str = "" - relation_info_block: str = "" - memory_block: str = "" - tool_info_block: str = "" - knowledge_prompt: str = "" - cross_context_block: str = "" - - # 其他内容块 - keywords_reaction_prompt: str = "" - extra_info_block: str = "" - time_block: str = "" - identity_block: str = "" - schedule_block: str = "" - moderation_prompt_block: str = "" - safety_guidelines_block: str = "" - reply_target_block: str = "" - mood_prompt: str = "" - action_descriptions: str = "" - - # 可用动作信息 - available_actions: dict[str, Any] | None = None - - # 动态生成的聊天场景提示 - chat_scene: str = "" - - def validate(self) -> list[str]: - """参数验证""" - errors = [] - if not self.chat_id: - errors.append("chat_id不能为空") - if self.prompt_mode not in ["s4u", "normal", "minimal"]: - errors.append("prompt_mode必须是's4u'、'normal'或'minimal'") - if self.max_context_messages <= 0: - errors.append("max_context_messages必须大于0") - return errors - - class PromptContext: """提示词上下文管理器""" @@ -303,11 +230,23 @@ class Prompt: start_time = time.time() try: - # 构建上下文数据 + # 1. 从组件管理器获取注入内容 + components_prefix = "" + if self.name: + components_prefix = await prompt_component_manager.execute_components_for( + injection_point=self.name, params=self.parameters + ) + + # 2. 构建核心上下文数据 context_data = await self._build_context_data() - # 格式化模板 - result = await self._format_with_context(context_data) + # 3. 格式化主模板 + main_formatted_prompt = await self._format_with_context(context_data) + + # 4. 拼接组件内容和主模板内容 + result = main_formatted_prompt + if components_prefix: + result = f"{components_prefix}\n\n{main_formatted_prompt}" total_time = time.time() - start_time logger.debug(f"Prompt构建完成,模式: {self.parameters.prompt_mode}, 耗时: {total_time:.2f}s") diff --git a/src/chat/utils/prompt_component_manager.py b/src/chat/utils/prompt_component_manager.py new file mode 100644 index 000000000..380e6dc77 --- /dev/null +++ b/src/chat/utils/prompt_component_manager.py @@ -0,0 +1,123 @@ +import asyncio +from collections import defaultdict +from typing import Type + +from src.chat.utils.prompt_params import PromptParameters +from src.common.logger import get_logger +from src.plugin_system.base.base_prompt import BasePrompt +from src.plugin_system.base.component_types import PromptInfo + +logger = get_logger("prompt_component_manager") + + +class PromptComponentManager: + """ + 管理所有 `BasePrompt` 组件的单例类。 + + 该管理器负责: + 1. 注册由插件定义的 `BasePrompt` 子类。 + 2. 根据注入点(目标Prompt名称)对它们进行分类存储。 + 3. 提供一个接口,以便在构建核心Prompt时,能够获取并执行所有相关的组件。 + """ + + def __init__(self): + self._registry: dict[str, list[Type[BasePrompt]]] = defaultdict(list) + self._prompt_infos: dict[str, PromptInfo] = {} + + def register(self, component_class: Type[BasePrompt]): + """ + 注册一个 `BasePrompt` 组件类。 + + Args: + component_class: 要注册的 `BasePrompt` 子类。 + """ + if not issubclass(component_class, BasePrompt): + logger.warning(f"尝试注册一个非 BasePrompt 的类: {component_class.__name__}") + return + + try: + prompt_info = component_class.get_prompt_info() + if prompt_info.name in self._prompt_infos: + logger.warning(f"重复注册 Prompt 组件: {prompt_info.name}。将覆盖旧组件。") + + injection_points = prompt_info.injection_point + if isinstance(injection_points, str): + injection_points = [injection_points] + + if not injection_points or not all(injection_points): + logger.debug(f"Prompt 组件 '{prompt_info.name}' 未指定有效的 injection_point,将不会被自动注入。") + return + + for point in injection_points: + self._registry[point].append(component_class) + + self._prompt_infos[prompt_info.name] = prompt_info + logger.info(f"成功注册 Prompt 组件 '{prompt_info.name}' 到注入点: {injection_points}") + + except ValueError as e: + logger.error(f"注册 Prompt 组件失败 {component_class.__name__}: {e}") + + def get_components_for(self, injection_point: str) -> list[Type[BasePrompt]]: + """ + 获取指定注入点的所有已注册组件类。 + + Args: + injection_point: 目标Prompt的名称。 + + Returns: + list[Type[BasePrompt]]: 与该注入点关联的组件类列表。 + """ + return self._registry.get(injection_point, []) + + async def execute_components_for(self, injection_point: str, params: PromptParameters) -> str: + """ + 实例化并执行指定注入点的所有组件,然后将它们的输出拼接成一个字符串。 + + Args: + injection_point: 目标Prompt的名称。 + params: 用于初始化组件的 PromptParameters 对象。 + + Returns: + str: 所有相关组件生成的、用换行符连接的文本内容。 + """ + component_classes = self.get_components_for(injection_point) + if not component_classes: + return "" + + tasks = [] + from src.plugin_system.core.component_registry import component_registry + + for component_class in component_classes: + try: + prompt_info = self._prompt_infos.get(component_class.prompt_name) + if not prompt_info: + logger.warning(f"找不到 Prompt 组件 '{component_class.prompt_name}' 的信息,无法获取插件配置") + plugin_config = {} + else: + plugin_config = component_registry.get_plugin_config(prompt_info.plugin_name) + + instance = component_class(params=params, plugin_config=plugin_config) + tasks.append(instance.execute()) + except Exception as e: + logger.error(f"实例化 Prompt 组件 '{component_class.prompt_name}' 失败: {e}") + + if not tasks: + return "" + + # 并行执行所有组件 + results = await asyncio.gather(*tasks, return_exceptions=True) + + # 过滤掉执行失败的结果和空字符串 + valid_results = [] + for i, result in enumerate(results): + if isinstance(result, Exception): + logger.error(f"执行 Prompt 组件 '{component_classes[i].prompt_name}' 失败: {result}") + elif result and isinstance(result, str) and result.strip(): + valid_results.append(result.strip()) + + # 使用换行符拼接所有有效结果 + return "\n".join(valid_results) + + +# 创建全局单例 +prompt_component_manager = PromptComponentManager() \ No newline at end of file diff --git a/src/chat/utils/prompt_params.py b/src/chat/utils/prompt_params.py new file mode 100644 index 000000000..2722c3605 --- /dev/null +++ b/src/chat/utils/prompt_params.py @@ -0,0 +1,79 @@ +""" +This module contains the PromptParameters class, which is used to define the parameters for a prompt. +""" +from dataclasses import dataclass, field +from typing import Any, Literal + + +@dataclass +class PromptParameters: + """统一提示词参数系统""" + + # 基础参数 + chat_id: str = "" + is_group_chat: bool = False + sender: str = "" + target: str = "" + reply_to: str = "" + extra_info: str = "" + prompt_mode: Literal["s4u", "normal", "minimal"] = "s4u" + bot_name: str = "" + bot_nickname: str = "" + + # 功能开关 + enable_tool: bool = True + enable_memory: bool = True + enable_expression: bool = True + enable_relation: bool = True + enable_cross_context: bool = True + enable_knowledge: bool = True + + # 性能控制 + max_context_messages: int = 50 + + # 调试选项 + debug_mode: bool = False + + # 聊天历史和上下文 + chat_target_info: dict[str, Any] | None = None + message_list_before_now_long: list[dict[str, Any]] = field(default_factory=list) + message_list_before_short: list[dict[str, Any]] = field(default_factory=list) + chat_talking_prompt_short: str = "" + target_user_info: dict[str, Any] | None = None + + # 已构建的内容块 + expression_habits_block: str = "" + relation_info_block: str = "" + memory_block: str = "" + tool_info_block: str = "" + knowledge_prompt: str = "" + cross_context_block: str = "" + + # 其他内容块 + keywords_reaction_prompt: str = "" + extra_info_block: str = "" + time_block: str = "" + identity_block: str = "" + schedule_block: str = "" + moderation_prompt_block: str = "" + safety_guidelines_block: str = "" + reply_target_block: str = "" + mood_prompt: str = "" + action_descriptions: str = "" + + # 可用动作信息 + available_actions: dict[str, Any] | None = None + + # 动态生成的聊天场景提示 + chat_scene: str = "" + + def validate(self) -> list[str]: + """参数验证""" + errors = [] + if not self.chat_id: + errors.append("chat_id不能为空") + if self.prompt_mode not in ["s4u", "normal", "minimal"]: + errors.append("prompt_mode必须是's4u'、'normal'或'minimal'") + if self.max_context_messages <= 0: + errors.append("max_context_messages必须大于0") + return errors \ No newline at end of file diff --git a/src/plugin_system/__init__.py b/src/plugin_system/__init__.py index 6f9d14714..3f19de3a1 100644 --- a/src/plugin_system/__init__.py +++ b/src/plugin_system/__init__.py @@ -26,6 +26,7 @@ from .base import ( ActionInfo, BaseAction, BaseCommand, + BasePrompt, BaseEventHandler, BasePlugin, BaseTool, @@ -64,6 +65,7 @@ __all__ = [ "BaseEventHandler", # 基础类 "BasePlugin", + "BasePrompt", "BaseTool", "ChatMode", "ChatType", diff --git a/src/plugin_system/base/__init__.py b/src/plugin_system/base/__init__.py index 0e69b1206..1b62d2a78 100644 --- a/src/plugin_system/base/__init__.py +++ b/src/plugin_system/base/__init__.py @@ -8,6 +8,7 @@ from .base_action import BaseAction from .base_command import BaseCommand from .base_events_handler import BaseEventHandler from .base_plugin import BasePlugin +from .base_prompt import BasePrompt from .base_tool import BaseTool from .command_args import CommandArgs from .component_types import ( @@ -37,6 +38,7 @@ __all__ = [ "BaseCommand", "BaseEventHandler", "BasePlugin", + "BasePrompt", "BaseTool", "ChatMode", "ChatType", diff --git a/src/plugin_system/base/base_plugin.py b/src/plugin_system/base/base_plugin.py index df48a3164..662af3a5e 100644 --- a/src/plugin_system/base/base_plugin.py +++ b/src/plugin_system/base/base_plugin.py @@ -8,6 +8,7 @@ from src.plugin_system.base.component_types import ( EventHandlerInfo, InterestCalculatorInfo, PlusCommandInfo, + PromptInfo, ToolInfo, ) @@ -15,6 +16,7 @@ from .base_action import BaseAction from .base_command import BaseCommand from .base_events_handler import BaseEventHandler from .base_interest_calculator import BaseInterestCalculator +from .base_prompt import BasePrompt from .base_tool import BaseTool from .plugin_base import PluginBase from .plus_command import PlusCommand @@ -80,6 +82,13 @@ class BasePlugin(PluginBase): logger.warning("EventHandler的get_info逻辑尚未实现") return None + elif component_type == ComponentType.PROMPT: + if hasattr(component_class, "get_prompt_info"): + return component_class.get_prompt_info() + else: + logger.warning(f"Prompt类 {component_class.__name__} 缺少 get_prompt_info 方法") + return None + else: logger.error(f"不支持的组件类型: {component_type}") return None @@ -109,6 +118,7 @@ class BasePlugin(PluginBase): | tuple[EventHandlerInfo, type[BaseEventHandler]] | tuple[ToolInfo, type[BaseTool]] | tuple[InterestCalculatorInfo, type[BaseInterestCalculator]] + | tuple[PromptInfo, type[BasePrompt]] ]: """获取插件包含的组件列表 diff --git a/src/plugin_system/base/base_prompt.py b/src/plugin_system/base/base_prompt.py new file mode 100644 index 000000000..8947ea2f5 --- /dev/null +++ b/src/plugin_system/base/base_prompt.py @@ -0,0 +1,95 @@ +from abc import ABC, abstractmethod +from typing import Any + +from src.chat.utils.prompt_params import PromptParameters +from src.common.logger import get_logger +from src.plugin_system.base.component_types import ComponentType, PromptInfo + +logger = get_logger("base_prompt") + + +class BasePrompt(ABC): + """Prompt组件基类 + + Prompt是插件的一种组件类型,用于动态地向现有的核心Prompt模板中注入额外的上下文信息。 + 它的主要作用是在不修改核心代码的情况下,扩展和定制模型的行为。 + + 子类可以通过类属性定义其行为: + - prompt_name: Prompt组件的唯一名称。 + - injection_point: 指定要注入的目标Prompt名称(或名称列表)。 + """ + + prompt_name: str = "" + """Prompt组件的名称""" + prompt_description: str = "" + """Prompt组件的描述""" + + # 定义此组件希望注入到哪个或哪些核心Prompt中 + # 可以是一个字符串(单个目标)或字符串列表(多个目标) + # 例如: "planner_prompt" 或 ["s4u_style_prompt", "normal_style_prompt"] + injection_point: str | list[str] = "" + """要注入的目标Prompt名称或列表""" + + def __init__(self, params: PromptParameters, plugin_config: dict | None = None): + """初始化Prompt组件 + + Args: + params: 统一提示词参数,包含所有构建提示词所需的上下文信息。 + plugin_config: 插件配置字典。 + """ + self.params = params + self.plugin_config = plugin_config or {} + self.log_prefix = "[PromptComponent]" + + logger.debug(f"{self.log_prefix} Prompt组件 '{self.prompt_name}' 初始化完成") + + @abstractmethod + async def execute(self) -> str: + """执行Prompt生成的抽象方法,子类必须实现。 + + 此方法应根据初始化时传入的 `self.params` 来构建并返回一个字符串。 + 返回的字符串将被拼接到目标Prompt的最前面。 + + Returns: + str: 生成的文本内容。 + """ + pass + + def get_config(self, key: str, default: Any = None) -> Any: + """获取插件配置值,支持嵌套键访问。 + + Args: + key: 配置键名,使用点号进行嵌套访问,如 "section.subsection.key"。 + default: 未找到键时返回的默认值。 + + Returns: + Any: 配置值或默认值。 + """ + if not self.plugin_config: + return default + + keys = key.split(".") + current = self.plugin_config + for k in keys: + if isinstance(current, dict) and k in current: + current = current[k] + else: + return default + return current + + @classmethod + def get_prompt_info(cls) -> "PromptInfo": + """从类属性生成PromptInfo,用于组件注册和管理。 + + Returns: + PromptInfo: 生成的Prompt信息对象。 + """ + if not cls.prompt_name: + raise ValueError("Prompt组件必须定义 'prompt_name' 类属性。") + + return PromptInfo( + name=cls.prompt_name, + component_type=ComponentType.PROMPT, + description=cls.prompt_description, + injection_point=cls.injection_point, + ) \ No newline at end of file diff --git a/src/plugin_system/base/component_types.py b/src/plugin_system/base/component_types.py index 53952319e..5db5fdeb3 100644 --- a/src/plugin_system/base/component_types.py +++ b/src/plugin_system/base/component_types.py @@ -20,6 +20,7 @@ class ComponentType(Enum): EVENT_HANDLER = "event_handler" # 事件处理组件 CHATTER = "chatter" # 聊天处理器组件 INTEREST_CALCULATOR = "interest_calculator" # 兴趣度计算组件 + PROMPT = "prompt" # Prompt组件 def __str__(self) -> str: return self.value @@ -266,6 +267,18 @@ class EventInfo(ComponentInfo): self.component_type = ComponentType.EVENT_HANDLER +@dataclass +class PromptInfo(ComponentInfo): + """Prompt组件信息""" + + injection_point: str | list[str] = "" + """要注入的目标Prompt名称或列表""" + + def __post_init__(self): + super().__post_init__() + self.component_type = ComponentType.PROMPT + + @dataclass class PluginInfo: """插件信息""" diff --git a/src/plugin_system/core/component_registry.py b/src/plugin_system/core/component_registry.py index 51da61743..310f95997 100644 --- a/src/plugin_system/core/component_registry.py +++ b/src/plugin_system/core/component_registry.py @@ -11,6 +11,7 @@ from src.plugin_system.base.base_chatter import BaseChatter from src.plugin_system.base.base_command import BaseCommand from src.plugin_system.base.base_events_handler import BaseEventHandler from src.plugin_system.base.base_interest_calculator import BaseInterestCalculator +from src.plugin_system.base.base_prompt import BasePrompt from src.plugin_system.base.base_tool import BaseTool from src.plugin_system.base.component_types import ( ActionInfo, @@ -22,6 +23,7 @@ from src.plugin_system.base.component_types import ( InterestCalculatorInfo, PluginInfo, PlusCommandInfo, + PromptInfo, ToolInfo, ) from src.plugin_system.base.plus_command import PlusCommand @@ -37,6 +39,7 @@ ComponentClassType = ( | type[PlusCommand] | type[BaseChatter] | type[BaseInterestCalculator] + | type[BasePrompt] ) @@ -183,6 +186,10 @@ class ComponentRegistry: assert isinstance(component_info, InterestCalculatorInfo) assert issubclass(component_class, BaseInterestCalculator) ret = self._register_interest_calculator_component(component_info, component_class) + case ComponentType.PROMPT: + assert isinstance(component_info, PromptInfo) + assert issubclass(component_class, BasePrompt) + ret = self._register_prompt_component(component_info, component_class) case _: logger.warning(f"未知组件类型: {component_type}") ret = False @@ -346,6 +353,31 @@ class ComponentRegistry: logger.debug(f"已注册InterestCalculator组件: {calculator_name}") return True + def _register_prompt_component( + self, prompt_info: PromptInfo, prompt_class: "ComponentClassType" + ) -> bool: + """注册Prompt组件到Prompt特定注册表""" + prompt_name = prompt_info.name + if not prompt_name: + logger.error(f"Prompt组件 {prompt_class.__name__} 必须指定名称") + return False + + if not hasattr(self, "_prompt_registry"): + self._prompt_registry: dict[str, type[BasePrompt]] = {} + if not hasattr(self, "_enabled_prompt_registry"): + self._enabled_prompt_registry: dict[str, type[BasePrompt]] = {} + + _assign_plugin_attrs( + prompt_class, prompt_info.plugin_name, self.get_plugin_config(prompt_info.plugin_name) or {} + ) + self._prompt_registry[prompt_name] = prompt_class # type: ignore + + if prompt_info.enabled: + self._enabled_prompt_registry[prompt_name] = prompt_class # type: ignore + + logger.debug(f"已注册Prompt组件: {prompt_name}") + return True + # === 组件移除相关 === async def remove_component(self, component_name: str, component_type: ComponentType, plugin_name: str) -> bool: @@ -580,7 +612,17 @@ class ComponentRegistry: component_name: str, component_type: ComponentType | None = None, ) -> ( - type[BaseCommand | BaseAction | BaseEventHandler | BaseTool | PlusCommand | BaseChatter | BaseInterestCalculator] | None + type[ + BaseCommand + | BaseAction + | BaseEventHandler + | BaseTool + | PlusCommand + | BaseChatter + | BaseInterestCalculator + | BasePrompt + ] + | None ): """获取组件类,支持自动命名空间解析 @@ -829,6 +871,7 @@ class ComponentRegistry: events_handlers: int = 0 plus_command_components: int = 0 chatter_components: int = 0 + prompt_components: int = 0 for component in self._components.values(): if component.component_type == ComponentType.ACTION: action_components += 1 @@ -842,6 +885,8 @@ class ComponentRegistry: plus_command_components += 1 elif component.component_type == ComponentType.CHATTER: chatter_components += 1 + elif component.component_type == ComponentType.PROMPT: + prompt_components += 1 return { "action_components": action_components, "command_components": command_components, @@ -849,6 +894,7 @@ class ComponentRegistry: "event_handlers": events_handlers, "plus_command_components": plus_command_components, "chatter_components": chatter_components, + "prompt_components": prompt_components, "total_components": len(self._components), "total_plugins": len(self._plugins), "components_by_type": { diff --git a/src/plugin_system/core/plugin_manager.py b/src/plugin_system/core/plugin_manager.py index 6cd89e5f4..3a59efeda 100644 --- a/src/plugin_system/core/plugin_manager.py +++ b/src/plugin_system/core/plugin_manager.py @@ -358,13 +358,14 @@ class PluginManager: event_handler_count = stats.get("event_handlers", 0) plus_command_count = stats.get("plus_command_components", 0) chatter_count = stats.get("chatter_components", 0) + prompt_count = stats.get("prompt_components", 0) total_components = stats.get("total_components", 0) # 📋 显示插件加载总览 if total_registered > 0: logger.info("🎉 插件系统加载完成!") logger.info( - f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count}, Tool: {tool_count}, PlusCommand: {plus_command_count}, EventHandler: {event_handler_count}, Chatter: {chatter_count})" + f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count}, Tool: {tool_count}, PlusCommand: {plus_command_count}, EventHandler: {event_handler_count}, Chatter: {chatter_count}, Prompt: {prompt_count})" ) # 显示详细的插件列表 @@ -402,6 +403,9 @@ class PluginManager: plus_command_components = [ c for c in plugin_info.components if c.component_type == ComponentType.PLUS_COMMAND ] + prompt_components = [ + c for c in plugin_info.components if c.component_type == ComponentType.PROMPT + ] if action_components: action_details = [format_component(c) for c in action_components] @@ -425,6 +429,9 @@ class PluginManager: if event_handler_components: event_handler_details = [format_component(c) for c in event_handler_components] logger.info(f" 📢 EventHandler组件: {', '.join(event_handler_details)}") + if prompt_components: + prompt_details = [format_component(c) for c in prompt_components] + logger.info(f" 📝 Prompt组件: {', '.join(prompt_details)}") # 权限节点信息 if plugin_instance := self.loaded_plugins.get(plugin_name):