feat(plugin): 引入Prompt组件系统以实现动态Prompt注入
引入了一个新的插件组件类型 `BasePrompt`,允许插件动态地向核心Prompt模板中注入额外的上下文信息。该系统旨在提高Prompt的可扩展性和可定制性,使得开发者可以在不修改核心代码的情况下,通过插件来丰富和调整模型的行为。 主要变更包括: - **`BasePrompt` 基类**: 定义了Prompt组件的标准接口,包括 `execute` 方法用于生成注入内容,以及 `injection_point` 属性用于指定目标Prompt。 - **`PromptComponentManager`**: 一个新的管理器,负责注册、分类和执行所有 `BasePrompt` 组件。它会在构建核心Prompt时,自动查找并执行相关组件,将其输出拼接到主Prompt内容之前。 - **核心Prompt逻辑更新**: `src.chat.utils.prompt.Prompt` 类现在会调用 `PromptComponentManager` 来获取并注入组件内容。 - **插件系统集成**: `ComponentRegistry` 和 `PluginManager` 已更新,以支持 `BasePrompt` 组件的注册、管理和统计。 - **示例插件更新**: `hello_world_plugin` 中增加了一个 `WeatherPrompt` 示例,演示了如何创建和注册一个新的Prompt组件。 - **代码重构**: 将 `PromptParameters` 类从 `prompt.py` 移动到独立的 `prompt_params.py` 文件中,以改善模块化和解决循环依赖问题。
This commit is contained in:
@@ -6,6 +6,8 @@ from src.plugin_system import (
|
|||||||
BaseAction,
|
BaseAction,
|
||||||
BaseEventHandler,
|
BaseEventHandler,
|
||||||
BasePlugin,
|
BasePlugin,
|
||||||
|
BasePrompt,
|
||||||
|
ToolParamType,
|
||||||
BaseTool,
|
BaseTool,
|
||||||
ChatType,
|
ChatType,
|
||||||
CommandArgs,
|
CommandArgs,
|
||||||
@@ -36,7 +38,17 @@ class GetSystemInfoTool(BaseTool):
|
|||||||
name = "get_system_info"
|
name = "get_system_info"
|
||||||
description = "获取当前系统的模拟版本和状态信息。"
|
description = "获取当前系统的模拟版本和状态信息。"
|
||||||
available_for_llm = True
|
available_for_llm = True
|
||||||
parameters = []
|
parameters = [
|
||||||
|
("query", ToolParamType.STRING, "要搜索的关键词或问题。", True, None),
|
||||||
|
("num_results", ToolParamType.INTEGER, "期望每个搜索引擎返回的搜索结果数量,默认为5。", False, None),
|
||||||
|
(
|
||||||
|
"time_range",
|
||||||
|
ToolParamType.STRING,
|
||||||
|
"指定搜索的时间范围,可以是 'any', 'week', 'month'。默认为 'any'。",
|
||||||
|
False,
|
||||||
|
["any", "week", "month"],
|
||||||
|
),
|
||||||
|
] # type: ignore
|
||||||
|
|
||||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||||
return {"name": self.name, "content": "系统版本: 1.0.1, 状态: 运行正常"}
|
return {"name": self.name, "content": "系统版本: 1.0.1, 状态: 运行正常"}
|
||||||
@@ -99,7 +111,6 @@ class LLMJudgeExampleAction(BaseAction):
|
|||||||
async def go_activate(self, chat_content: str = "", llm_judge_model=None) -> bool:
|
async def go_activate(self, chat_content: str = "", llm_judge_model=None) -> bool:
|
||||||
"""LLM 判断激活:判断用户是否情绪低落"""
|
"""LLM 判断激活:判断用户是否情绪低落"""
|
||||||
return await self._llm_judge_activation(
|
return await self._llm_judge_activation(
|
||||||
chat_content=chat_content,
|
|
||||||
judge_prompt="""
|
judge_prompt="""
|
||||||
判断用户是否表达了以下情绪或需求:
|
判断用户是否表达了以下情绪或需求:
|
||||||
1. 感到难过、沮丧或失落
|
1. 感到难过、沮丧或失落
|
||||||
@@ -169,6 +180,19 @@ class RandomEmojiAction(BaseAction):
|
|||||||
return True, "成功发送了一个随机表情"
|
return True, "成功发送了一个随机表情"
|
||||||
|
|
||||||
|
|
||||||
|
class WeatherPrompt(BasePrompt):
|
||||||
|
"""一个简单的Prompt组件,用于向Planner注入天气信息。"""
|
||||||
|
|
||||||
|
prompt_name = "weather_info_prompt"
|
||||||
|
prompt_description = "向Planner注入当前天气信息,以丰富对话上下文。"
|
||||||
|
injection_point = "planner_prompt"
|
||||||
|
|
||||||
|
async def execute(self) -> str:
|
||||||
|
# 在实际应用中,这里可以调用天气API
|
||||||
|
# 为了演示,我们返回一个固定的天气信息
|
||||||
|
return "当前天气:晴朗,温度25°C。"
|
||||||
|
|
||||||
|
|
||||||
@register_plugin
|
@register_plugin
|
||||||
class HelloWorldPlugin(BasePlugin):
|
class HelloWorldPlugin(BasePlugin):
|
||||||
"""一个包含四大核心组件和高级配置功能的入门示例插件。"""
|
"""一个包含四大核心组件和高级配置功能的入门示例插件。"""
|
||||||
@@ -178,7 +202,6 @@ class HelloWorldPlugin(BasePlugin):
|
|||||||
dependencies = []
|
dependencies = []
|
||||||
python_dependencies = []
|
python_dependencies = []
|
||||||
config_file_name = "config.toml"
|
config_file_name = "config.toml"
|
||||||
enable_plugin = False
|
|
||||||
|
|
||||||
config_schema = {
|
config_schema = {
|
||||||
"meta": {
|
"meta": {
|
||||||
@@ -208,4 +231,7 @@ class HelloWorldPlugin(BasePlugin):
|
|||||||
if self.get_config("components.random_emoji_action_enabled", True):
|
if self.get_config("components.random_emoji_action_enabled", True):
|
||||||
components.append((RandomEmojiAction.get_action_info(), RandomEmojiAction))
|
components.append((RandomEmojiAction.get_action_info(), RandomEmojiAction))
|
||||||
|
|
||||||
|
# 注册新的Prompt组件
|
||||||
|
components.append((WeatherPrompt.get_prompt_info(), WeatherPrompt))
|
||||||
|
|
||||||
return components
|
return components
|
||||||
|
|||||||
@@ -23,7 +23,8 @@ from src.chat.utils.chat_message_builder import (
|
|||||||
from src.chat.utils.memory_mappings import get_memory_type_chinese_label
|
from src.chat.utils.memory_mappings import get_memory_type_chinese_label
|
||||||
|
|
||||||
# 导入新的统一Prompt系统
|
# 导入新的统一Prompt系统
|
||||||
from src.chat.utils.prompt import Prompt, PromptParameters, global_prompt_manager
|
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||||
|
from src.chat.utils.prompt_params import PromptParameters
|
||||||
from src.chat.utils.timer_calculator import Timer
|
from src.chat.utils.timer_calculator import Timer
|
||||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|||||||
@@ -8,13 +8,14 @@ import contextvars
|
|||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from dataclasses import dataclass, field
|
from typing import Any, Optional
|
||||||
from typing import Any, Literal, Optional
|
|
||||||
|
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||||
|
from src.chat.utils.prompt_component_manager import prompt_component_manager
|
||||||
|
from src.chat.utils.prompt_params import PromptParameters
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.person_info.person_info import get_person_info_manager
|
from src.person_info.person_info import get_person_info_manager
|
||||||
@@ -23,80 +24,6 @@ install(extra_lines=3)
|
|||||||
logger = get_logger("unified_prompt")
|
logger = get_logger("unified_prompt")
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PromptParameters:
|
|
||||||
"""统一提示词参数系统"""
|
|
||||||
|
|
||||||
# 基础参数
|
|
||||||
chat_id: str = ""
|
|
||||||
is_group_chat: bool = False
|
|
||||||
sender: str = ""
|
|
||||||
target: str = ""
|
|
||||||
reply_to: str = ""
|
|
||||||
extra_info: str = ""
|
|
||||||
prompt_mode: Literal["s4u", "normal", "minimal"] = "s4u"
|
|
||||||
bot_name: str = ""
|
|
||||||
bot_nickname: str = ""
|
|
||||||
|
|
||||||
# 功能开关
|
|
||||||
enable_tool: bool = True
|
|
||||||
enable_memory: bool = True
|
|
||||||
enable_expression: bool = True
|
|
||||||
enable_relation: bool = True
|
|
||||||
enable_cross_context: bool = True
|
|
||||||
enable_knowledge: bool = True
|
|
||||||
|
|
||||||
# 性能控制
|
|
||||||
max_context_messages: int = 50
|
|
||||||
|
|
||||||
# 调试选项
|
|
||||||
debug_mode: bool = False
|
|
||||||
|
|
||||||
# 聊天历史和上下文
|
|
||||||
chat_target_info: dict[str, Any] | None = None
|
|
||||||
message_list_before_now_long: list[dict[str, Any]] = field(default_factory=list)
|
|
||||||
message_list_before_short: list[dict[str, Any]] = field(default_factory=list)
|
|
||||||
chat_talking_prompt_short: str = ""
|
|
||||||
target_user_info: dict[str, Any] | None = None
|
|
||||||
|
|
||||||
# 已构建的内容块
|
|
||||||
expression_habits_block: str = ""
|
|
||||||
relation_info_block: str = ""
|
|
||||||
memory_block: str = ""
|
|
||||||
tool_info_block: str = ""
|
|
||||||
knowledge_prompt: str = ""
|
|
||||||
cross_context_block: str = ""
|
|
||||||
|
|
||||||
# 其他内容块
|
|
||||||
keywords_reaction_prompt: str = ""
|
|
||||||
extra_info_block: str = ""
|
|
||||||
time_block: str = ""
|
|
||||||
identity_block: str = ""
|
|
||||||
schedule_block: str = ""
|
|
||||||
moderation_prompt_block: str = ""
|
|
||||||
safety_guidelines_block: str = ""
|
|
||||||
reply_target_block: str = ""
|
|
||||||
mood_prompt: str = ""
|
|
||||||
action_descriptions: str = ""
|
|
||||||
|
|
||||||
# 可用动作信息
|
|
||||||
available_actions: dict[str, Any] | None = None
|
|
||||||
|
|
||||||
# 动态生成的聊天场景提示
|
|
||||||
chat_scene: str = ""
|
|
||||||
|
|
||||||
def validate(self) -> list[str]:
|
|
||||||
"""参数验证"""
|
|
||||||
errors = []
|
|
||||||
if not self.chat_id:
|
|
||||||
errors.append("chat_id不能为空")
|
|
||||||
if self.prompt_mode not in ["s4u", "normal", "minimal"]:
|
|
||||||
errors.append("prompt_mode必须是's4u'、'normal'或'minimal'")
|
|
||||||
if self.max_context_messages <= 0:
|
|
||||||
errors.append("max_context_messages必须大于0")
|
|
||||||
return errors
|
|
||||||
|
|
||||||
|
|
||||||
class PromptContext:
|
class PromptContext:
|
||||||
"""提示词上下文管理器"""
|
"""提示词上下文管理器"""
|
||||||
|
|
||||||
@@ -303,11 +230,23 @@ class Prompt:
|
|||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
try:
|
try:
|
||||||
# 构建上下文数据
|
# 1. 从组件管理器获取注入内容
|
||||||
|
components_prefix = ""
|
||||||
|
if self.name:
|
||||||
|
components_prefix = await prompt_component_manager.execute_components_for(
|
||||||
|
injection_point=self.name, params=self.parameters
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. 构建核心上下文数据
|
||||||
context_data = await self._build_context_data()
|
context_data = await self._build_context_data()
|
||||||
|
|
||||||
# 格式化模板
|
# 3. 格式化主模板
|
||||||
result = await self._format_with_context(context_data)
|
main_formatted_prompt = await self._format_with_context(context_data)
|
||||||
|
|
||||||
|
# 4. 拼接组件内容和主模板内容
|
||||||
|
result = main_formatted_prompt
|
||||||
|
if components_prefix:
|
||||||
|
result = f"{components_prefix}\n\n{main_formatted_prompt}"
|
||||||
|
|
||||||
total_time = time.time() - start_time
|
total_time = time.time() - start_time
|
||||||
logger.debug(f"Prompt构建完成,模式: {self.parameters.prompt_mode}, 耗时: {total_time:.2f}s")
|
logger.debug(f"Prompt构建完成,模式: {self.parameters.prompt_mode}, 耗时: {total_time:.2f}s")
|
||||||
|
|||||||
123
src/chat/utils/prompt_component_manager.py
Normal file
123
src/chat/utils/prompt_component_manager.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
import asyncio
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
from src.chat.utils.prompt_params import PromptParameters
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.plugin_system.base.base_prompt import BasePrompt
|
||||||
|
from src.plugin_system.base.component_types import PromptInfo
|
||||||
|
|
||||||
|
logger = get_logger("prompt_component_manager")
|
||||||
|
|
||||||
|
|
||||||
|
class PromptComponentManager:
|
||||||
|
"""
|
||||||
|
管理所有 `BasePrompt` 组件的单例类。
|
||||||
|
|
||||||
|
该管理器负责:
|
||||||
|
1. 注册由插件定义的 `BasePrompt` 子类。
|
||||||
|
2. 根据注入点(目标Prompt名称)对它们进行分类存储。
|
||||||
|
3. 提供一个接口,以便在构建核心Prompt时,能够获取并执行所有相关的组件。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._registry: dict[str, list[Type[BasePrompt]]] = defaultdict(list)
|
||||||
|
self._prompt_infos: dict[str, PromptInfo] = {}
|
||||||
|
|
||||||
|
def register(self, component_class: Type[BasePrompt]):
|
||||||
|
"""
|
||||||
|
注册一个 `BasePrompt` 组件类。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
component_class: 要注册的 `BasePrompt` 子类。
|
||||||
|
"""
|
||||||
|
if not issubclass(component_class, BasePrompt):
|
||||||
|
logger.warning(f"尝试注册一个非 BasePrompt 的类: {component_class.__name__}")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
prompt_info = component_class.get_prompt_info()
|
||||||
|
if prompt_info.name in self._prompt_infos:
|
||||||
|
logger.warning(f"重复注册 Prompt 组件: {prompt_info.name}。将覆盖旧组件。")
|
||||||
|
|
||||||
|
injection_points = prompt_info.injection_point
|
||||||
|
if isinstance(injection_points, str):
|
||||||
|
injection_points = [injection_points]
|
||||||
|
|
||||||
|
if not injection_points or not all(injection_points):
|
||||||
|
logger.debug(f"Prompt 组件 '{prompt_info.name}' 未指定有效的 injection_point,将不会被自动注入。")
|
||||||
|
return
|
||||||
|
|
||||||
|
for point in injection_points:
|
||||||
|
self._registry[point].append(component_class)
|
||||||
|
|
||||||
|
self._prompt_infos[prompt_info.name] = prompt_info
|
||||||
|
logger.info(f"成功注册 Prompt 组件 '{prompt_info.name}' 到注入点: {injection_points}")
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(f"注册 Prompt 组件失败 {component_class.__name__}: {e}")
|
||||||
|
|
||||||
|
def get_components_for(self, injection_point: str) -> list[Type[BasePrompt]]:
|
||||||
|
"""
|
||||||
|
获取指定注入点的所有已注册组件类。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
injection_point: 目标Prompt的名称。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[Type[BasePrompt]]: 与该注入点关联的组件类列表。
|
||||||
|
"""
|
||||||
|
return self._registry.get(injection_point, [])
|
||||||
|
|
||||||
|
async def execute_components_for(self, injection_point: str, params: PromptParameters) -> str:
|
||||||
|
"""
|
||||||
|
实例化并执行指定注入点的所有组件,然后将它们的输出拼接成一个字符串。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
injection_point: 目标Prompt的名称。
|
||||||
|
params: 用于初始化组件的 PromptParameters 对象。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 所有相关组件生成的、用换行符连接的文本内容。
|
||||||
|
"""
|
||||||
|
component_classes = self.get_components_for(injection_point)
|
||||||
|
if not component_classes:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
tasks = []
|
||||||
|
from src.plugin_system.core.component_registry import component_registry
|
||||||
|
|
||||||
|
for component_class in component_classes:
|
||||||
|
try:
|
||||||
|
prompt_info = self._prompt_infos.get(component_class.prompt_name)
|
||||||
|
if not prompt_info:
|
||||||
|
logger.warning(f"找不到 Prompt 组件 '{component_class.prompt_name}' 的信息,无法获取插件配置")
|
||||||
|
plugin_config = {}
|
||||||
|
else:
|
||||||
|
plugin_config = component_registry.get_plugin_config(prompt_info.plugin_name)
|
||||||
|
|
||||||
|
instance = component_class(params=params, plugin_config=plugin_config)
|
||||||
|
tasks.append(instance.execute())
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"实例化 Prompt 组件 '{component_class.prompt_name}' 失败: {e}")
|
||||||
|
|
||||||
|
if not tasks:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# 并行执行所有组件
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
# 过滤掉执行失败的结果和空字符串
|
||||||
|
valid_results = []
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
logger.error(f"执行 Prompt 组件 '{component_classes[i].prompt_name}' 失败: {result}")
|
||||||
|
elif result and isinstance(result, str) and result.strip():
|
||||||
|
valid_results.append(result.strip())
|
||||||
|
|
||||||
|
# 使用换行符拼接所有有效结果
|
||||||
|
return "\n".join(valid_results)
|
||||||
|
|
||||||
|
|
||||||
|
# 创建全局单例
|
||||||
|
prompt_component_manager = PromptComponentManager()
|
||||||
79
src/chat/utils/prompt_params.py
Normal file
79
src/chat/utils/prompt_params.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
"""
|
||||||
|
This module contains the PromptParameters class, which is used to define the parameters for a prompt.
|
||||||
|
"""
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PromptParameters:
|
||||||
|
"""统一提示词参数系统"""
|
||||||
|
|
||||||
|
# 基础参数
|
||||||
|
chat_id: str = ""
|
||||||
|
is_group_chat: bool = False
|
||||||
|
sender: str = ""
|
||||||
|
target: str = ""
|
||||||
|
reply_to: str = ""
|
||||||
|
extra_info: str = ""
|
||||||
|
prompt_mode: Literal["s4u", "normal", "minimal"] = "s4u"
|
||||||
|
bot_name: str = ""
|
||||||
|
bot_nickname: str = ""
|
||||||
|
|
||||||
|
# 功能开关
|
||||||
|
enable_tool: bool = True
|
||||||
|
enable_memory: bool = True
|
||||||
|
enable_expression: bool = True
|
||||||
|
enable_relation: bool = True
|
||||||
|
enable_cross_context: bool = True
|
||||||
|
enable_knowledge: bool = True
|
||||||
|
|
||||||
|
# 性能控制
|
||||||
|
max_context_messages: int = 50
|
||||||
|
|
||||||
|
# 调试选项
|
||||||
|
debug_mode: bool = False
|
||||||
|
|
||||||
|
# 聊天历史和上下文
|
||||||
|
chat_target_info: dict[str, Any] | None = None
|
||||||
|
message_list_before_now_long: list[dict[str, Any]] = field(default_factory=list)
|
||||||
|
message_list_before_short: list[dict[str, Any]] = field(default_factory=list)
|
||||||
|
chat_talking_prompt_short: str = ""
|
||||||
|
target_user_info: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
# 已构建的内容块
|
||||||
|
expression_habits_block: str = ""
|
||||||
|
relation_info_block: str = ""
|
||||||
|
memory_block: str = ""
|
||||||
|
tool_info_block: str = ""
|
||||||
|
knowledge_prompt: str = ""
|
||||||
|
cross_context_block: str = ""
|
||||||
|
|
||||||
|
# 其他内容块
|
||||||
|
keywords_reaction_prompt: str = ""
|
||||||
|
extra_info_block: str = ""
|
||||||
|
time_block: str = ""
|
||||||
|
identity_block: str = ""
|
||||||
|
schedule_block: str = ""
|
||||||
|
moderation_prompt_block: str = ""
|
||||||
|
safety_guidelines_block: str = ""
|
||||||
|
reply_target_block: str = ""
|
||||||
|
mood_prompt: str = ""
|
||||||
|
action_descriptions: str = ""
|
||||||
|
|
||||||
|
# 可用动作信息
|
||||||
|
available_actions: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
# 动态生成的聊天场景提示
|
||||||
|
chat_scene: str = ""
|
||||||
|
|
||||||
|
def validate(self) -> list[str]:
|
||||||
|
"""参数验证"""
|
||||||
|
errors = []
|
||||||
|
if not self.chat_id:
|
||||||
|
errors.append("chat_id不能为空")
|
||||||
|
if self.prompt_mode not in ["s4u", "normal", "minimal"]:
|
||||||
|
errors.append("prompt_mode必须是's4u'、'normal'或'minimal'")
|
||||||
|
if self.max_context_messages <= 0:
|
||||||
|
errors.append("max_context_messages必须大于0")
|
||||||
|
return errors
|
||||||
@@ -26,6 +26,7 @@ from .base import (
|
|||||||
ActionInfo,
|
ActionInfo,
|
||||||
BaseAction,
|
BaseAction,
|
||||||
BaseCommand,
|
BaseCommand,
|
||||||
|
BasePrompt,
|
||||||
BaseEventHandler,
|
BaseEventHandler,
|
||||||
BasePlugin,
|
BasePlugin,
|
||||||
BaseTool,
|
BaseTool,
|
||||||
@@ -64,6 +65,7 @@ __all__ = [
|
|||||||
"BaseEventHandler",
|
"BaseEventHandler",
|
||||||
# 基础类
|
# 基础类
|
||||||
"BasePlugin",
|
"BasePlugin",
|
||||||
|
"BasePrompt",
|
||||||
"BaseTool",
|
"BaseTool",
|
||||||
"ChatMode",
|
"ChatMode",
|
||||||
"ChatType",
|
"ChatType",
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from .base_action import BaseAction
|
|||||||
from .base_command import BaseCommand
|
from .base_command import BaseCommand
|
||||||
from .base_events_handler import BaseEventHandler
|
from .base_events_handler import BaseEventHandler
|
||||||
from .base_plugin import BasePlugin
|
from .base_plugin import BasePlugin
|
||||||
|
from .base_prompt import BasePrompt
|
||||||
from .base_tool import BaseTool
|
from .base_tool import BaseTool
|
||||||
from .command_args import CommandArgs
|
from .command_args import CommandArgs
|
||||||
from .component_types import (
|
from .component_types import (
|
||||||
@@ -37,6 +38,7 @@ __all__ = [
|
|||||||
"BaseCommand",
|
"BaseCommand",
|
||||||
"BaseEventHandler",
|
"BaseEventHandler",
|
||||||
"BasePlugin",
|
"BasePlugin",
|
||||||
|
"BasePrompt",
|
||||||
"BaseTool",
|
"BaseTool",
|
||||||
"ChatMode",
|
"ChatMode",
|
||||||
"ChatType",
|
"ChatType",
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from src.plugin_system.base.component_types import (
|
|||||||
EventHandlerInfo,
|
EventHandlerInfo,
|
||||||
InterestCalculatorInfo,
|
InterestCalculatorInfo,
|
||||||
PlusCommandInfo,
|
PlusCommandInfo,
|
||||||
|
PromptInfo,
|
||||||
ToolInfo,
|
ToolInfo,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -15,6 +16,7 @@ from .base_action import BaseAction
|
|||||||
from .base_command import BaseCommand
|
from .base_command import BaseCommand
|
||||||
from .base_events_handler import BaseEventHandler
|
from .base_events_handler import BaseEventHandler
|
||||||
from .base_interest_calculator import BaseInterestCalculator
|
from .base_interest_calculator import BaseInterestCalculator
|
||||||
|
from .base_prompt import BasePrompt
|
||||||
from .base_tool import BaseTool
|
from .base_tool import BaseTool
|
||||||
from .plugin_base import PluginBase
|
from .plugin_base import PluginBase
|
||||||
from .plus_command import PlusCommand
|
from .plus_command import PlusCommand
|
||||||
@@ -80,6 +82,13 @@ class BasePlugin(PluginBase):
|
|||||||
logger.warning("EventHandler的get_info逻辑尚未实现")
|
logger.warning("EventHandler的get_info逻辑尚未实现")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
elif component_type == ComponentType.PROMPT:
|
||||||
|
if hasattr(component_class, "get_prompt_info"):
|
||||||
|
return component_class.get_prompt_info()
|
||||||
|
else:
|
||||||
|
logger.warning(f"Prompt类 {component_class.__name__} 缺少 get_prompt_info 方法")
|
||||||
|
return None
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.error(f"不支持的组件类型: {component_type}")
|
logger.error(f"不支持的组件类型: {component_type}")
|
||||||
return None
|
return None
|
||||||
@@ -109,6 +118,7 @@ class BasePlugin(PluginBase):
|
|||||||
| tuple[EventHandlerInfo, type[BaseEventHandler]]
|
| tuple[EventHandlerInfo, type[BaseEventHandler]]
|
||||||
| tuple[ToolInfo, type[BaseTool]]
|
| tuple[ToolInfo, type[BaseTool]]
|
||||||
| tuple[InterestCalculatorInfo, type[BaseInterestCalculator]]
|
| tuple[InterestCalculatorInfo, type[BaseInterestCalculator]]
|
||||||
|
| tuple[PromptInfo, type[BasePrompt]]
|
||||||
]:
|
]:
|
||||||
"""获取插件包含的组件列表
|
"""获取插件包含的组件列表
|
||||||
|
|
||||||
|
|||||||
95
src/plugin_system/base/base_prompt.py
Normal file
95
src/plugin_system/base/base_prompt.py
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from src.chat.utils.prompt_params import PromptParameters
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.plugin_system.base.component_types import ComponentType, PromptInfo
|
||||||
|
|
||||||
|
logger = get_logger("base_prompt")
|
||||||
|
|
||||||
|
|
||||||
|
class BasePrompt(ABC):
|
||||||
|
"""Prompt组件基类
|
||||||
|
|
||||||
|
Prompt是插件的一种组件类型,用于动态地向现有的核心Prompt模板中注入额外的上下文信息。
|
||||||
|
它的主要作用是在不修改核心代码的情况下,扩展和定制模型的行为。
|
||||||
|
|
||||||
|
子类可以通过类属性定义其行为:
|
||||||
|
- prompt_name: Prompt组件的唯一名称。
|
||||||
|
- injection_point: 指定要注入的目标Prompt名称(或名称列表)。
|
||||||
|
"""
|
||||||
|
|
||||||
|
prompt_name: str = ""
|
||||||
|
"""Prompt组件的名称"""
|
||||||
|
prompt_description: str = ""
|
||||||
|
"""Prompt组件的描述"""
|
||||||
|
|
||||||
|
# 定义此组件希望注入到哪个或哪些核心Prompt中
|
||||||
|
# 可以是一个字符串(单个目标)或字符串列表(多个目标)
|
||||||
|
# 例如: "planner_prompt" 或 ["s4u_style_prompt", "normal_style_prompt"]
|
||||||
|
injection_point: str | list[str] = ""
|
||||||
|
"""要注入的目标Prompt名称或列表"""
|
||||||
|
|
||||||
|
def __init__(self, params: PromptParameters, plugin_config: dict | None = None):
|
||||||
|
"""初始化Prompt组件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params: 统一提示词参数,包含所有构建提示词所需的上下文信息。
|
||||||
|
plugin_config: 插件配置字典。
|
||||||
|
"""
|
||||||
|
self.params = params
|
||||||
|
self.plugin_config = plugin_config or {}
|
||||||
|
self.log_prefix = "[PromptComponent]"
|
||||||
|
|
||||||
|
logger.debug(f"{self.log_prefix} Prompt组件 '{self.prompt_name}' 初始化完成")
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def execute(self) -> str:
|
||||||
|
"""执行Prompt生成的抽象方法,子类必须实现。
|
||||||
|
|
||||||
|
此方法应根据初始化时传入的 `self.params` 来构建并返回一个字符串。
|
||||||
|
返回的字符串将被拼接到目标Prompt的最前面。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 生成的文本内容。
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_config(self, key: str, default: Any = None) -> Any:
|
||||||
|
"""获取插件配置值,支持嵌套键访问。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: 配置键名,使用点号进行嵌套访问,如 "section.subsection.key"。
|
||||||
|
default: 未找到键时返回的默认值。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Any: 配置值或默认值。
|
||||||
|
"""
|
||||||
|
if not self.plugin_config:
|
||||||
|
return default
|
||||||
|
|
||||||
|
keys = key.split(".")
|
||||||
|
current = self.plugin_config
|
||||||
|
for k in keys:
|
||||||
|
if isinstance(current, dict) and k in current:
|
||||||
|
current = current[k]
|
||||||
|
else:
|
||||||
|
return default
|
||||||
|
return current
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_prompt_info(cls) -> "PromptInfo":
|
||||||
|
"""从类属性生成PromptInfo,用于组件注册和管理。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PromptInfo: 生成的Prompt信息对象。
|
||||||
|
"""
|
||||||
|
if not cls.prompt_name:
|
||||||
|
raise ValueError("Prompt组件必须定义 'prompt_name' 类属性。")
|
||||||
|
|
||||||
|
return PromptInfo(
|
||||||
|
name=cls.prompt_name,
|
||||||
|
component_type=ComponentType.PROMPT,
|
||||||
|
description=cls.prompt_description,
|
||||||
|
injection_point=cls.injection_point,
|
||||||
|
)
|
||||||
@@ -20,6 +20,7 @@ class ComponentType(Enum):
|
|||||||
EVENT_HANDLER = "event_handler" # 事件处理组件
|
EVENT_HANDLER = "event_handler" # 事件处理组件
|
||||||
CHATTER = "chatter" # 聊天处理器组件
|
CHATTER = "chatter" # 聊天处理器组件
|
||||||
INTEREST_CALCULATOR = "interest_calculator" # 兴趣度计算组件
|
INTEREST_CALCULATOR = "interest_calculator" # 兴趣度计算组件
|
||||||
|
PROMPT = "prompt" # Prompt组件
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return self.value
|
return self.value
|
||||||
@@ -266,6 +267,18 @@ class EventInfo(ComponentInfo):
|
|||||||
self.component_type = ComponentType.EVENT_HANDLER
|
self.component_type = ComponentType.EVENT_HANDLER
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PromptInfo(ComponentInfo):
|
||||||
|
"""Prompt组件信息"""
|
||||||
|
|
||||||
|
injection_point: str | list[str] = ""
|
||||||
|
"""要注入的目标Prompt名称或列表"""
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__post_init__()
|
||||||
|
self.component_type = ComponentType.PROMPT
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PluginInfo:
|
class PluginInfo:
|
||||||
"""插件信息"""
|
"""插件信息"""
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from src.plugin_system.base.base_chatter import BaseChatter
|
|||||||
from src.plugin_system.base.base_command import BaseCommand
|
from src.plugin_system.base.base_command import BaseCommand
|
||||||
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
||||||
from src.plugin_system.base.base_interest_calculator import BaseInterestCalculator
|
from src.plugin_system.base.base_interest_calculator import BaseInterestCalculator
|
||||||
|
from src.plugin_system.base.base_prompt import BasePrompt
|
||||||
from src.plugin_system.base.base_tool import BaseTool
|
from src.plugin_system.base.base_tool import BaseTool
|
||||||
from src.plugin_system.base.component_types import (
|
from src.plugin_system.base.component_types import (
|
||||||
ActionInfo,
|
ActionInfo,
|
||||||
@@ -22,6 +23,7 @@ from src.plugin_system.base.component_types import (
|
|||||||
InterestCalculatorInfo,
|
InterestCalculatorInfo,
|
||||||
PluginInfo,
|
PluginInfo,
|
||||||
PlusCommandInfo,
|
PlusCommandInfo,
|
||||||
|
PromptInfo,
|
||||||
ToolInfo,
|
ToolInfo,
|
||||||
)
|
)
|
||||||
from src.plugin_system.base.plus_command import PlusCommand
|
from src.plugin_system.base.plus_command import PlusCommand
|
||||||
@@ -37,6 +39,7 @@ ComponentClassType = (
|
|||||||
| type[PlusCommand]
|
| type[PlusCommand]
|
||||||
| type[BaseChatter]
|
| type[BaseChatter]
|
||||||
| type[BaseInterestCalculator]
|
| type[BaseInterestCalculator]
|
||||||
|
| type[BasePrompt]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -183,6 +186,10 @@ class ComponentRegistry:
|
|||||||
assert isinstance(component_info, InterestCalculatorInfo)
|
assert isinstance(component_info, InterestCalculatorInfo)
|
||||||
assert issubclass(component_class, BaseInterestCalculator)
|
assert issubclass(component_class, BaseInterestCalculator)
|
||||||
ret = self._register_interest_calculator_component(component_info, component_class)
|
ret = self._register_interest_calculator_component(component_info, component_class)
|
||||||
|
case ComponentType.PROMPT:
|
||||||
|
assert isinstance(component_info, PromptInfo)
|
||||||
|
assert issubclass(component_class, BasePrompt)
|
||||||
|
ret = self._register_prompt_component(component_info, component_class)
|
||||||
case _:
|
case _:
|
||||||
logger.warning(f"未知组件类型: {component_type}")
|
logger.warning(f"未知组件类型: {component_type}")
|
||||||
ret = False
|
ret = False
|
||||||
@@ -346,6 +353,31 @@ class ComponentRegistry:
|
|||||||
logger.debug(f"已注册InterestCalculator组件: {calculator_name}")
|
logger.debug(f"已注册InterestCalculator组件: {calculator_name}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def _register_prompt_component(
|
||||||
|
self, prompt_info: PromptInfo, prompt_class: "ComponentClassType"
|
||||||
|
) -> bool:
|
||||||
|
"""注册Prompt组件到Prompt特定注册表"""
|
||||||
|
prompt_name = prompt_info.name
|
||||||
|
if not prompt_name:
|
||||||
|
logger.error(f"Prompt组件 {prompt_class.__name__} 必须指定名称")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not hasattr(self, "_prompt_registry"):
|
||||||
|
self._prompt_registry: dict[str, type[BasePrompt]] = {}
|
||||||
|
if not hasattr(self, "_enabled_prompt_registry"):
|
||||||
|
self._enabled_prompt_registry: dict[str, type[BasePrompt]] = {}
|
||||||
|
|
||||||
|
_assign_plugin_attrs(
|
||||||
|
prompt_class, prompt_info.plugin_name, self.get_plugin_config(prompt_info.plugin_name) or {}
|
||||||
|
)
|
||||||
|
self._prompt_registry[prompt_name] = prompt_class # type: ignore
|
||||||
|
|
||||||
|
if prompt_info.enabled:
|
||||||
|
self._enabled_prompt_registry[prompt_name] = prompt_class # type: ignore
|
||||||
|
|
||||||
|
logger.debug(f"已注册Prompt组件: {prompt_name}")
|
||||||
|
return True
|
||||||
|
|
||||||
# === 组件移除相关 ===
|
# === 组件移除相关 ===
|
||||||
|
|
||||||
async def remove_component(self, component_name: str, component_type: ComponentType, plugin_name: str) -> bool:
|
async def remove_component(self, component_name: str, component_type: ComponentType, plugin_name: str) -> bool:
|
||||||
@@ -580,7 +612,17 @@ class ComponentRegistry:
|
|||||||
component_name: str,
|
component_name: str,
|
||||||
component_type: ComponentType | None = None,
|
component_type: ComponentType | None = None,
|
||||||
) -> (
|
) -> (
|
||||||
type[BaseCommand | BaseAction | BaseEventHandler | BaseTool | PlusCommand | BaseChatter | BaseInterestCalculator] | None
|
type[
|
||||||
|
BaseCommand
|
||||||
|
| BaseAction
|
||||||
|
| BaseEventHandler
|
||||||
|
| BaseTool
|
||||||
|
| PlusCommand
|
||||||
|
| BaseChatter
|
||||||
|
| BaseInterestCalculator
|
||||||
|
| BasePrompt
|
||||||
|
]
|
||||||
|
| None
|
||||||
):
|
):
|
||||||
"""获取组件类,支持自动命名空间解析
|
"""获取组件类,支持自动命名空间解析
|
||||||
|
|
||||||
@@ -829,6 +871,7 @@ class ComponentRegistry:
|
|||||||
events_handlers: int = 0
|
events_handlers: int = 0
|
||||||
plus_command_components: int = 0
|
plus_command_components: int = 0
|
||||||
chatter_components: int = 0
|
chatter_components: int = 0
|
||||||
|
prompt_components: int = 0
|
||||||
for component in self._components.values():
|
for component in self._components.values():
|
||||||
if component.component_type == ComponentType.ACTION:
|
if component.component_type == ComponentType.ACTION:
|
||||||
action_components += 1
|
action_components += 1
|
||||||
@@ -842,6 +885,8 @@ class ComponentRegistry:
|
|||||||
plus_command_components += 1
|
plus_command_components += 1
|
||||||
elif component.component_type == ComponentType.CHATTER:
|
elif component.component_type == ComponentType.CHATTER:
|
||||||
chatter_components += 1
|
chatter_components += 1
|
||||||
|
elif component.component_type == ComponentType.PROMPT:
|
||||||
|
prompt_components += 1
|
||||||
return {
|
return {
|
||||||
"action_components": action_components,
|
"action_components": action_components,
|
||||||
"command_components": command_components,
|
"command_components": command_components,
|
||||||
@@ -849,6 +894,7 @@ class ComponentRegistry:
|
|||||||
"event_handlers": events_handlers,
|
"event_handlers": events_handlers,
|
||||||
"plus_command_components": plus_command_components,
|
"plus_command_components": plus_command_components,
|
||||||
"chatter_components": chatter_components,
|
"chatter_components": chatter_components,
|
||||||
|
"prompt_components": prompt_components,
|
||||||
"total_components": len(self._components),
|
"total_components": len(self._components),
|
||||||
"total_plugins": len(self._plugins),
|
"total_plugins": len(self._plugins),
|
||||||
"components_by_type": {
|
"components_by_type": {
|
||||||
|
|||||||
@@ -358,13 +358,14 @@ class PluginManager:
|
|||||||
event_handler_count = stats.get("event_handlers", 0)
|
event_handler_count = stats.get("event_handlers", 0)
|
||||||
plus_command_count = stats.get("plus_command_components", 0)
|
plus_command_count = stats.get("plus_command_components", 0)
|
||||||
chatter_count = stats.get("chatter_components", 0)
|
chatter_count = stats.get("chatter_components", 0)
|
||||||
|
prompt_count = stats.get("prompt_components", 0)
|
||||||
total_components = stats.get("total_components", 0)
|
total_components = stats.get("total_components", 0)
|
||||||
|
|
||||||
# 📋 显示插件加载总览
|
# 📋 显示插件加载总览
|
||||||
if total_registered > 0:
|
if total_registered > 0:
|
||||||
logger.info("🎉 插件系统加载完成!")
|
logger.info("🎉 插件系统加载完成!")
|
||||||
logger.info(
|
logger.info(
|
||||||
f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count}, Tool: {tool_count}, PlusCommand: {plus_command_count}, EventHandler: {event_handler_count}, Chatter: {chatter_count})"
|
f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count}, Tool: {tool_count}, PlusCommand: {plus_command_count}, EventHandler: {event_handler_count}, Chatter: {chatter_count}, Prompt: {prompt_count})"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 显示详细的插件列表
|
# 显示详细的插件列表
|
||||||
@@ -402,6 +403,9 @@ class PluginManager:
|
|||||||
plus_command_components = [
|
plus_command_components = [
|
||||||
c for c in plugin_info.components if c.component_type == ComponentType.PLUS_COMMAND
|
c for c in plugin_info.components if c.component_type == ComponentType.PLUS_COMMAND
|
||||||
]
|
]
|
||||||
|
prompt_components = [
|
||||||
|
c for c in plugin_info.components if c.component_type == ComponentType.PROMPT
|
||||||
|
]
|
||||||
|
|
||||||
if action_components:
|
if action_components:
|
||||||
action_details = [format_component(c) for c in action_components]
|
action_details = [format_component(c) for c in action_components]
|
||||||
@@ -425,6 +429,9 @@ class PluginManager:
|
|||||||
if event_handler_components:
|
if event_handler_components:
|
||||||
event_handler_details = [format_component(c) for c in event_handler_components]
|
event_handler_details = [format_component(c) for c in event_handler_components]
|
||||||
logger.info(f" 📢 EventHandler组件: {', '.join(event_handler_details)}")
|
logger.info(f" 📢 EventHandler组件: {', '.join(event_handler_details)}")
|
||||||
|
if prompt_components:
|
||||||
|
prompt_details = [format_component(c) for c in prompt_components]
|
||||||
|
logger.info(f" 📝 Prompt组件: {', '.join(prompt_details)}")
|
||||||
|
|
||||||
# 权限节点信息
|
# 权限节点信息
|
||||||
if plugin_instance := self.loaded_plugins.get(plugin_name):
|
if plugin_instance := self.loaded_plugins.get(plugin_name):
|
||||||
|
|||||||
Reference in New Issue
Block a user