feat(plugin): 引入Prompt组件系统以实现动态Prompt注入

引入了一个新的插件组件类型 `BasePrompt`,允许插件动态地向核心Prompt模板中注入额外的上下文信息。该系统旨在提高Prompt的可扩展性和可定制性,使得开发者可以在不修改核心代码的情况下,通过插件来丰富和调整模型的行为。

主要变更包括:
- **`BasePrompt` 基类**: 定义了Prompt组件的标准接口,包括 `execute` 方法用于生成注入内容,以及 `injection_point` 属性用于指定目标Prompt。
- **`PromptComponentManager`**: 一个新的管理器,负责注册、分类和执行所有 `BasePrompt` 组件。它会在构建核心Prompt时,自动查找并执行相关组件,将其输出拼接到主Prompt内容之前。
- **核心Prompt逻辑更新**: `src.chat.utils.prompt.Prompt` 类现在会调用 `PromptComponentManager` 来获取并注入组件内容。
- **插件系统集成**: `ComponentRegistry` 和 `PluginManager` 已更新,以支持 `BasePrompt` 组件的注册、管理和统计。
- **示例插件更新**: `hello_world_plugin` 中增加了一个 `WeatherPrompt` 示例,演示了如何创建和注册一个新的Prompt组件。
- **代码重构**: 将 `PromptParameters` 类从 `prompt.py` 移动到独立的 `prompt_params.py` 文件中,以改善模块化和解决循环依赖问题。
This commit is contained in:
minecraft1024a
2025-10-19 13:00:18 +08:00
committed by Windpicker-owo
parent baefca2115
commit 0917318cbd
12 changed files with 428 additions and 85 deletions

View File

@@ -0,0 +1,123 @@
import asyncio
from collections import defaultdict
from typing import Type
from src.chat.utils.prompt_params import PromptParameters
from src.common.logger import get_logger
from src.plugin_system.base.base_prompt import BasePrompt
from src.plugin_system.base.component_types import PromptInfo
logger = get_logger("prompt_component_manager")
class PromptComponentManager:
"""
管理所有 `BasePrompt` 组件的单例类。
该管理器负责:
1. 注册由插件定义的 `BasePrompt` 子类。
2. 根据注入点目标Prompt名称对它们进行分类存储。
3. 提供一个接口以便在构建核心Prompt时能够获取并执行所有相关的组件。
"""
def __init__(self):
self._registry: dict[str, list[Type[BasePrompt]]] = defaultdict(list)
self._prompt_infos: dict[str, PromptInfo] = {}
def register(self, component_class: Type[BasePrompt]):
"""
注册一个 `BasePrompt` 组件类。
Args:
component_class: 要注册的 `BasePrompt` 子类。
"""
if not issubclass(component_class, BasePrompt):
logger.warning(f"尝试注册一个非 BasePrompt 的类: {component_class.__name__}")
return
try:
prompt_info = component_class.get_prompt_info()
if prompt_info.name in self._prompt_infos:
logger.warning(f"重复注册 Prompt 组件: {prompt_info.name}。将覆盖旧组件。")
injection_points = prompt_info.injection_point
if isinstance(injection_points, str):
injection_points = [injection_points]
if not injection_points or not all(injection_points):
logger.debug(f"Prompt 组件 '{prompt_info.name}' 未指定有效的 injection_point将不会被自动注入。")
return
for point in injection_points:
self._registry[point].append(component_class)
self._prompt_infos[prompt_info.name] = prompt_info
logger.info(f"成功注册 Prompt 组件 '{prompt_info.name}' 到注入点: {injection_points}")
except ValueError as e:
logger.error(f"注册 Prompt 组件失败 {component_class.__name__}: {e}")
def get_components_for(self, injection_point: str) -> list[Type[BasePrompt]]:
"""
获取指定注入点的所有已注册组件类。
Args:
injection_point: 目标Prompt的名称。
Returns:
list[Type[BasePrompt]]: 与该注入点关联的组件类列表。
"""
return self._registry.get(injection_point, [])
async def execute_components_for(self, injection_point: str, params: PromptParameters) -> str:
"""
实例化并执行指定注入点的所有组件,然后将它们的输出拼接成一个字符串。
Args:
injection_point: 目标Prompt的名称。
params: 用于初始化组件的 PromptParameters 对象。
Returns:
str: 所有相关组件生成的、用换行符连接的文本内容。
"""
component_classes = self.get_components_for(injection_point)
if not component_classes:
return ""
tasks = []
from src.plugin_system.core.component_registry import component_registry
for component_class in component_classes:
try:
prompt_info = self._prompt_infos.get(component_class.prompt_name)
if not prompt_info:
logger.warning(f"找不到 Prompt 组件 '{component_class.prompt_name}' 的信息,无法获取插件配置")
plugin_config = {}
else:
plugin_config = component_registry.get_plugin_config(prompt_info.plugin_name)
instance = component_class(params=params, plugin_config=plugin_config)
tasks.append(instance.execute())
except Exception as e:
logger.error(f"实例化 Prompt 组件 '{component_class.prompt_name}' 失败: {e}")
if not tasks:
return ""
# 并行执行所有组件
results = await asyncio.gather(*tasks, return_exceptions=True)
# 过滤掉执行失败的结果和空字符串
valid_results = []
for i, result in enumerate(results):
if isinstance(result, Exception):
logger.error(f"执行 Prompt 组件 '{component_classes[i].prompt_name}' 失败: {result}")
elif result and isinstance(result, str) and result.strip():
valid_results.append(result.strip())
# 使用换行符拼接所有有效结果
return "\n".join(valid_results)
# 创建全局单例
prompt_component_manager = PromptComponentManager()