feat(plugin): 引入Prompt组件系统以实现动态Prompt注入
引入了一个新的插件组件类型 `BasePrompt`,允许插件动态地向核心Prompt模板中注入额外的上下文信息。该系统旨在提高Prompt的可扩展性和可定制性,使得开发者可以在不修改核心代码的情况下,通过插件来丰富和调整模型的行为。 主要变更包括: - **`BasePrompt` 基类**: 定义了Prompt组件的标准接口,包括 `execute` 方法用于生成注入内容,以及 `injection_point` 属性用于指定目标Prompt。 - **`PromptComponentManager`**: 一个新的管理器,负责注册、分类和执行所有 `BasePrompt` 组件。它会在构建核心Prompt时,自动查找并执行相关组件,将其输出拼接到主Prompt内容之前。 - **核心Prompt逻辑更新**: `src.chat.utils.prompt.Prompt` 类现在会调用 `PromptComponentManager` 来获取并注入组件内容。 - **插件系统集成**: `ComponentRegistry` 和 `PluginManager` 已更新,以支持 `BasePrompt` 组件的注册、管理和统计。 - **示例插件更新**: `hello_world_plugin` 中增加了一个 `WeatherPrompt` 示例,演示了如何创建和注册一个新的Prompt组件。 - **代码重构**: 将 `PromptParameters` 类从 `prompt.py` 移动到独立的 `prompt_params.py` 文件中,以改善模块化和解决循环依赖问题。
This commit is contained in:
@@ -23,7 +23,8 @@ from src.chat.utils.chat_message_builder import (
|
||||
from src.chat.utils.memory_mappings import get_memory_type_chinese_label
|
||||
|
||||
# 导入新的统一Prompt系统
|
||||
from src.chat.utils.prompt import Prompt, PromptParameters, global_prompt_manager
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.chat.utils.prompt_params import PromptParameters
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -8,13 +8,14 @@ import contextvars
|
||||
import re
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from rich.traceback import install
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
from src.chat.utils.prompt_component_manager import prompt_component_manager
|
||||
from src.chat.utils.prompt_params import PromptParameters
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
@@ -23,80 +24,6 @@ install(extra_lines=3)
|
||||
logger = get_logger("unified_prompt")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptParameters:
|
||||
"""统一提示词参数系统"""
|
||||
|
||||
# 基础参数
|
||||
chat_id: str = ""
|
||||
is_group_chat: bool = False
|
||||
sender: str = ""
|
||||
target: str = ""
|
||||
reply_to: str = ""
|
||||
extra_info: str = ""
|
||||
prompt_mode: Literal["s4u", "normal", "minimal"] = "s4u"
|
||||
bot_name: str = ""
|
||||
bot_nickname: str = ""
|
||||
|
||||
# 功能开关
|
||||
enable_tool: bool = True
|
||||
enable_memory: bool = True
|
||||
enable_expression: bool = True
|
||||
enable_relation: bool = True
|
||||
enable_cross_context: bool = True
|
||||
enable_knowledge: bool = True
|
||||
|
||||
# 性能控制
|
||||
max_context_messages: int = 50
|
||||
|
||||
# 调试选项
|
||||
debug_mode: bool = False
|
||||
|
||||
# 聊天历史和上下文
|
||||
chat_target_info: dict[str, Any] | None = None
|
||||
message_list_before_now_long: list[dict[str, Any]] = field(default_factory=list)
|
||||
message_list_before_short: list[dict[str, Any]] = field(default_factory=list)
|
||||
chat_talking_prompt_short: str = ""
|
||||
target_user_info: dict[str, Any] | None = None
|
||||
|
||||
# 已构建的内容块
|
||||
expression_habits_block: str = ""
|
||||
relation_info_block: str = ""
|
||||
memory_block: str = ""
|
||||
tool_info_block: str = ""
|
||||
knowledge_prompt: str = ""
|
||||
cross_context_block: str = ""
|
||||
|
||||
# 其他内容块
|
||||
keywords_reaction_prompt: str = ""
|
||||
extra_info_block: str = ""
|
||||
time_block: str = ""
|
||||
identity_block: str = ""
|
||||
schedule_block: str = ""
|
||||
moderation_prompt_block: str = ""
|
||||
safety_guidelines_block: str = ""
|
||||
reply_target_block: str = ""
|
||||
mood_prompt: str = ""
|
||||
action_descriptions: str = ""
|
||||
|
||||
# 可用动作信息
|
||||
available_actions: dict[str, Any] | None = None
|
||||
|
||||
# 动态生成的聊天场景提示
|
||||
chat_scene: str = ""
|
||||
|
||||
def validate(self) -> list[str]:
|
||||
"""参数验证"""
|
||||
errors = []
|
||||
if not self.chat_id:
|
||||
errors.append("chat_id不能为空")
|
||||
if self.prompt_mode not in ["s4u", "normal", "minimal"]:
|
||||
errors.append("prompt_mode必须是's4u'、'normal'或'minimal'")
|
||||
if self.max_context_messages <= 0:
|
||||
errors.append("max_context_messages必须大于0")
|
||||
return errors
|
||||
|
||||
|
||||
class PromptContext:
|
||||
"""提示词上下文管理器"""
|
||||
|
||||
@@ -303,11 +230,23 @@ class Prompt:
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
# 构建上下文数据
|
||||
# 1. 从组件管理器获取注入内容
|
||||
components_prefix = ""
|
||||
if self.name:
|
||||
components_prefix = await prompt_component_manager.execute_components_for(
|
||||
injection_point=self.name, params=self.parameters
|
||||
)
|
||||
|
||||
# 2. 构建核心上下文数据
|
||||
context_data = await self._build_context_data()
|
||||
|
||||
# 格式化模板
|
||||
result = await self._format_with_context(context_data)
|
||||
# 3. 格式化主模板
|
||||
main_formatted_prompt = await self._format_with_context(context_data)
|
||||
|
||||
# 4. 拼接组件内容和主模板内容
|
||||
result = main_formatted_prompt
|
||||
if components_prefix:
|
||||
result = f"{components_prefix}\n\n{main_formatted_prompt}"
|
||||
|
||||
total_time = time.time() - start_time
|
||||
logger.debug(f"Prompt构建完成,模式: {self.parameters.prompt_mode}, 耗时: {total_time:.2f}s")
|
||||
|
||||
123
src/chat/utils/prompt_component_manager.py
Normal file
123
src/chat/utils/prompt_component_manager.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
from typing import Type
|
||||
|
||||
from src.chat.utils.prompt_params import PromptParameters
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.base_prompt import BasePrompt
|
||||
from src.plugin_system.base.component_types import PromptInfo
|
||||
|
||||
logger = get_logger("prompt_component_manager")
|
||||
|
||||
|
||||
class PromptComponentManager:
|
||||
"""
|
||||
管理所有 `BasePrompt` 组件的单例类。
|
||||
|
||||
该管理器负责:
|
||||
1. 注册由插件定义的 `BasePrompt` 子类。
|
||||
2. 根据注入点(目标Prompt名称)对它们进行分类存储。
|
||||
3. 提供一个接口,以便在构建核心Prompt时,能够获取并执行所有相关的组件。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._registry: dict[str, list[Type[BasePrompt]]] = defaultdict(list)
|
||||
self._prompt_infos: dict[str, PromptInfo] = {}
|
||||
|
||||
def register(self, component_class: Type[BasePrompt]):
|
||||
"""
|
||||
注册一个 `BasePrompt` 组件类。
|
||||
|
||||
Args:
|
||||
component_class: 要注册的 `BasePrompt` 子类。
|
||||
"""
|
||||
if not issubclass(component_class, BasePrompt):
|
||||
logger.warning(f"尝试注册一个非 BasePrompt 的类: {component_class.__name__}")
|
||||
return
|
||||
|
||||
try:
|
||||
prompt_info = component_class.get_prompt_info()
|
||||
if prompt_info.name in self._prompt_infos:
|
||||
logger.warning(f"重复注册 Prompt 组件: {prompt_info.name}。将覆盖旧组件。")
|
||||
|
||||
injection_points = prompt_info.injection_point
|
||||
if isinstance(injection_points, str):
|
||||
injection_points = [injection_points]
|
||||
|
||||
if not injection_points or not all(injection_points):
|
||||
logger.debug(f"Prompt 组件 '{prompt_info.name}' 未指定有效的 injection_point,将不会被自动注入。")
|
||||
return
|
||||
|
||||
for point in injection_points:
|
||||
self._registry[point].append(component_class)
|
||||
|
||||
self._prompt_infos[prompt_info.name] = prompt_info
|
||||
logger.info(f"成功注册 Prompt 组件 '{prompt_info.name}' 到注入点: {injection_points}")
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"注册 Prompt 组件失败 {component_class.__name__}: {e}")
|
||||
|
||||
def get_components_for(self, injection_point: str) -> list[Type[BasePrompt]]:
|
||||
"""
|
||||
获取指定注入点的所有已注册组件类。
|
||||
|
||||
Args:
|
||||
injection_point: 目标Prompt的名称。
|
||||
|
||||
Returns:
|
||||
list[Type[BasePrompt]]: 与该注入点关联的组件类列表。
|
||||
"""
|
||||
return self._registry.get(injection_point, [])
|
||||
|
||||
async def execute_components_for(self, injection_point: str, params: PromptParameters) -> str:
|
||||
"""
|
||||
实例化并执行指定注入点的所有组件,然后将它们的输出拼接成一个字符串。
|
||||
|
||||
Args:
|
||||
injection_point: 目标Prompt的名称。
|
||||
params: 用于初始化组件的 PromptParameters 对象。
|
||||
|
||||
Returns:
|
||||
str: 所有相关组件生成的、用换行符连接的文本内容。
|
||||
"""
|
||||
component_classes = self.get_components_for(injection_point)
|
||||
if not component_classes:
|
||||
return ""
|
||||
|
||||
tasks = []
|
||||
from src.plugin_system.core.component_registry import component_registry
|
||||
|
||||
for component_class in component_classes:
|
||||
try:
|
||||
prompt_info = self._prompt_infos.get(component_class.prompt_name)
|
||||
if not prompt_info:
|
||||
logger.warning(f"找不到 Prompt 组件 '{component_class.prompt_name}' 的信息,无法获取插件配置")
|
||||
plugin_config = {}
|
||||
else:
|
||||
plugin_config = component_registry.get_plugin_config(prompt_info.plugin_name)
|
||||
|
||||
instance = component_class(params=params, plugin_config=plugin_config)
|
||||
tasks.append(instance.execute())
|
||||
except Exception as e:
|
||||
logger.error(f"实例化 Prompt 组件 '{component_class.prompt_name}' 失败: {e}")
|
||||
|
||||
if not tasks:
|
||||
return ""
|
||||
|
||||
# 并行执行所有组件
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 过滤掉执行失败的结果和空字符串
|
||||
valid_results = []
|
||||
for i, result in enumerate(results):
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"执行 Prompt 组件 '{component_classes[i].prompt_name}' 失败: {result}")
|
||||
elif result and isinstance(result, str) and result.strip():
|
||||
valid_results.append(result.strip())
|
||||
|
||||
# 使用换行符拼接所有有效结果
|
||||
return "\n".join(valid_results)
|
||||
|
||||
|
||||
# 创建全局单例
|
||||
prompt_component_manager = PromptComponentManager()
|
||||
79
src/chat/utils/prompt_params.py
Normal file
79
src/chat/utils/prompt_params.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""
|
||||
This module contains the PromptParameters class, which is used to define the parameters for a prompt.
|
||||
"""
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Literal
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptParameters:
|
||||
"""统一提示词参数系统"""
|
||||
|
||||
# 基础参数
|
||||
chat_id: str = ""
|
||||
is_group_chat: bool = False
|
||||
sender: str = ""
|
||||
target: str = ""
|
||||
reply_to: str = ""
|
||||
extra_info: str = ""
|
||||
prompt_mode: Literal["s4u", "normal", "minimal"] = "s4u"
|
||||
bot_name: str = ""
|
||||
bot_nickname: str = ""
|
||||
|
||||
# 功能开关
|
||||
enable_tool: bool = True
|
||||
enable_memory: bool = True
|
||||
enable_expression: bool = True
|
||||
enable_relation: bool = True
|
||||
enable_cross_context: bool = True
|
||||
enable_knowledge: bool = True
|
||||
|
||||
# 性能控制
|
||||
max_context_messages: int = 50
|
||||
|
||||
# 调试选项
|
||||
debug_mode: bool = False
|
||||
|
||||
# 聊天历史和上下文
|
||||
chat_target_info: dict[str, Any] | None = None
|
||||
message_list_before_now_long: list[dict[str, Any]] = field(default_factory=list)
|
||||
message_list_before_short: list[dict[str, Any]] = field(default_factory=list)
|
||||
chat_talking_prompt_short: str = ""
|
||||
target_user_info: dict[str, Any] | None = None
|
||||
|
||||
# 已构建的内容块
|
||||
expression_habits_block: str = ""
|
||||
relation_info_block: str = ""
|
||||
memory_block: str = ""
|
||||
tool_info_block: str = ""
|
||||
knowledge_prompt: str = ""
|
||||
cross_context_block: str = ""
|
||||
|
||||
# 其他内容块
|
||||
keywords_reaction_prompt: str = ""
|
||||
extra_info_block: str = ""
|
||||
time_block: str = ""
|
||||
identity_block: str = ""
|
||||
schedule_block: str = ""
|
||||
moderation_prompt_block: str = ""
|
||||
safety_guidelines_block: str = ""
|
||||
reply_target_block: str = ""
|
||||
mood_prompt: str = ""
|
||||
action_descriptions: str = ""
|
||||
|
||||
# 可用动作信息
|
||||
available_actions: dict[str, Any] | None = None
|
||||
|
||||
# 动态生成的聊天场景提示
|
||||
chat_scene: str = ""
|
||||
|
||||
def validate(self) -> list[str]:
|
||||
"""参数验证"""
|
||||
errors = []
|
||||
if not self.chat_id:
|
||||
errors.append("chat_id不能为空")
|
||||
if self.prompt_mode not in ["s4u", "normal", "minimal"]:
|
||||
errors.append("prompt_mode必须是's4u'、'normal'或'minimal'")
|
||||
if self.max_context_messages <= 0:
|
||||
errors.append("max_context_messages必须大于0")
|
||||
return errors
|
||||
Reference in New Issue
Block a user