将Prompt组件的注入逻辑从`Prompt.format`方法前置到`PromptManager.get_prompt_async`和`create_prompt_async`中。这使得注入时机更早,逻辑更清晰,并允许在获取Prompt时就能动态传入参数以影响注入内容。 主要变更: - `PromptManager`: `get_prompt_async`现在负责处理组件注入,并接收可选的`parameters`参数。`format_prompt`相应地传递参数。 - `create_prompt_async`: 现在也支持在创建时进行动态注入。 - `Prompt.format`: 移除了原有的组件注入逻辑,简化了方法实现。 - `PromptComponentManager`: 重构为直接从全局`component_registry`获取组件,移除了自身的注册和存储逻辑,减少了状态管理的复杂性。 - `plan_filter.py`: 删除了大量冗余和重复的代码块,包括主动聊天模式的独立逻辑和旧的历史消息构建方式。
109 lines
4.3 KiB
Python
109 lines
4.3 KiB
Python
import asyncio
|
||
from typing import Type
|
||
|
||
from src.chat.utils.prompt_params import PromptParameters
|
||
from src.common.logger import get_logger
|
||
from src.plugin_system.base.base_prompt import BasePrompt
|
||
from src.plugin_system.base.component_types import ComponentType, PromptInfo
|
||
from src.plugin_system.core.component_registry import component_registry
|
||
|
||
logger = get_logger("prompt_component_manager")
|
||
|
||
|
||
class PromptComponentManager:
|
||
"""
|
||
管理所有 `BasePrompt` 组件的单例类。
|
||
|
||
该管理器负责:
|
||
1. 从 `component_registry` 中查询 `BasePrompt` 子类。
|
||
2. 根据注入点(目标Prompt名称)对它们进行筛选。
|
||
3. 提供一个接口,以便在构建核心Prompt时,能够获取并执行所有相关的组件。
|
||
"""
|
||
|
||
def get_components_for(self, injection_point: str) -> list[Type[BasePrompt]]:
|
||
"""
|
||
获取指定注入点的所有已注册组件类。
|
||
|
||
Args:
|
||
injection_point: 目标Prompt的名称。
|
||
|
||
Returns:
|
||
list[Type[BasePrompt]]: 与该注入点关联的组件类列表。
|
||
"""
|
||
# 从组件注册中心获取所有启用的Prompt组件
|
||
enabled_prompts = component_registry.get_enabled_components_by_type(ComponentType.PROMPT)
|
||
|
||
matching_components: list[Type[BasePrompt]] = []
|
||
|
||
for prompt_name, prompt_info in enabled_prompts.items():
|
||
# 确保 prompt_info 是 PromptInfo 类型
|
||
if not isinstance(prompt_info, PromptInfo):
|
||
continue
|
||
|
||
# 获取注入点信息
|
||
injection_points = prompt_info.injection_point
|
||
if isinstance(injection_points, str):
|
||
injection_points = [injection_points]
|
||
|
||
# 检查当前注入点是否匹配
|
||
if injection_point in injection_points:
|
||
# 获取组件类
|
||
component_class = component_registry.get_component_class(prompt_name, ComponentType.PROMPT)
|
||
if component_class and issubclass(component_class, BasePrompt):
|
||
matching_components.append(component_class)
|
||
|
||
return matching_components
|
||
|
||
async def execute_components_for(self, injection_point: str, params: PromptParameters) -> str:
|
||
"""
|
||
实例化并执行指定注入点的所有组件,然后将它们的输出拼接成一个字符串。
|
||
|
||
Args:
|
||
injection_point: 目标Prompt的名称。
|
||
params: 用于初始化组件的 PromptParameters 对象。
|
||
|
||
Returns:
|
||
str: 所有相关组件生成的、用换行符连接的文本内容。
|
||
"""
|
||
component_classes = self.get_components_for(injection_point)
|
||
if not component_classes:
|
||
return ""
|
||
|
||
tasks = []
|
||
for component_class in component_classes:
|
||
try:
|
||
# 从注册中心获取组件信息
|
||
prompt_info = component_registry.get_component_info(
|
||
component_class.prompt_name, ComponentType.PROMPT
|
||
)
|
||
if not isinstance(prompt_info, PromptInfo):
|
||
logger.warning(f"找不到 Prompt 组件 '{component_class.prompt_name}' 的信息,无法获取插件配置")
|
||
plugin_config = {}
|
||
else:
|
||
plugin_config = component_registry.get_plugin_config(prompt_info.plugin_name)
|
||
|
||
instance = component_class(params=params, plugin_config=plugin_config)
|
||
tasks.append(instance.execute())
|
||
except Exception as e:
|
||
logger.error(f"实例化 Prompt 组件 '{component_class.prompt_name}' 失败: {e}")
|
||
|
||
if not tasks:
|
||
return ""
|
||
|
||
# 并行执行所有组件
|
||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||
|
||
# 过滤掉执行失败的结果和空字符串
|
||
valid_results = []
|
||
for i, result in enumerate(results):
|
||
if isinstance(result, Exception):
|
||
logger.error(f"执行 Prompt 组件 '{component_classes[i].prompt_name}' 失败: {result}")
|
||
elif result and isinstance(result, str) and result.strip():
|
||
valid_results.append(result.strip())
|
||
|
||
# 使用换行符拼接所有有效结果
|
||
return "\n".join(valid_results)
|
||
|
||
|
||
# 创建全局单例
|
||
prompt_component_manager = PromptComponentManager() |