feat(plugin): 引入Prompt组件系统以实现动态Prompt注入

引入了一个新的插件组件类型 `BasePrompt`,允许插件动态地向核心Prompt模板中注入额外的上下文信息。该系统旨在提高Prompt的可扩展性和可定制性,使得开发者可以在不修改核心代码的情况下,通过插件来丰富和调整模型的行为。

主要变更包括:
- **`BasePrompt` 基类**: 定义了Prompt组件的标准接口,包括 `execute` 方法用于生成注入内容,以及 `injection_point` 属性用于指定目标Prompt。
- **`PromptComponentManager`**: 一个新的管理器,负责注册、分类和执行所有 `BasePrompt` 组件。它会在构建核心Prompt时,自动查找并执行相关组件,将其输出拼接到主Prompt内容之前。
- **核心Prompt逻辑更新**: `src.chat.utils.prompt.Prompt` 类现在会调用 `PromptComponentManager` 来获取并注入组件内容。
- **插件系统集成**: `ComponentRegistry` 和 `PluginManager` 已更新,以支持 `BasePrompt` 组件的注册、管理和统计。
- **示例插件更新**: `hello_world_plugin` 中增加了一个 `WeatherPrompt` 示例,演示了如何创建和注册一个新的Prompt组件。
- **代码重构**: 将 `PromptParameters` 类从 `prompt.py` 移动到独立的 `prompt_params.py` 文件中,以改善模块化和解决循环依赖问题。
This commit is contained in:
minecraft1024a
2025-10-19 13:00:18 +08:00
committed by Windpicker-owo
parent baefca2115
commit 0917318cbd
12 changed files with 428 additions and 85 deletions

View File

@@ -8,13 +8,14 @@ import contextvars
import re
import time
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from typing import Any, Literal, Optional
from typing import Any, Optional
from rich.traceback import install
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.chat_message_builder import build_readable_messages
from src.chat.utils.prompt_component_manager import prompt_component_manager
from src.chat.utils.prompt_params import PromptParameters
from src.common.logger import get_logger
from src.config.config import global_config
from src.person_info.person_info import get_person_info_manager
@@ -23,80 +24,6 @@ install(extra_lines=3)
logger = get_logger("unified_prompt")
@dataclass
class PromptParameters:
"""统一提示词参数系统"""
# 基础参数
chat_id: str = ""
is_group_chat: bool = False
sender: str = ""
target: str = ""
reply_to: str = ""
extra_info: str = ""
prompt_mode: Literal["s4u", "normal", "minimal"] = "s4u"
bot_name: str = ""
bot_nickname: str = ""
# 功能开关
enable_tool: bool = True
enable_memory: bool = True
enable_expression: bool = True
enable_relation: bool = True
enable_cross_context: bool = True
enable_knowledge: bool = True
# 性能控制
max_context_messages: int = 50
# 调试选项
debug_mode: bool = False
# 聊天历史和上下文
chat_target_info: dict[str, Any] | None = None
message_list_before_now_long: list[dict[str, Any]] = field(default_factory=list)
message_list_before_short: list[dict[str, Any]] = field(default_factory=list)
chat_talking_prompt_short: str = ""
target_user_info: dict[str, Any] | None = None
# 已构建的内容块
expression_habits_block: str = ""
relation_info_block: str = ""
memory_block: str = ""
tool_info_block: str = ""
knowledge_prompt: str = ""
cross_context_block: str = ""
# 其他内容块
keywords_reaction_prompt: str = ""
extra_info_block: str = ""
time_block: str = ""
identity_block: str = ""
schedule_block: str = ""
moderation_prompt_block: str = ""
safety_guidelines_block: str = ""
reply_target_block: str = ""
mood_prompt: str = ""
action_descriptions: str = ""
# 可用动作信息
available_actions: dict[str, Any] | None = None
# 动态生成的聊天场景提示
chat_scene: str = ""
def validate(self) -> list[str]:
"""参数验证"""
errors = []
if not self.chat_id:
errors.append("chat_id不能为空")
if self.prompt_mode not in ["s4u", "normal", "minimal"]:
errors.append("prompt_mode必须是's4u''normal''minimal'")
if self.max_context_messages <= 0:
errors.append("max_context_messages必须大于0")
return errors
class PromptContext:
"""提示词上下文管理器"""
@@ -307,11 +234,23 @@ class Prompt:
start_time = time.time()
try:
# 构建上下文数据
# 1. 从组件管理器获取注入内容
components_prefix = ""
if self.name:
components_prefix = await prompt_component_manager.execute_components_for(
injection_point=self.name, params=self.parameters
)
# 2. 构建核心上下文数据
context_data = await self._build_context_data()
# 格式化模板
result = await self._format_with_context(context_data)
# 3. 格式化模板
main_formatted_prompt = await self._format_with_context(context_data)
# 4. 拼接组件内容和主模板内容
result = main_formatted_prompt
if components_prefix:
result = f"{components_prefix}\n\n{main_formatted_prompt}"
total_time = time.time() - start_time
logger.debug(f"Prompt构建完成模式: {self.parameters.prompt_mode}, 耗时: {total_time:.2f}s")

View File

