refactor(plugin_system): 重构 Prompt 注入逻辑以实现动态化

本次重构的核心目标是将 Prompt 注入规则的处理方式从系统启动时的一次性加载,转变为在每次需要注入时实时、动态地构建。这解决了之前静态加载机制下,运行时启用/禁用 Prompt 组件无法影响其注入行为的问题。

主要变更包括:

- **PromptComponentManager 动态化**:
    - 移除了 `load_static_rules` 和 `_initialized` 标志,规则不再在启动时预加载到 `_dynamic_rules` 中。
    - `_dynamic_rules` 现在只存储通过 API 动态添加的纯运行时规则。
    - 新增 `_build_rules_for_target` 方法,该方法在 `apply_injections` 时被调用,实时从 `component_registry` 获取所有已启用的静态组件规则,并与 `_dynamic_rules` 中的运行时规则合并,确保规则集始终反映当前系统状态。

- **依赖 ComponentRegistry**:
    - `PromptComponentManager` 现在直接依赖 `component_registry` 来获取组件的最新启用状态和信息,而不是依赖自己预加载的缓存。
    - `get_registered_prompt_component_info`, `get_injection_info`, `get_injection_rules` 等多个 API 方法被修改为 `async`,并重写了内部逻辑,以动态查询和构建信息,确保返回的数据准确反映了当前所有可用组件(包括静态和纯动态)的注入配置。

- **ComponentRegistry 增强**:
    - 增加了对 Prompt 组件在禁用时从内部启用的注册表中移除的逻辑。
    - 扩展了 `is_component_available` 的逻辑,使其能正确处理不支持局部(stream-specific)状态的组件类型。
This commit is contained in:
minecraft1024a
2025-11-22 11:15:45 +08:00
parent affd70b165
commit 30bf1f68b1
3 changed files with 196 additions and 204 deletions

View File

