diff --git a/plugins/hello_world_plugin/plugin.py b/plugins/hello_world_plugin/plugin.py index 5d1026d52..fbb4fcab8 100644 --- a/plugins/hello_world_plugin/plugin.py +++ b/plugins/hello_world_plugin/plugin.py @@ -6,6 +6,8 @@ from src.plugin_system import ( BaseAction, BaseEventHandler, BasePlugin, + BasePrompt, + ToolParamType, BaseTool, ChatType, CommandArgs, @@ -36,7 +38,17 @@ class GetSystemInfoTool(BaseTool): name = "get_system_info" description = "获取当前系统的模拟版本和状态信息。" available_for_llm = True - parameters = [] + parameters = [ + ("query", ToolParamType.STRING, "要搜索的关键词或问题。", True, None), + ("num_results", ToolParamType.INTEGER, "期望每个搜索引擎返回的搜索结果数量,默认为5。", False, None), + ( + "time_range", + ToolParamType.STRING, + "指定搜索的时间范围,可以是 'any', 'week', 'month'。默认为 'any'。", + False, + ["any", "week", "month"], + ), + ] # type: ignore async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: return {"name": self.name, "content": "系统版本: 1.0.1, 状态: 运行正常"} @@ -99,7 +111,6 @@ class LLMJudgeExampleAction(BaseAction): async def go_activate(self, chat_content: str = "", llm_judge_model=None) -> bool: """LLM 判断激活:判断用户是否情绪低落""" return await self._llm_judge_activation( - chat_content=chat_content, judge_prompt=""" 判断用户是否表达了以下情绪或需求: 1. 感到难过、沮丧或失落 @@ -169,6 +180,19 @@ class RandomEmojiAction(BaseAction): return True, "成功发送了一个随机表情" +class WeatherPrompt(BasePrompt): + """一个简单的Prompt组件,用于向Planner注入天气信息。""" + + prompt_name = "weather_info_prompt" + prompt_description = "向Planner注入当前天气信息,以丰富对话上下文。" + injection_point = "planner_prompt" + + async def execute(self) -> str: + # 在实际应用中,这里可以调用天气API + # 为了演示,我们返回一个固定的天气信息 + return "当前天气:晴朗,温度25°C。" + + @register_plugin class HelloWorldPlugin(BasePlugin): """一个包含四大核心组件和高级配置功能的入门示例插件。""" @@ -178,7 +202,6 @@ class HelloWorldPlugin(BasePlugin): dependencies = [] python_dependencies = [] config_file_name = "config.toml" - enable_plugin = False config_schema = { "meta": { @@ -208,4 +231,7 @@ class HelloWorldPlugin(BasePlugin): if self.get_config("components.random_emoji_action_enabled", True): components.append((RandomEmojiAction.get_action_info(), RandomEmojiAction)) + # 注册新的Prompt组件 + components.append((WeatherPrompt.get_prompt_info(), WeatherPrompt)) + return components diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index 9b25fd110..ec7e14787 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -23,7 +23,8 @@ from src.chat.utils.chat_message_builder import ( from src.chat.utils.memory_mappings import get_memory_type_chinese_label # 导入新的统一Prompt系统 -from src.chat.utils.prompt import Prompt, PromptParameters, global_prompt_manager +from src.chat.utils.prompt import Prompt, global_prompt_manager +from src.chat.utils.prompt_params import PromptParameters from src.chat.utils.timer_calculator import Timer from src.chat.utils.utils import get_chat_type_and_target_info from src.common.logger import get_logger diff --git a/src/chat/utils/prompt.py b/src/chat/utils/prompt.py index 543da2d2b..6f46a99c3 100644 --- a/src/chat/utils/prompt.py +++ b/src/chat/utils/prompt.py @@ -8,13 +8,14 @@ import contextvars import re import time from contextlib import asynccontextmanager -from dataclasses import dataclass, field -from typing import Any, Literal, Optional +from typing import Any, Optional from rich.traceback import install from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.utils.chat_message_builder import build_readable_messages +from src.chat.utils.prompt_component_manager import prompt_component_manager +from src.chat.utils.prompt_params import PromptParameters from src.common.logger import get_logger from src.config.config import global_config from src.person_info.person_info import get_person_info_manager @@ -23,80 +24,6 @@ install(extra_lines=3) logger = get_logger("unified_prompt") -@dataclass -class PromptParameters: - """统一提示词参数系统""" - - # 基础参数 - chat_id: str = "" - is_group_chat: bool = False - sender: str = "" - target: str = "" - reply_to: str = "" - extra_info: str = "" - prompt_mode: Literal["s4u", "normal", "minimal"] = "s4u" - bot_name: str = "" - bot_nickname: str = "" - - # 功能开关 - enable_tool: bool = True - enable_memory: bool = True - enable_expression: bool = True - enable_relation: bool = True - enable_cross_context: bool = True - enable_knowledge: bool = True - - # 性能控制 - max_context_messages: int = 50 - - # 调试选项 - debug_mode: bool = False - - # 聊天历史和上下文 - chat_target_info: dict[str, Any] | None = None - message_list_before_now_long: list[dict[str, Any]] = field(default_factory=list) - message_list_before_short: list[dict[str, Any]] = field(default_factory=list) - chat_talking_prompt_short: str = "" - target_user_info: dict[str, Any] | None = None - - # 已构建的内容块 - expression_habits_block: str = "" - relation_info_block: str = "" - memory_block: str = "" - tool_info_block: str = "" - knowledge_prompt: str = "" - cross_context_block: str = "" - - # 其他内容块 - keywords_reaction_prompt: str = "" - extra_info_block: str = "" - time_block: str = "" - identity_block: str = "" - schedule_block: str = "" - moderation_prompt_block: str = "" - safety_guidelines_block: str = "" - reply_target_block: str = "" - mood_prompt: str = "" - action_descriptions: str = "" - - # 可用动作信息 - available_actions: dict[str, Any] | None = None - - # 动态生成的聊天场景提示 - chat_scene: str = "" - - def validate(self) -> list[str]: - """参数验证""" - errors = [] - if not self.chat_id: - errors.append("chat_id不能为空") - if self.prompt_mode not in ["s4u", "normal", "minimal"]: - errors.append("prompt_mode必须是's4u'、'normal'或'minimal'") - if self.max_context_messages <= 0: - errors.append("max_context_messages必须大于0") - return errors - - class PromptContext: """提示词上下文管理器""" @@ -131,7 +58,7 @@ class PromptContext: context_id = None previous_context = self._current_context - token = self._current_context_var.set(context_id) if context_id else None + token = self._current_context_var.set(context_id) if context_id else None # type: ignore else: previous_context = self._current_context token = None @@ -184,16 +111,42 @@ class PromptManager: async with self._context.async_scope(message_id): yield self - async def get_prompt_async(self, name: str) -> "Prompt": - """异步获取提示模板""" + async def get_prompt_async(self, name: str, parameters: PromptParameters | None = None) -> "Prompt": + """ + 异步获取提示模板,并动态注入插件内容 + """ + original_prompt = None context_prompt = await self._context.get_prompt_async(name) if context_prompt is not None: logger.debug(f"从上下文中获取提示词: {name} {context_prompt}") - return context_prompt - - if name not in self._prompts: + original_prompt = context_prompt + elif name in self._prompts: + original_prompt = self._prompts[name] + else: raise KeyError(f"Prompt '{name}' not found") - return self._prompts[name] + + # 动态注入插件内容 + if original_prompt.name: + # 确保我们有有效的parameters实例 + params_for_injection = parameters or original_prompt.parameters + + components_prefix = await prompt_component_manager.execute_components_for( + injection_point=original_prompt.name, params=params_for_injection + ) + logger.info(components_prefix) + if components_prefix: + logger.info(f"为'{name}'注入插件内容: \n{components_prefix}") + # 创建一个新的临时Prompt实例,不进行注册 + new_template = f"{components_prefix}\n\n{original_prompt.template}" + temp_prompt = Prompt( + template=new_template, + name=original_prompt.name, + parameters=original_prompt.parameters, + should_register=False, # 确保不重新注册 + ) + return temp_prompt + + return original_prompt def generate_name(self, template: str) -> str: """为未命名的prompt生成名称""" @@ -215,7 +168,9 @@ class PromptManager: async def format_prompt(self, name: str, **kwargs) -> str: """格式化提示模板""" - prompt = await self.get_prompt_async(name) + # 提取parameters用于注入 + parameters = kwargs.get("parameters") + prompt = await self.get_prompt_async(name, parameters=parameters) result = prompt.format(**kwargs) return result @@ -303,11 +258,14 @@ class Prompt: start_time = time.time() try: - # 构建上下文数据 + # 1. 构建核心上下文数据 context_data = await self._build_context_data() - # 格式化模板 - result = await self._format_with_context(context_data) + # 2. 格式化主模板 + main_formatted_prompt = await self._format_with_context(context_data) + + # 3. 拼接组件内容和主模板内容 (逻辑已前置到 get_prompt_async) + result = main_formatted_prompt total_time = time.time() - start_time logger.debug(f"Prompt构建完成,模式: {self.parameters.prompt_mode}, 耗时: {total_time:.2f}s") @@ -467,9 +425,13 @@ class Prompt: if not self.parameters.message_list_before_now_long: return + target_user_id = "" + if self.parameters.target_user_info: + target_user_id = self.parameters.target_user_info.get("user_id") or "" + read_history_prompt, unread_history_prompt = await self._build_s4u_chat_history_prompts( self.parameters.message_list_before_now_long, - self.parameters.target_user_info.get("user_id") if self.parameters.target_user_info else "", + target_user_id, self.parameters.sender, self.parameters.chat_id, ) @@ -495,11 +457,14 @@ class Prompt: # 创建临时生成器实例来使用其方法 temp_generator = await get_replyer(None, chat_id, request_type="prompt_building") - return await temp_generator.build_s4u_chat_history_prompts( - message_list_before_now, target_user_id, sender, chat_id - ) + if temp_generator: + return await temp_generator.build_s4u_chat_history_prompts( + message_list_before_now, target_user_id, sender, chat_id + ) + return "", "" except Exception as e: logger.error(f"构建S4U历史消息prompt失败: {e}") + return "", "" async def _build_expression_habits(self) -> dict[str, Any]: """构建表达习惯""" @@ -586,10 +551,10 @@ class Prompt: running_memories, instant_memory = await asyncio.gather(*memory_tasks, return_exceptions=True) # 处理可能的异常结果 - if isinstance(running_memories, Exception): + if isinstance(running_memories, BaseException): logger.warning(f"长期记忆查询失败: {running_memories}") running_memories = [] - if isinstance(instant_memory, Exception): + if isinstance(instant_memory, BaseException): logger.warning(f"即时记忆查询失败: {instant_memory}") instant_memory = None @@ -1103,8 +1068,24 @@ def create_prompt( async def create_prompt_async( template: str, name: str | None = None, parameters: PromptParameters | None = None, **kwargs ) -> Prompt: - """异步创建Prompt实例""" - prompt = create_prompt(template, name, parameters, **kwargs) + """异步创建Prompt实例,并动态注入插件内容""" + # 确保有可用的parameters实例 + final_params = parameters or PromptParameters(**kwargs) + + # 动态注入插件内容 + if name: + components_prefix = await prompt_component_manager.execute_components_for( + injection_point=name, params=final_params + ) + if components_prefix: + logger.debug(f"为'{name}'注入插件内容: \n{components_prefix}") + template = f"{components_prefix}\n\n{template}" + + # 使用可能已修改的模板创建实例 + prompt = create_prompt(template, name, final_params) + + # 如果在特定上下文中,则异步注册 if global_prompt_manager._context._current_context: await global_prompt_manager._context.register_async(prompt) + return prompt diff --git a/src/chat/utils/prompt_component_manager.py b/src/chat/utils/prompt_component_manager.py new file mode 100644 index 000000000..58c7a097b --- /dev/null +++ b/src/chat/utils/prompt_component_manager.py @@ -0,0 +1,109 @@ +import asyncio +from typing import Type + +from src.chat.utils.prompt_params import PromptParameters +from src.common.logger import get_logger +from src.plugin_system.base.base_prompt import BasePrompt +from src.plugin_system.base.component_types import ComponentType, PromptInfo +from src.plugin_system.core.component_registry import component_registry + +logger = get_logger("prompt_component_manager") + + +class PromptComponentManager: + """ + 管理所有 `BasePrompt` 组件的单例类。 + + 该管理器负责: + 1. 从 `component_registry` 中查询 `BasePrompt` 子类。 + 2. 根据注入点(目标Prompt名称)对它们进行筛选。 + 3. 提供一个接口,以便在构建核心Prompt时,能够获取并执行所有相关的组件。 + """ + + def get_components_for(self, injection_point: str) -> list[Type[BasePrompt]]: + """ + 获取指定注入点的所有已注册组件类。 + + Args: + injection_point: 目标Prompt的名称。 + + Returns: + list[Type[BasePrompt]]: 与该注入点关联的组件类列表。 + """ + # 从组件注册中心获取所有启用的Prompt组件 + enabled_prompts = component_registry.get_enabled_components_by_type(ComponentType.PROMPT) + + matching_components: list[Type[BasePrompt]] = [] + + for prompt_name, prompt_info in enabled_prompts.items(): + # 确保 prompt_info 是 PromptInfo 类型 + if not isinstance(prompt_info, PromptInfo): + continue + + # 获取注入点信息 + injection_points = prompt_info.injection_point + if isinstance(injection_points, str): + injection_points = [injection_points] + + # 检查当前注入点是否匹配 + if injection_point in injection_points: + # 获取组件类 + component_class = component_registry.get_component_class(prompt_name, ComponentType.PROMPT) + if component_class and issubclass(component_class, BasePrompt): + matching_components.append(component_class) + + return matching_components + + async def execute_components_for(self, injection_point: str, params: PromptParameters) -> str: + """ + 实例化并执行指定注入点的所有组件,然后将它们的输出拼接成一个字符串。 + + Args: + injection_point: 目标Prompt的名称。 + params: 用于初始化组件的 PromptParameters 对象。 + + Returns: + str: 所有相关组件生成的、用换行符连接的文本内容。 + """ + component_classes = self.get_components_for(injection_point) + if not component_classes: + return "" + + tasks = [] + for component_class in component_classes: + try: + # 从注册中心获取组件信息 + prompt_info = component_registry.get_component_info( + component_class.prompt_name, ComponentType.PROMPT + ) + if not isinstance(prompt_info, PromptInfo): + logger.warning(f"找不到 Prompt 组件 '{component_class.prompt_name}' 的信息,无法获取插件配置") + plugin_config = {} + else: + plugin_config = component_registry.get_plugin_config(prompt_info.plugin_name) + + instance = component_class(params=params, plugin_config=plugin_config) + tasks.append(instance.execute()) + except Exception as e: + logger.error(f"实例化 Prompt 组件 '{component_class.prompt_name}' 失败: {e}") + + if not tasks: + return "" + + # 并行执行所有组件 + results = await asyncio.gather(*tasks, return_exceptions=True) + + # 过滤掉执行失败的结果和空字符串 + valid_results = [] + for i, result in enumerate(results): + if isinstance(result, Exception): + logger.error(f"执行 Prompt 组件 '{component_classes[i].prompt_name}' 失败: {result}") + elif result and isinstance(result, str) and result.strip(): + valid_results.append(result.strip()) + + # 使用换行符拼接所有有效结果 + return "\n".join(valid_results) + + +# 创建全局单例 +prompt_component_manager = PromptComponentManager() \ No newline at end of file diff --git a/src/chat/utils/prompt_params.py b/src/chat/utils/prompt_params.py new file mode 100644 index 000000000..2722c3605 --- /dev/null +++ b/src/chat/utils/prompt_params.py @@ -0,0 +1,79 @@ +""" +This module contains the PromptParameters class, which is used to define the parameters for a prompt. +""" +from dataclasses import dataclass, field +from typing import Any, Literal + + +@dataclass +class PromptParameters: + """统一提示词参数系统""" + + # 基础参数 + chat_id: str = "" + is_group_chat: bool = False + sender: str = "" + target: str = "" + reply_to: str = "" + extra_info: str = "" + prompt_mode: Literal["s4u", "normal", "minimal"] = "s4u" + bot_name: str = "" + bot_nickname: str = "" + + # 功能开关 + enable_tool: bool = True + enable_memory: bool = True + enable_expression: bool = True + enable_relation: bool = True + enable_cross_context: bool = True + enable_knowledge: bool = True + + # 性能控制 + max_context_messages: int = 50 + + # 调试选项 + debug_mode: bool = False + + # 聊天历史和上下文 + chat_target_info: dict[str, Any] | None = None + message_list_before_now_long: list[dict[str, Any]] = field(default_factory=list) + message_list_before_short: list[dict[str, Any]] = field(default_factory=list) + chat_talking_prompt_short: str = "" + target_user_info: dict[str, Any] | None = None + + # 已构建的内容块 + expression_habits_block: str = "" + relation_info_block: str = "" + memory_block: str = "" + tool_info_block: str = "" + knowledge_prompt: str = "" + cross_context_block: str = "" + + # 其他内容块 + keywords_reaction_prompt: str = "" + extra_info_block: str = "" + time_block: str = "" + identity_block: str = "" + schedule_block: str = "" + moderation_prompt_block: str = "" + safety_guidelines_block: str = "" + reply_target_block: str = "" + mood_prompt: str = "" + action_descriptions: str = "" + + # 可用动作信息 + available_actions: dict[str, Any] | None = None + + # 动态生成的聊天场景提示 + chat_scene: str = "" + + def validate(self) -> list[str]: + """参数验证""" + errors = [] + if not self.chat_id: + errors.append("chat_id不能为空") + if self.prompt_mode not in ["s4u", "normal", "minimal"]: + errors.append("prompt_mode必须是's4u'、'normal'或'minimal'") + if self.max_context_messages <= 0: + errors.append("max_context_messages必须大于0") + return errors \ No newline at end of file diff --git a/src/plugin_system/__init__.py b/src/plugin_system/__init__.py index 6f9d14714..3f19de3a1 100644 --- a/src/plugin_system/__init__.py +++ b/src/plugin_system/__init__.py @@ -26,6 +26,7 @@ from .base import ( ActionInfo, BaseAction, BaseCommand, + BasePrompt, BaseEventHandler, BasePlugin, BaseTool, @@ -64,6 +65,7 @@ __all__ = [ "BaseEventHandler", # 基础类 "BasePlugin", + "BasePrompt", "BaseTool", "ChatMode", "ChatType", diff --git a/src/plugin_system/base/__init__.py b/src/plugin_system/base/__init__.py index 0e69b1206..1b62d2a78 100644 --- a/src/plugin_system/base/__init__.py +++ b/src/plugin_system/base/__init__.py @@ -8,6 +8,7 @@ from .base_action import BaseAction from .base_command import BaseCommand from .base_events_handler import BaseEventHandler from .base_plugin import BasePlugin +from .base_prompt import BasePrompt from .base_tool import BaseTool from .command_args import CommandArgs from .component_types import ( @@ -37,6 +38,7 @@ __all__ = [ "BaseCommand", "BaseEventHandler", "BasePlugin", + "BasePrompt", "BaseTool", "ChatMode", "ChatType", diff --git a/src/plugin_system/base/base_plugin.py b/src/plugin_system/base/base_plugin.py index df48a3164..662af3a5e 100644 --- a/src/plugin_system/base/base_plugin.py +++ b/src/plugin_system/base/base_plugin.py @@ -8,6 +8,7 @@ from src.plugin_system.base.component_types import ( EventHandlerInfo, InterestCalculatorInfo, PlusCommandInfo, + PromptInfo, ToolInfo, ) @@ -15,6 +16,7 @@ from .base_action import BaseAction from .base_command import BaseCommand from .base_events_handler import BaseEventHandler from .base_interest_calculator import BaseInterestCalculator +from .base_prompt import BasePrompt from .base_tool import BaseTool from .plugin_base import PluginBase from .plus_command import PlusCommand @@ -80,6 +82,13 @@ class BasePlugin(PluginBase): logger.warning("EventHandler的get_info逻辑尚未实现") return None + elif component_type == ComponentType.PROMPT: + if hasattr(component_class, "get_prompt_info"): + return component_class.get_prompt_info() + else: + logger.warning(f"Prompt类 {component_class.__name__} 缺少 get_prompt_info 方法") + return None + else: logger.error(f"不支持的组件类型: {component_type}") return None @@ -109,6 +118,7 @@ class BasePlugin(PluginBase): | tuple[EventHandlerInfo, type[BaseEventHandler]] | tuple[ToolInfo, type[BaseTool]] | tuple[InterestCalculatorInfo, type[BaseInterestCalculator]] + | tuple[PromptInfo, type[BasePrompt]] ]: """获取插件包含的组件列表 diff --git a/src/plugin_system/base/base_prompt.py b/src/plugin_system/base/base_prompt.py new file mode 100644 index 000000000..8947ea2f5 --- /dev/null +++ b/src/plugin_system/base/base_prompt.py @@ -0,0 +1,95 @@ +from abc import ABC, abstractmethod +from typing import Any + +from src.chat.utils.prompt_params import PromptParameters +from src.common.logger import get_logger +from src.plugin_system.base.component_types import ComponentType, PromptInfo + +logger = get_logger("base_prompt") + + +class BasePrompt(ABC): + """Prompt组件基类 + + Prompt是插件的一种组件类型,用于动态地向现有的核心Prompt模板中注入额外的上下文信息。 + 它的主要作用是在不修改核心代码的情况下,扩展和定制模型的行为。 + + 子类可以通过类属性定义其行为: + - prompt_name: Prompt组件的唯一名称。 + - injection_point: 指定要注入的目标Prompt名称(或名称列表)。 + """ + + prompt_name: str = "" + """Prompt组件的名称""" + prompt_description: str = "" + """Prompt组件的描述""" + + # 定义此组件希望注入到哪个或哪些核心Prompt中 + # 可以是一个字符串(单个目标)或字符串列表(多个目标) + # 例如: "planner_prompt" 或 ["s4u_style_prompt", "normal_style_prompt"] + injection_point: str | list[str] = "" + """要注入的目标Prompt名称或列表""" + + def __init__(self, params: PromptParameters, plugin_config: dict | None = None): + """初始化Prompt组件 + + Args: + params: 统一提示词参数,包含所有构建提示词所需的上下文信息。 + plugin_config: 插件配置字典。 + """ + self.params = params + self.plugin_config = plugin_config or {} + self.log_prefix = "[PromptComponent]" + + logger.debug(f"{self.log_prefix} Prompt组件 '{self.prompt_name}' 初始化完成") + + @abstractmethod + async def execute(self) -> str: + """执行Prompt生成的抽象方法,子类必须实现。 + + 此方法应根据初始化时传入的 `self.params` 来构建并返回一个字符串。 + 返回的字符串将被拼接到目标Prompt的最前面。 + + Returns: + str: 生成的文本内容。 + """ + pass + + def get_config(self, key: str, default: Any = None) -> Any: + """获取插件配置值,支持嵌套键访问。 + + Args: + key: 配置键名,使用点号进行嵌套访问,如 "section.subsection.key"。 + default: 未找到键时返回的默认值。 + + Returns: + Any: 配置值或默认值。 + """ + if not self.plugin_config: + return default + + keys = key.split(".") + current = self.plugin_config + for k in keys: + if isinstance(current, dict) and k in current: + current = current[k] + else: + return default + return current + + @classmethod + def get_prompt_info(cls) -> "PromptInfo": + """从类属性生成PromptInfo,用于组件注册和管理。 + + Returns: + PromptInfo: 生成的Prompt信息对象。 + """ + if not cls.prompt_name: + raise ValueError("Prompt组件必须定义 'prompt_name' 类属性。") + + return PromptInfo( + name=cls.prompt_name, + component_type=ComponentType.PROMPT, + description=cls.prompt_description, + injection_point=cls.injection_point, + ) \ No newline at end of file diff --git a/src/plugin_system/base/component_types.py b/src/plugin_system/base/component_types.py index 53952319e..5db5fdeb3 100644 --- a/src/plugin_system/base/component_types.py +++ b/src/plugin_system/base/component_types.py @@ -20,6 +20,7 @@ class ComponentType(Enum): EVENT_HANDLER = "event_handler" # 事件处理组件 CHATTER = "chatter" # 聊天处理器组件 INTEREST_CALCULATOR = "interest_calculator" # 兴趣度计算组件 + PROMPT = "prompt" # Prompt组件 def __str__(self) -> str: return self.value @@ -266,6 +267,18 @@ class EventInfo(ComponentInfo): self.component_type = ComponentType.EVENT_HANDLER +@dataclass +class PromptInfo(ComponentInfo): + """Prompt组件信息""" + + injection_point: str | list[str] = "" + """要注入的目标Prompt名称或列表""" + + def __post_init__(self): + super().__post_init__() + self.component_type = ComponentType.PROMPT + + @dataclass class PluginInfo: """插件信息""" diff --git a/src/plugin_system/core/component_registry.py b/src/plugin_system/core/component_registry.py index 51da61743..310f95997 100644 --- a/src/plugin_system/core/component_registry.py +++ b/src/plugin_system/core/component_registry.py @@ -11,6 +11,7 @@ from src.plugin_system.base.base_chatter import BaseChatter from src.plugin_system.base.base_command import BaseCommand from src.plugin_system.base.base_events_handler import BaseEventHandler from src.plugin_system.base.base_interest_calculator import BaseInterestCalculator +from src.plugin_system.base.base_prompt import BasePrompt from src.plugin_system.base.base_tool import BaseTool from src.plugin_system.base.component_types import ( ActionInfo, @@ -22,6 +23,7 @@ from src.plugin_system.base.component_types import ( InterestCalculatorInfo, PluginInfo, PlusCommandInfo, + PromptInfo, ToolInfo, ) from src.plugin_system.base.plus_command import PlusCommand @@ -37,6 +39,7 @@ ComponentClassType = ( | type[PlusCommand] | type[BaseChatter] | type[BaseInterestCalculator] + | type[BasePrompt] ) @@ -183,6 +186,10 @@ class ComponentRegistry: assert isinstance(component_info, InterestCalculatorInfo) assert issubclass(component_class, BaseInterestCalculator) ret = self._register_interest_calculator_component(component_info, component_class) + case ComponentType.PROMPT: + assert isinstance(component_info, PromptInfo) + assert issubclass(component_class, BasePrompt) + ret = self._register_prompt_component(component_info, component_class) case _: logger.warning(f"未知组件类型: {component_type}") ret = False @@ -346,6 +353,31 @@ class ComponentRegistry: logger.debug(f"已注册InterestCalculator组件: {calculator_name}") return True + def _register_prompt_component( + self, prompt_info: PromptInfo, prompt_class: "ComponentClassType" + ) -> bool: + """注册Prompt组件到Prompt特定注册表""" + prompt_name = prompt_info.name + if not prompt_name: + logger.error(f"Prompt组件 {prompt_class.__name__} 必须指定名称") + return False + + if not hasattr(self, "_prompt_registry"): + self._prompt_registry: dict[str, type[BasePrompt]] = {} + if not hasattr(self, "_enabled_prompt_registry"): + self._enabled_prompt_registry: dict[str, type[BasePrompt]] = {} + + _assign_plugin_attrs( + prompt_class, prompt_info.plugin_name, self.get_plugin_config(prompt_info.plugin_name) or {} + ) + self._prompt_registry[prompt_name] = prompt_class # type: ignore + + if prompt_info.enabled: + self._enabled_prompt_registry[prompt_name] = prompt_class # type: ignore + + logger.debug(f"已注册Prompt组件: {prompt_name}") + return True + # === 组件移除相关 === async def remove_component(self, component_name: str, component_type: ComponentType, plugin_name: str) -> bool: @@ -580,7 +612,17 @@ class ComponentRegistry: component_name: str, component_type: ComponentType | None = None, ) -> ( - type[BaseCommand | BaseAction | BaseEventHandler | BaseTool | PlusCommand | BaseChatter | BaseInterestCalculator] | None + type[ + BaseCommand + | BaseAction + | BaseEventHandler + | BaseTool + | PlusCommand + | BaseChatter + | BaseInterestCalculator + | BasePrompt + ] + | None ): """获取组件类,支持自动命名空间解析 @@ -829,6 +871,7 @@ class ComponentRegistry: events_handlers: int = 0 plus_command_components: int = 0 chatter_components: int = 0 + prompt_components: int = 0 for component in self._components.values(): if component.component_type == ComponentType.ACTION: action_components += 1 @@ -842,6 +885,8 @@ class ComponentRegistry: plus_command_components += 1 elif component.component_type == ComponentType.CHATTER: chatter_components += 1 + elif component.component_type == ComponentType.PROMPT: + prompt_components += 1 return { "action_components": action_components, "command_components": command_components, @@ -849,6 +894,7 @@ class ComponentRegistry: "event_handlers": events_handlers, "plus_command_components": plus_command_components, "chatter_components": chatter_components, + "prompt_components": prompt_components, "total_components": len(self._components), "total_plugins": len(self._plugins), "components_by_type": { diff --git a/src/plugin_system/core/plugin_manager.py b/src/plugin_system/core/plugin_manager.py index 6cd89e5f4..3a59efeda 100644 --- a/src/plugin_system/core/plugin_manager.py +++ b/src/plugin_system/core/plugin_manager.py @@ -358,13 +358,14 @@ class PluginManager: event_handler_count = stats.get("event_handlers", 0) plus_command_count = stats.get("plus_command_components", 0) chatter_count = stats.get("chatter_components", 0) + prompt_count = stats.get("prompt_components", 0) total_components = stats.get("total_components", 0) # 📋 显示插件加载总览 if total_registered > 0: logger.info("🎉 插件系统加载完成!") logger.info( - f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count}, Tool: {tool_count}, PlusCommand: {plus_command_count}, EventHandler: {event_handler_count}, Chatter: {chatter_count})" + f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count}, Tool: {tool_count}, PlusCommand: {plus_command_count}, EventHandler: {event_handler_count}, Chatter: {chatter_count}, Prompt: {prompt_count})" ) # 显示详细的插件列表 @@ -402,6 +403,9 @@ class PluginManager: plus_command_components = [ c for c in plugin_info.components if c.component_type == ComponentType.PLUS_COMMAND ] + prompt_components = [ + c for c in plugin_info.components if c.component_type == ComponentType.PROMPT + ] if action_components: action_details = [format_component(c) for c in action_components] @@ -425,6 +429,9 @@ class PluginManager: if event_handler_components: event_handler_details = [format_component(c) for c in event_handler_components] logger.info(f" 📢 EventHandler组件: {', '.join(event_handler_details)}") + if prompt_components: + prompt_details = [format_component(c) for c in prompt_components] + logger.info(f" 📝 Prompt组件: {', '.join(prompt_details)}") # 权限节点信息 if plugin_instance := self.loaded_plugins.get(plugin_name): diff --git a/src/plugins/built_in/affinity_flow_chatter/plan_filter.py b/src/plugins/built_in/affinity_flow_chatter/plan_filter.py index 275efddf3..6b97c056b 100644 --- a/src/plugins/built_in/affinity_flow_chatter/plan_filter.py +++ b/src/plugins/built_in/affinity_flow_chatter/plan_filter.py @@ -155,87 +155,22 @@ class ChatterPlanFilter: identity_block = f"你的名字是{bot_name}{bot_nickname},你{bot_core_personality}:" schedule_block = "" - # 优先检查是否被吵醒 - - angry_prompt_addition = "" - try: - from src.plugins.built_in.sleep_system.api import get_wakeup_manager - wakeup_mgr = get_wakeup_manager() - except ImportError: - logger.debug("无法导入睡眠系统API,将跳过相关检查。") - wakeup_mgr = None - - if wakeup_mgr: - - # 双重检查确保愤怒状态不会丢失 - # 检查1: 直接从 wakeup_manager 获取 - if wakeup_mgr.is_in_angry_state(): - angry_prompt_addition = wakeup_mgr.get_angry_prompt_addition() - - # 检查2: 如果上面没获取到,再从 mood_manager 确认 - if not angry_prompt_addition: - chat_mood_for_check = mood_manager.get_mood_by_chat_id(plan.chat_id) - if chat_mood_for_check.is_angry_from_wakeup: - angry_prompt_addition = global_config.sleep_system.angry_prompt - - if angry_prompt_addition: - schedule_block = angry_prompt_addition - elif global_config.planning_system.schedule_enable: + if global_config.planning_system.schedule_enable: if activity_info := schedule_manager.get_current_activity(): activity = activity_info.get("activity", "未知活动") schedule_block = f"你当前正在:{activity},但注意它与群聊的聊天无关。" mood_block = "" - # 如果被吵醒,则心情也是愤怒的,不需要另外的情绪模块 - if not angry_prompt_addition and global_config.mood.enable_mood: + # 需要情绪模块打开才能获得情绪,否则会引发报错 + if global_config.mood.enable_mood: chat_mood = mood_manager.get_mood_by_chat_id(plan.chat_id) mood_block = f"你现在的心情是:{chat_mood.mood_state}" - if plan.mode == ChatMode.PROACTIVE: - long_term_memory_block = await self._get_long_term_memory_context() - - chat_content_block, message_id_list = await build_readable_messages_with_id( - messages=[msg.flatten() for msg in plan.chat_history], - timestamp_mode="normal", - truncate=False, - show_actions=False, - ) - - prompt_template = await global_prompt_manager.get_prompt_async("proactive_planner_prompt") - actions_before_now = await get_actions_by_timestamp_with_chat( - chat_id=plan.chat_id, - timestamp_start=time.time() - 3600, - timestamp_end=time.time(), - limit=5, - ) - actions_before_now_block = build_readable_actions(actions=actions_before_now) - actions_before_now_block = f"你刚刚选择并执行过的action是:\n{actions_before_now_block}" - - prompt = prompt_template.format( - time_block=time_block, - identity_block=identity_block, - schedule_block=schedule_block, - mood_block=mood_block, - long_term_memory_block=long_term_memory_block, - chat_content_block=chat_content_block or "最近没有聊天内容。", - actions_before_now_block=actions_before_now_block, - ) - return prompt, message_id_list - # 构建已读/未读历史消息 read_history_block, unread_history_block, message_id_list = await self._build_read_unread_history_blocks( plan ) - # 为了兼容性,保留原有的chat_content_block - chat_content_block, _ = await build_readable_messages_with_id( - messages=[msg.flatten() for msg in plan.chat_history], - timestamp_mode="normal", - read_mark=self.last_obs_time_mark, - truncate=True, - show_actions=True, - ) - actions_before_now = await get_actions_by_timestamp_with_chat( chat_id=plan.chat_id, timestamp_start=time.time() - 3600, @@ -285,7 +220,7 @@ class ChatterPlanFilter: is_group_chat = plan.chat_type == ChatType.GROUP chat_context_description = "你现在正在一个群聊中" if not is_group_chat and plan.target_info: - chat_target_name = plan.target_info.get("person_name") or plan.target_info.get("user_nickname") or "对方" + chat_target_name = plan.target_info.person_name or plan.target_info.user_nickname or "对方" chat_context_description = f"你正在和 {chat_target_name} 私聊" action_options_block = await self._build_action_options(plan.available_actions) diff --git a/src/plugins/built_in/affinity_flow_chatter/plan_generator.py b/src/plugins/built_in/affinity_flow_chatter/plan_generator.py index 193477273..498471ff7 100644 --- a/src/plugins/built_in/affinity_flow_chatter/plan_generator.py +++ b/src/plugins/built_in/affinity_flow_chatter/plan_generator.py @@ -9,7 +9,7 @@ from src.chat.utils.utils import get_chat_type_and_target_info from src.common.data_models.database_data_model import DatabaseMessages from src.common.data_models.info_data_model import Plan, TargetPersonInfo from src.config.config import global_config -from src.plugin_system.base.component_types import ActionInfo, ChatMode, ChatType +from src.plugin_system.base.component_types import ActionInfo, ChatMode, ChatType, ComponentType from src.plugin_system.core.component_registry import component_registry @@ -55,6 +55,11 @@ class ChatterPlanGenerator: try: # 获取聊天类型和目标信息 chat_type, target_info = await get_chat_type_and_target_info(self.chat_id) + if chat_type: + chat_type = ChatType.GROUP + else: + #遇到未知类型也当私聊处理 + chat_type = ChatType.PRIVATE # 获取可用动作列表 available_actions = await self._get_available_actions(chat_type, mode) @@ -62,12 +67,16 @@ class ChatterPlanGenerator: # 获取聊天历史记录 recent_messages = await self._get_recent_messages() + # 构建计划对象 + # 使用 target_info 字典创建 TargetPersonInfo 实例 + target_person_info = TargetPersonInfo(**target_info) if target_info else TargetPersonInfo() + # 构建计划对象 plan = Plan( chat_id=self.chat_id, chat_type=chat_type, mode=mode, - target_info=target_info, + target_info=target_person_info, available_actions=available_actions, chat_history=recent_messages, ) @@ -77,6 +86,7 @@ class ChatterPlanGenerator: except Exception: # 如果生成失败,返回一个基本的空计划 return Plan( + chat_type = ChatType.PRIVATE,#空计划默认当成私聊 chat_id=self.chat_id, mode=mode, target_info=TargetPersonInfo(), @@ -124,7 +134,7 @@ class ChatterPlanGenerator: try: # 获取最近的消息记录 raw_messages = await get_raw_msg_before_timestamp_with_chat( - chat_id=self.chat_id, timestamp=time.time(), limit=global_config.memory.short_memory_length + chat_id=self.chat_id, timestamp=time.time(), limit=global_config.chat.max_context_size ) # 转换为 DatabaseMessages 对象 diff --git a/src/plugins/built_in/affinity_flow_chatter/planner.py b/src/plugins/built_in/affinity_flow_chatter/planner.py index 703141d70..9be3235e9 100644 --- a/src/plugins/built_in/affinity_flow_chatter/planner.py +++ b/src/plugins/built_in/affinity_flow_chatter/planner.py @@ -70,6 +70,7 @@ class ChatterActionPlanner: "replies_generated": 0, "other_actions_executed": 0, } + self._background_tasks: set[asyncio.Task] = set() async def plan(self, context: "StreamContext | None" = None) -> tuple[list[dict[str, Any]], Any | None]: """ @@ -157,7 +158,9 @@ class ChatterActionPlanner: ) if interest_updates: - asyncio.create_task(self._commit_interest_updates(interest_updates)) + task = asyncio.create_task(self._commit_interest_updates(interest_updates)) + self._background_tasks.add(task) + task.add_done_callback(self._handle_task_result) # 检查兴趣度是否达到非回复动作阈值 non_reply_action_interest_threshold = global_config.affinity_flow.non_reply_action_interest_threshold @@ -266,6 +269,17 @@ class ChatterActionPlanner: return final_actions_dict, final_target_message_dict + def _handle_task_result(self, task: asyncio.Task) -> None: + """处理后台任务的结果,记录异常。""" + try: + task.result() + except asyncio.CancelledError: + pass # 任务被取消是正常现象 + except Exception as e: + logger.error(f"后台任务执行失败: {e}", exc_info=True) + finally: + self._background_tasks.discard(task) + def get_planner_stats(self) -> dict[str, Any]: """获取规划器统计""" return self.planner_stats.copy()