@@ -0,0 +1,123 @@
import asyncio
from collections import defaultdict
from typing import Type
from src.chat.utils.prompt_params import PromptParameters
from src.common.logger import get_logger
from src.plugin_system.base.base_prompt import BasePrompt
from src.plugin_system.base.component_types import PromptInfo
logger = get_logger("prompt_component_manager")
class PromptComponentManager:
"""
管理所有 `BasePrompt` 组件的单例类。
该管理器负责:
1. 注册由插件定义的 `BasePrompt` 子类。
2. 根据注入点目标Prompt名称对它们进行分类存储。
3. 提供一个接口以便在构建核心Prompt时能够获取并执行所有相关的组件。
"""
def __init__(self):
self._registry: dict[str, list[Type[BasePrompt]]] = defaultdict(list)
self._prompt_infos: dict[str, PromptInfo] = {}
def register(self, component_class: Type[BasePrompt]):
"""
注册一个 `BasePrompt` 组件类。
Args:
component_class: 要注册的 `BasePrompt` 子类。
"""
if not issubclass(component_class, BasePrompt):
logger.warning(f"尝试注册一个非 BasePrompt 的类: {component_class.__name__}")
return
try:
prompt_info = component_class.get_prompt_info()
if prompt_info.name in self._prompt_infos:
logger.warning(f"重复注册 Prompt 组件: {prompt_info.name}。将覆盖旧组件。")
injection_points = prompt_info.injection_point
if isinstance(injection_points, str):
injection_points = [injection_points]
if not injection_points or not all(injection_points):
logger.debug(f"Prompt 组件 '{prompt_info.name}' 未指定有效的 injection_point将不会被自动注入。")
return
for point in injection_points:
self._registry[point].append(component_class)
self._prompt_infos[prompt_info.name] = prompt_info
logger.info(f"成功注册 Prompt 组件 '{prompt_info.name}' 到注入点: {injection_points}")
except ValueError as e:
logger.error(f"注册 Prompt 组件失败 {component_class.__name__}: {e}")
def get_components_for(self, injection_point: str) -> list[Type[BasePrompt]]:
"""
获取指定注入点的所有已注册组件类。
Args:
injection_point: 目标Prompt的名称。
Returns:
list[Type[BasePrompt]]: 与该注入点关联的组件类列表。
"""
return self._registry.get(injection_point, [])
async def execute_components_for(self, injection_point: str, params: PromptParameters) -> str:
"""
实例化并执行指定注入点的所有组件,然后将它们的输出拼接成一个字符串。
Args:
injection_point: 目标Prompt的名称。
params: 用于初始化组件的 PromptParameters 对象。
Returns:
str: 所有相关组件生成的、用换行符连接的文本内容。
"""
component_classes = self.get_components_for(injection_point)
if not component_classes:
return ""
tasks = []
from src.plugin_system.core.component_registry import component_registry
for component_class in component_classes:
try:
prompt_info = self._prompt_infos.get(component_class.prompt_name)
if not prompt_info:
logger.warning(f"找不到 Prompt 组件 '{component_class.prompt_name}' 的信息,无法获取插件配置")
plugin_config = {}
else:
plugin_config = component_registry.get_plugin_config(prompt_info.plugin_name)
instance = component_class(params=params, plugin_config=plugin_config)
tasks.append(instance.execute())
except Exception as e:
logger.error(f"实例化 Prompt 组件 '{component_class.prompt_name}' 失败: {e}")
if not tasks:
return ""
# 并行执行所有组件
results = await asyncio.gather(*tasks, return_exceptions=True)
# 过滤掉执行失败的结果和空字符串
valid_results = []
for i, result in enumerate(results):
if isinstance(result, Exception):
logger.error(f"执行 Prompt 组件 '{component_classes[i].prompt_name}' 失败: {result}")
elif result and isinstance(result, str) and result.strip():
valid_results.append(result.strip())
# 使用换行符拼接所有有效结果
return "\n".join(valid_results)
# 创建全局单例
prompt_component_manager = PromptComponentManager()

View File

@@ -0,0 +1,79 @@
"""
This module contains the PromptParameters class, which is used to define the parameters for a prompt.
"""
from dataclasses import dataclass, field
from typing import Any, Literal
@dataclass
class PromptParameters:
"""统一提示词参数系统"""
# 基础参数
chat_id: str = ""
is_group_chat: bool = False
sender: str = ""
target: str = ""
reply_to: str = ""
extra_info: str = ""
prompt_mode: Literal["s4u", "normal", "minimal"] = "s4u"
bot_name: str = ""
bot_nickname: str = ""
# 功能开关
enable_tool: bool = True
enable_memory: bool = True
enable_expression: bool = True
enable_relation: bool = True
enable_cross_context: bool = True
enable_knowledge: bool = True
# 性能控制
max_context_messages: int = 50
# 调试选项
debug_mode: bool = False
# 聊天历史和上下文
chat_target_info: dict[str, Any] | None = None
message_list_before_now_long: list[dict[str, Any]] = field(default_factory=list)
message_list_before_short: list[dict[str, Any]] = field(default_factory=list)
chat_talking_prompt_short: str = ""
target_user_info: dict[str, Any] | None = None
# 已构建的内容块
expression_habits_block: str = ""
relation_info_block: str = ""
memory_block: str = ""
tool_info_block: str = ""
knowledge_prompt: str = ""
cross_context_block: str = ""
# 其他内容块
keywords_reaction_prompt: str = ""
extra_info_block: str = ""
time_block: str = ""
identity_block: str = ""
schedule_block: str = ""
moderation_prompt_block: str = ""
safety_guidelines_block: str = ""
reply_target_block: str = ""
mood_prompt: str = ""
action_descriptions: str = ""
# 可用动作信息
available_actions: dict[str, Any] | None = None
# 动态生成的聊天场景提示
chat_scene: str = ""
def validate(self) -> list[str]:
"""参数验证"""
errors = []
if not self.chat_id:
errors.append("chat_id不能为空")
if self.prompt_mode not in ["s4u", "normal", "minimal"]:
errors.append("prompt_mode必须是's4u''normal''minimal'")
if self.max_context_messages <= 0:
errors.append("max_context_messages必须大于0")
return errors