@@ -29,89 +29,14 @@ class PromptComponentManager:
def __init__(self):
"""初始化管理器实例。"""
# _dynamic_rules 是管理器的核心状态,存储所有注入规则。
# _dynamic_rules 仅用于存储通过 API 动态添加的、非静态组件的规则。
# 结构: {
# "target_prompt_name": {
# "prompt_component_name": (InjectionRule, content_provider, source)
# }
# }
# content_provider 是一个异步函数,用于在应用规则时动态生成注入内容。
# source 记录了规则的来源(例如 "static_default" 或 "runtime")。
self._dynamic_rules: dict[str, dict[str, tuple[InjectionRule, Callable[..., Awaitable[str]], str]]] = {}
self._lock = asyncio.Lock() # 使用异步锁确保对 _dynamic_rules 的并发访问安全。
self._initialized = False # 标记静态规则是否已加载,防止重复加载。
# --- 核心生命周期与初始化 ---
def load_static_rules(self):
"""
在系统启动时加载所有静态注入规则。
该方法会扫描所有已在 `component_registry` 中注册并启用的 Prompt 组件,
将其类变量 `injection_rules` 转换为管理器的动态规则。
这确保了所有插件定义的默认注入行为在系统启动时就能生效。
此操作是幂等的,一旦初始化完成就不会重复执行。
"""
if self._initialized:
return
logger.info("正在加载静态 Prompt 注入规则...")
# 从组件注册表中获取所有已启用的 Prompt 组件
enabled_prompts = component_registry.get_enabled_components_by_type(ComponentType.PROMPT)
for prompt_name, prompt_info in enabled_prompts.items():
if not isinstance(prompt_info, PromptInfo):
continue
component_class = component_registry.get_component_class(prompt_name, ComponentType.PROMPT)
if not (component_class and issubclass(component_class, BasePrompt)):
logger.warning(f"无法为 '{prompt_name}' 加载静态规则,因为它不是一个有效的 Prompt 组件。")
continue
def create_provider(
cls: type[BasePrompt],
) -> Callable[[PromptParameters, str], Awaitable[str]]:
"""
为静态组件创建一个内容提供者闭包 (Content Provider Closure)。
这个闭包捕获了组件的类 `cls`,并返回一个标准的 `content_provider` 异步函数。
当 `apply_injections` 需要内容时,它会调用这个函数。
函数内部会实例化组件,并执行其 `execute` 方法来获取注入内容。
Args:
cls (type[BasePrompt]): 需要为其创建提供者的 Prompt 组件类。
Returns:
Callable[[PromptParameters, str], Awaitable[str]]: 一个符合管理器标准的异步内容提供者。
"""
async def content_provider(params: PromptParameters, target_prompt_name: str) -> str:
"""实际执行内容生成的异步函数。"""
try:
# 从注册表获取最新的组件信息,包括插件配置
p_info = component_registry.get_component_info(cls.prompt_name, ComponentType.PROMPT)
plugin_config = {}
if isinstance(p_info, PromptInfo):
plugin_config = component_registry.get_plugin_config(p_info.plugin_name)
# 实例化组件并执行,传入 target_prompt_name
instance = cls(params=params, plugin_config=plugin_config, target_prompt_name=target_prompt_name)
result = await instance.execute()
return str(result) if result is not None else ""
except Exception as e:
logger.error(f"执行静态规则提供者 '{cls.prompt_name}' 时出错: {e}", exc_info=True)
return "" # 出错时返回空字符串,避免影响主流程
return content_provider
# 为该组件的每条静态注入规则创建并注册一个动态规则
for rule in prompt_info.injection_rules:
provider = create_provider(component_class)
target_rules = self._dynamic_rules.setdefault(rule.target_prompt, {})
target_rules[prompt_name] = (rule, provider, "static_default")
self._initialized = True
logger.info(f"静态 Prompt 注入规则加载完成,共处理 {len(enabled_prompts)} 个组件。")
self._lock = asyncio.Lock() # 锁现在保护 _dynamic_rules
# --- 运行时规则管理 API ---
@@ -243,6 +168,65 @@ class PromptComponentManager:
return removed
# --- 核心注入逻辑 ---
def _create_content_provider(
self, component_name: str, component_class: type[BasePrompt]
) -> Callable[[PromptParameters, str], Awaitable[str]]:
"""为指定的组件类创建一个标准化的内容提供者闭包。"""
async def content_provider(params: PromptParameters, target_prompt_name: str) -> str:
"""实际执行内容生成的异步函数。"""
try:
p_info = component_registry.get_component_info(component_name, ComponentType.PROMPT)
plugin_config = {}
if isinstance(p_info, PromptInfo):
plugin_config = component_registry.get_plugin_config(p_info.plugin_name)
instance = component_class(
params=params, plugin_config=plugin_config, target_prompt_name=target_prompt_name
)
result = await instance.execute()
return str(result) if result is not None else ""
except Exception as e:
logger.error(f"执行规则提供者 '{component_name}' 时出错: {e}", exc_info=True)
return ""
return content_provider
async def _build_rules_for_target(self, target_prompt_name: str) -> list:
"""在注入时动态构建目标的所有有效规则列表。"""
all_rules = []
# 1. 从 component_registry 获取所有静态组件的规则
static_components = component_registry.get_components_by_type(ComponentType.PROMPT)
for name, info in static_components.items():
if not isinstance(info, PromptInfo):
continue
# 实时检查组件是否启用
if not component_registry.is_component_available(name, ComponentType.PROMPT):
continue
component_class = component_registry.get_component_class(name, ComponentType.PROMPT)
if not (component_class and issubclass(component_class, BasePrompt)):
continue
provider = self._create_content_provider(name, component_class)
for rule in info.injection_rules:
if rule.target_prompt == target_prompt_name:
all_rules.append((rule, provider, "static"))
# 2. 从 _dynamic_rules 获取所有纯运行时规则
async with self._lock:
runtime_rules = self._dynamic_rules.get(target_prompt_name, {})
for name, (rule, provider, source) in runtime_rules.items():
# 确保运行时组件不会与禁用的静态组件冲突
static_info = component_registry.get_component_info(name, ComponentType.PROMPT)
if static_info and not component_registry.is_component_available(name, ComponentType.PROMPT):
logger.debug(f"跳过运行时规则 '{name}',因为它关联的静态组件当前已禁用。")
continue
all_rules.append((rule, provider, source))
return all_rules
async def apply_injections(
self, target_prompt_name: str, original_template: str, params: PromptParameters
@@ -268,10 +252,7 @@ class PromptComponentManager:
Returns:
str: 应用了所有注入规则后,最终生成的提示词模板字符串。
"""
if not self._initialized:
self.load_static_rules()
rules_for_target = list(self._dynamic_rules.get(target_prompt_name, {}).values())
rules_for_target = await self._build_rules_for_target(target_prompt_name)
if not rules_for_target:
return original_template
@@ -405,55 +386,41 @@ class PromptComponentManager:
return [[name, prompt.template] for name, prompt in global_prompt_manager._prompts.items()]
def get_registered_prompt_component_info(self) -> list[PromptInfo]:
async def get_registered_prompt_component_info(self) -> list[PromptInfo]:
"""
获取所有已注册和动态添加的Prompt组件信息并反映当前的注入规则状态。
该方法会合并静态注册的组件信息和运行时的动态注入规则,
确保返回的 `PromptInfo` 列表能够准确地反映系统当前的完整状态。
Returns:
list[PromptInfo]: 一个包含所有静态和动态Prompt组件信息的列表。
每个组件的 `injection_rules` 都会被更新为当前实际生效的规则。
此方法现在直接从 component_registry 获取静态组件信息,并合并纯运行时的组件信息。
"""
# 步骤 1: 获取所有静态注册的组件信息,并使用深拷贝以避免修改原始数据
static_components = component_registry.get_components_by_type(ComponentType.PROMPT)
# 使用深拷贝以避免修改原始注册表数据
info_dict: dict[str, PromptInfo] = {
name: copy.deepcopy(info) for name, info in static_components.items() if isinstance(info, PromptInfo)
}
# 该方法现在直接从 component_registry 获取信息,因为它总是有最新的数据
all_components = component_registry.get_components_by_type(ComponentType.PROMPT)
info_list = [info for info in all_components.values() if isinstance(info, PromptInfo)]
# 步骤 2: 遍历动态规则,识别并创建纯动态组件的 PromptInfo
all_dynamic_component_names = set()
for target, rules in self._dynamic_rules.items():
for prompt_name, (rule, _, source) in rules.items():
all_dynamic_component_names.add(prompt_name)
# 检查是否有纯动态组件需要添加
async with self._lock:
runtime_component_names = set()
for rules in self._dynamic_rules.values():
runtime_component_names.update(rules.keys())
for name in all_dynamic_component_names:
if name not in info_dict:
# 这是一个纯动态组件,为其创建一个新的 PromptInfo
info_dict[name] = PromptInfo(
static_component_names = {info.name for info in info_list}
pure_dynamic_names = runtime_component_names - static_component_names
for name in pure_dynamic_names:
# 为纯动态组件创建临时的 PromptInfo
dynamic_info = PromptInfo(
name=name,
component_type=ComponentType.PROMPT,
description="Dynamically added component",
plugin_name="runtime", # 动态组件通常没有插件归属
description="Dynamically added runtime component",
plugin_name="runtime",
is_built_in=False,
)
# 从 _dynamic_rules 中收集其所有规则
for target, rules_in_target in self._dynamic_rules.items():
if name in rules_in_target:
rule, _, _ = rules_in_target[name]
dynamic_info.injection_rules.append(rule)
info_list.append(dynamic_info)
# 步骤 3: 清空所有组件的注入规则,准备用当前状态重新填充
for info in info_dict.values():
info.injection_rules = []
# 步骤 4: 再次遍历动态规则,为每个组件重建其 injection_rules 列表
for target, rules in self._dynamic_rules.items():
for prompt_name, (rule, _, _) in rules.items():
if prompt_name in info_dict:
# 确保规则是 InjectionRule 的实例
if isinstance(rule, InjectionRule):
info_dict[prompt_name].injection_rules.append(rule)
# 步骤 5: 返回最终的 PromptInfo 对象列表
return list(info_dict.values())
return info_list
async def get_injection_info(
self,
@@ -462,36 +429,23 @@ class PromptComponentManager:
) -> dict[str, list[dict]]:
"""
获取注入信息的映射图,可按目标筛选,并可控制信息的详细程度。
- `get_injection_info()` 返回所有目标的摘要注入信息。
- `get_injection_info(target_prompt="...")` 返回指定目标的摘要注入信息。
- `get_injection_info(detailed=True)` 返回所有目标的详细注入信息。
- `get_injection_info(target_prompt="...", detailed=True)` 返回指定目标的详细注入信息。
Args:
target_prompt (str, optional): 如果指定,仅返回该目标的注入信息。
detailed (bool, optional): 如果为 True则返回包含注入类型和内容的详细信息。
默认为 False返回摘要信息。
Returns:
dict[str, list[dict]]: 一个字典,键是目标提示词名称,
值是按优先级排序的注入信息列表。
此方法现在动态构建信息,以反映当前启用的组件和规则。
"""
info_map = {}
async with self._lock:
all_targets = set(self._dynamic_rules.keys()) | set(self.get_core_prompts())
# 如果指定了目标,则只处理该目标
targets_to_process = [target_prompt] if target_prompt and target_prompt in all_targets else sorted(all_targets)
all_core_prompts = self.get_core_prompts()
targets_to_process = [target_prompt] if target_prompt and target_prompt in all_core_prompts else all_core_prompts
for target in targets_to_process:
rules = self._dynamic_rules.get(target, {})
if not rules:
# 动态构建规则列表
rules_for_target = await self._build_rules_for_target(target)
if not rules_for_target:
info_map[target] = []
continue
info_list = []
for prompt_name, (rule, _, source) in rules.items():
for rule, _, source in rules_for_target:
# 从规则本身获取组件名
prompt_name = rule.owner_component
if detailed:
info_list.append(
{
@@ -509,13 +463,13 @@ class PromptComponentManager:
info_map[target] = info_list
return info_map
def get_injection_rules(
async def get_injection_rules(
self,
target_prompt: str | None = None,
component_name: str | None = None,
) -> dict[str, dict[str, "InjectionRule"]]:
"""
获取动态注入规则,可通过目标或组件名称进行筛选。
获取所有(包括静态和运行时)注入规则,可通过目标或组件名称进行筛选。
- 不提供任何参数时,返回所有规则。
- 提供 `target_prompt` 时,仅返回注入到该目标的规则。
@@ -527,44 +481,42 @@ class PromptComponentManager:
component_name (str, optional): 按注入组件名称筛选。
Returns:
dict[str, dict[str, InjectionRule]]: 一个深拷贝的规则字典。
dict[str, dict[str, InjectionRule]]: 一个包含所有匹配规则的深拷贝字典。
结构: { "target_prompt": { "component_name": InjectionRule } }
"""
rules_copy = {}
# 筛选目标
targets_to_check = [target_prompt] if target_prompt else self._dynamic_rules.keys()
all_rules: dict[str, dict[str, InjectionRule]] = {}
for target in targets_to_check:
if target not in self._dynamic_rules:
# 1. 收集所有静态组件的规则
static_components = component_registry.get_components_by_type(ComponentType.PROMPT)
for name, info in static_components.items():
if not isinstance(info, PromptInfo):
continue
# 应用 component_name 筛选
if component_name and name != component_name:
continue
rules_for_target = self._dynamic_rules[target]
target_copy = {}
for rule in info.injection_rules:
# 应用 target_prompt 筛选
if target_prompt and rule.target_prompt != target_prompt:
continue
target_dict = all_rules.setdefault(rule.target_prompt, {})
target_dict[name] = rule
# 筛选组件
if component_name:
if component_name in rules_for_target:
rule, _, _ = rules_for_target[component_name]
target_copy[component_name] = rule
else:
for name, (rule, _, _) in rules_for_target.items():
target_copy[name] = rule
# 2. 收集并合并所有纯运行时规则
async with self._lock:
for target, rules_in_target in self._dynamic_rules.items():
# 应用 target_prompt 筛选
if target_prompt and target != target_prompt:
continue
if target_copy:
rules_copy[target] = target_copy
for name, (rule, _, _) in rules_in_target.items():
# 应用 component_name 筛选
if component_name and name != component_name:
continue
target_dict = all_rules.setdefault(target, {})
target_dict[name] = rule
# 如果是按组件筛选且未指定目标,则需遍历所有目标
if component_name and not target_prompt:
found_rules = {}
for target, rules in self._dynamic_rules.items():
if component_name in rules:
rule, _, _ = rules[component_name]
if target not in found_rules:
found_rules[target] = {}
found_rules[target][component_name] = rule
return copy.deepcopy(found_rules)
return copy.deepcopy(rules_copy)
return copy.deepcopy(all_rules)
# 创建全局单例 (Singleton)

View File

@@ -110,7 +110,7 @@ class BaseAction(ABC):
**kwargs: 其他参数
"""
if plugin_config is None:
plugin_config: ClassVar = {}
plugin_config = {}
self.action_data = action_data
self.reasoning = reasoning
self.cycle_timers = cycle_timers

View File

@@ -605,6 +605,9 @@ class ComponentRegistry:
result = event_manager.unsubscribe_handler_from_event(event, component_name)
if hasattr(result, "__await__"):
await result # type: ignore[func-returns-value]
case ComponentType.PROMPT:
if hasattr(self, "_enabled_prompt_registry"):
self._enabled_prompt_registry.pop(component_name, None)
# 组件主注册表使用命名空间 key
namespaced_name = f"{component_type.value}.{component_name}"
@@ -915,6 +918,27 @@ class ComponentRegistry:
info = self.get_component_info(chatter_name, ComponentType.CHATTER)
return info if isinstance(info, ChatterInfo) else None
# === Prompt 特定查询方法 ===
def get_prompt_registry(self) -> dict[str, type[BasePrompt]]:
"""获取Prompt注册表"""
if not hasattr(self, "_prompt_registry"):
self._prompt_registry: dict[str, type[BasePrompt]] = {}
return self._prompt_registry.copy()
def get_enabled_prompt_registry(self, stream_id: str | None = None) -> dict[str, type[BasePrompt]]:
"""获取启用的Prompt注册表, 可选地根据 stream_id 考虑局部状态"""
all_prompts = self.get_prompt_registry()
available_prompts = {}
for name, prompt_class in all_prompts.items():
if self.is_component_available(name, ComponentType.PROMPT, stream_id):
available_prompts[name] = prompt_class
return available_prompts
def get_registered_prompt_info(self, prompt_name: str) -> PromptInfo | None:
"""获取Prompt信息"""
info = self.get_component_info(prompt_name, ComponentType.PROMPT)
return info if isinstance(info, PromptInfo) else None
# === 插件查询方法 ===
def get_plugin_info(self, plugin_name: str) -> PluginInfo | None:
@@ -1020,15 +1044,27 @@ class ComponentRegistry:
self, stream_id: str, component_name: str, component_type: ComponentType, enabled: bool
) -> bool:
"""为指定的 stream_id 设置组件的局部(临时)状态"""
# 如果组件类型不需要局部状态管理,则记录警告并返回
if component_type in self._no_local_state_types:
logger.warning(
f"组件类型 {component_type.value} 不支持局部状态管理。 "
f"尝试为 '{component_name}' 设置局部状态的操作将被忽略。"
)
return False
if stream_id not in self._local_component_states:
self._local_component_states[stream_id] = {}
state_key = (component_name, component_type)
self._local_component_states[stream_id][state_key] = enabled
logger.debug(f"已为 stream '{stream_id}' 设置局部状态: {component_name} ({component_type}) -> {'启用' if enabled else '禁用'}")
logger.debug(
f"已为 stream '{stream_id}' 设置局部状态: {component_name} ({component_type}) -> {'启用' if enabled else '禁用'}"
)
return True
def is_component_available(self, component_name: str, component_type: ComponentType, stream_id: str | None = None) -> bool:
def is_component_available(
self, component_name: str, component_type: ComponentType, stream_id: str | None = None
) -> bool:
"""检查组件在给定上下文中是否可用(同时考虑全局和局部状态)"""
component_info = self.get_component_info(component_name, component_type)
@@ -1036,7 +1072,11 @@ class ComponentRegistry:
if not component_info:
return False
# 2. 如果提供了 stream_id检查局部状态
# 2. 如果组件类型不需要局部状态,则直接返回其全局状态
if component_type in self._no_local_state_types:
return component_info.enabled
# 3. 如果提供了 stream_id检查局部状态
if stream_id and stream_id in self._local_component_states:
state_key = (component_name, component_type)
local_state = self._local_component_states[stream_id].get(state_key)
@@ -1044,7 +1084,7 @@ class ComponentRegistry:
if local_state is not None:
return local_state # 局部状态存在,覆盖全局状态
# 3. 如果没有局部状态覆盖,则返回全局状态
# 4. 如果没有局部状态覆盖,则返回全局状态
return component_info.enabled
# === MCP 工具相关方法 ===