refactor(prompt): 优化Prompt组件注入逻辑并简化代码
将Prompt组件的注入逻辑从`Prompt.format`方法前置到`PromptManager.get_prompt_async`和`create_prompt_async`中。这使得注入时机更早,逻辑更清晰,并允许在获取Prompt时就能动态传入参数以影响注入内容。 主要变更: - `PromptManager`: `get_prompt_async`现在负责处理组件注入,并接收可选的`parameters`参数。`format_prompt`相应地传递参数。 - `create_prompt_async`: 现在也支持在创建时进行动态注入。 - `Prompt.format`: 移除了原有的组件注入逻辑,简化了方法实现。 - `PromptComponentManager`: 重构为直接从全局`component_registry`获取组件,移除了自身的注册和存储逻辑,减少了状态管理的复杂性。 - `plan_filter.py`: 删除了大量冗余和重复的代码块,包括主动聊天模式的独立逻辑和旧的历史消息构建方式。
This commit is contained in:
@@ -58,7 +58,7 @@ class PromptContext:
|
|||||||
context_id = None
|
context_id = None
|
||||||
|
|
||||||
previous_context = self._current_context
|
previous_context = self._current_context
|
||||||
token = self._current_context_var.set(context_id) if context_id else None
|
token = self._current_context_var.set(context_id) if context_id else None # type: ignore
|
||||||
else:
|
else:
|
||||||
previous_context = self._current_context
|
previous_context = self._current_context
|
||||||
token = None
|
token = None
|
||||||
@@ -111,16 +111,42 @@ class PromptManager:
|
|||||||
async with self._context.async_scope(message_id):
|
async with self._context.async_scope(message_id):
|
||||||
yield self
|
yield self
|
||||||
|
|
||||||
async def get_prompt_async(self, name: str) -> "Prompt":
|
async def get_prompt_async(self, name: str, parameters: PromptParameters | None = None) -> "Prompt":
|
||||||
"""异步获取提示模板"""
|
"""
|
||||||
|
异步获取提示模板,并动态注入插件内容
|
||||||
|
"""
|
||||||
|
original_prompt = None
|
||||||
context_prompt = await self._context.get_prompt_async(name)
|
context_prompt = await self._context.get_prompt_async(name)
|
||||||
if context_prompt is not None:
|
if context_prompt is not None:
|
||||||
logger.debug(f"从上下文中获取提示词: {name} {context_prompt}")
|
logger.debug(f"从上下文中获取提示词: {name} {context_prompt}")
|
||||||
return context_prompt
|
original_prompt = context_prompt
|
||||||
|
elif name in self._prompts:
|
||||||
if name not in self._prompts:
|
original_prompt = self._prompts[name]
|
||||||
|
else:
|
||||||
raise KeyError(f"Prompt '{name}' not found")
|
raise KeyError(f"Prompt '{name}' not found")
|
||||||
return self._prompts[name]
|
|
||||||
|
# 动态注入插件内容
|
||||||
|
if original_prompt.name:
|
||||||
|
# 确保我们有有效的parameters实例
|
||||||
|
params_for_injection = parameters or original_prompt.parameters
|
||||||
|
|
||||||
|
components_prefix = await prompt_component_manager.execute_components_for(
|
||||||
|
injection_point=original_prompt.name, params=params_for_injection
|
||||||
|
)
|
||||||
|
logger.info(components_prefix)
|
||||||
|
if components_prefix:
|
||||||
|
logger.info(f"为'{name}'注入插件内容: \n{components_prefix}")
|
||||||
|
# 创建一个新的临时Prompt实例,不进行注册
|
||||||
|
new_template = f"{components_prefix}\n\n{original_prompt.template}"
|
||||||
|
temp_prompt = Prompt(
|
||||||
|
template=new_template,
|
||||||
|
name=original_prompt.name,
|
||||||
|
parameters=original_prompt.parameters,
|
||||||
|
should_register=False, # 确保不重新注册
|
||||||
|
)
|
||||||
|
return temp_prompt
|
||||||
|
|
||||||
|
return original_prompt
|
||||||
|
|
||||||
def generate_name(self, template: str) -> str:
|
def generate_name(self, template: str) -> str:
|
||||||
"""为未命名的prompt生成名称"""
|
"""为未命名的prompt生成名称"""
|
||||||
@@ -142,7 +168,9 @@ class PromptManager:
|
|||||||
|
|
||||||
async def format_prompt(self, name: str, **kwargs) -> str:
|
async def format_prompt(self, name: str, **kwargs) -> str:
|
||||||
"""格式化提示模板"""
|
"""格式化提示模板"""
|
||||||
prompt = await self.get_prompt_async(name)
|
# 提取parameters用于注入
|
||||||
|
parameters = kwargs.get("parameters")
|
||||||
|
prompt = await self.get_prompt_async(name, parameters=parameters)
|
||||||
result = prompt.format(**kwargs)
|
result = prompt.format(**kwargs)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -230,23 +258,14 @@ class Prompt:
|
|||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
try:
|
try:
|
||||||
# 1. 从组件管理器获取注入内容
|
# 1. 构建核心上下文数据
|
||||||
components_prefix = ""
|
|
||||||
if self.name:
|
|
||||||
components_prefix = await prompt_component_manager.execute_components_for(
|
|
||||||
injection_point=self.name, params=self.parameters
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. 构建核心上下文数据
|
|
||||||
context_data = await self._build_context_data()
|
context_data = await self._build_context_data()
|
||||||
|
|
||||||
# 3. 格式化主模板
|
# 2. 格式化主模板
|
||||||
main_formatted_prompt = await self._format_with_context(context_data)
|
main_formatted_prompt = await self._format_with_context(context_data)
|
||||||
|
|
||||||
# 4. 拼接组件内容和主模板内容
|
# 3. 拼接组件内容和主模板内容 (逻辑已前置到 get_prompt_async)
|
||||||
result = main_formatted_prompt
|
result = main_formatted_prompt
|
||||||
if components_prefix:
|
|
||||||
result = f"{components_prefix}\n\n{main_formatted_prompt}"
|
|
||||||
|
|
||||||
total_time = time.time() - start_time
|
total_time = time.time() - start_time
|
||||||
logger.debug(f"Prompt构建完成,模式: {self.parameters.prompt_mode}, 耗时: {total_time:.2f}s")
|
logger.debug(f"Prompt构建完成,模式: {self.parameters.prompt_mode}, 耗时: {total_time:.2f}s")
|
||||||
@@ -406,9 +425,13 @@ class Prompt:
|
|||||||
if not self.parameters.message_list_before_now_long:
|
if not self.parameters.message_list_before_now_long:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
target_user_id = ""
|
||||||
|
if self.parameters.target_user_info:
|
||||||
|
target_user_id = self.parameters.target_user_info.get("user_id") or ""
|
||||||
|
|
||||||
read_history_prompt, unread_history_prompt = await self._build_s4u_chat_history_prompts(
|
read_history_prompt, unread_history_prompt = await self._build_s4u_chat_history_prompts(
|
||||||
self.parameters.message_list_before_now_long,
|
self.parameters.message_list_before_now_long,
|
||||||
self.parameters.target_user_info.get("user_id") if self.parameters.target_user_info else "",
|
target_user_id,
|
||||||
self.parameters.sender,
|
self.parameters.sender,
|
||||||
self.parameters.chat_id,
|
self.parameters.chat_id,
|
||||||
)
|
)
|
||||||
@@ -434,11 +457,14 @@ class Prompt:
|
|||||||
|
|
||||||
# 创建临时生成器实例来使用其方法
|
# 创建临时生成器实例来使用其方法
|
||||||
temp_generator = await get_replyer(None, chat_id, request_type="prompt_building")
|
temp_generator = await get_replyer(None, chat_id, request_type="prompt_building")
|
||||||
return await temp_generator.build_s4u_chat_history_prompts(
|
if temp_generator:
|
||||||
message_list_before_now, target_user_id, sender, chat_id
|
return await temp_generator.build_s4u_chat_history_prompts(
|
||||||
)
|
message_list_before_now, target_user_id, sender, chat_id
|
||||||
|
)
|
||||||
|
return "", ""
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"构建S4U历史消息prompt失败: {e}")
|
logger.error(f"构建S4U历史消息prompt失败: {e}")
|
||||||
|
return "", ""
|
||||||
|
|
||||||
async def _build_expression_habits(self) -> dict[str, Any]:
|
async def _build_expression_habits(self) -> dict[str, Any]:
|
||||||
"""构建表达习惯"""
|
"""构建表达习惯"""
|
||||||
@@ -525,10 +551,10 @@ class Prompt:
|
|||||||
running_memories, instant_memory = await asyncio.gather(*memory_tasks, return_exceptions=True)
|
running_memories, instant_memory = await asyncio.gather(*memory_tasks, return_exceptions=True)
|
||||||
|
|
||||||
# 处理可能的异常结果
|
# 处理可能的异常结果
|
||||||
if isinstance(running_memories, Exception):
|
if isinstance(running_memories, BaseException):
|
||||||
logger.warning(f"长期记忆查询失败: {running_memories}")
|
logger.warning(f"长期记忆查询失败: {running_memories}")
|
||||||
running_memories = []
|
running_memories = []
|
||||||
if isinstance(instant_memory, Exception):
|
if isinstance(instant_memory, BaseException):
|
||||||
logger.warning(f"即时记忆查询失败: {instant_memory}")
|
logger.warning(f"即时记忆查询失败: {instant_memory}")
|
||||||
instant_memory = None
|
instant_memory = None
|
||||||
|
|
||||||
@@ -1042,8 +1068,24 @@ def create_prompt(
|
|||||||
async def create_prompt_async(
|
async def create_prompt_async(
|
||||||
template: str, name: str | None = None, parameters: PromptParameters | None = None, **kwargs
|
template: str, name: str | None = None, parameters: PromptParameters | None = None, **kwargs
|
||||||
) -> Prompt:
|
) -> Prompt:
|
||||||
"""异步创建Prompt实例"""
|
"""异步创建Prompt实例,并动态注入插件内容"""
|
||||||
prompt = create_prompt(template, name, parameters, **kwargs)
|
# 确保有可用的parameters实例
|
||||||
|
final_params = parameters or PromptParameters(**kwargs)
|
||||||
|
|
||||||
|
# 动态注入插件内容
|
||||||
|
if name:
|
||||||
|
components_prefix = await prompt_component_manager.execute_components_for(
|
||||||
|
injection_point=name, params=final_params
|
||||||
|
)
|
||||||
|
if components_prefix:
|
||||||
|
logger.debug(f"为'{name}'注入插件内容: \n{components_prefix}")
|
||||||
|
template = f"{components_prefix}\n\n{template}"
|
||||||
|
|
||||||
|
# 使用可能已修改的模板创建实例
|
||||||
|
prompt = create_prompt(template, name, final_params)
|
||||||
|
|
||||||
|
# 如果在特定上下文中,则异步注册
|
||||||
if global_prompt_manager._context._current_context:
|
if global_prompt_manager._context._current_context:
|
||||||
await global_prompt_manager._context.register_async(prompt)
|
await global_prompt_manager._context.register_async(prompt)
|
||||||
|
|
||||||
return prompt
|
return prompt
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from collections import defaultdict
|
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
from src.chat.utils.prompt_params import PromptParameters
|
from src.chat.utils.prompt_params import PromptParameters
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.plugin_system.base.base_prompt import BasePrompt
|
from src.plugin_system.base.base_prompt import BasePrompt
|
||||||
from src.plugin_system.base.component_types import PromptInfo
|
from src.plugin_system.base.component_types import ComponentType, PromptInfo
|
||||||
|
from src.plugin_system.core.component_registry import component_registry
|
||||||
|
|
||||||
logger = get_logger("prompt_component_manager")
|
logger = get_logger("prompt_component_manager")
|
||||||
|
|
||||||
@@ -15,48 +15,11 @@ class PromptComponentManager:
|
|||||||
管理所有 `BasePrompt` 组件的单例类。
|
管理所有 `BasePrompt` 组件的单例类。
|
||||||
|
|
||||||
该管理器负责:
|
该管理器负责:
|
||||||
1. 注册由插件定义的 `BasePrompt` 子类。
|
1. 从 `component_registry` 中查询 `BasePrompt` 子类。
|
||||||
2. 根据注入点(目标Prompt名称)对它们进行分类存储。
|
2. 根据注入点(目标Prompt名称)对它们进行筛选。
|
||||||
3. 提供一个接口,以便在构建核心Prompt时,能够获取并执行所有相关的组件。
|
3. 提供一个接口,以便在构建核心Prompt时,能够获取并执行所有相关的组件。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self._registry: dict[str, list[Type[BasePrompt]]] = defaultdict(list)
|
|
||||||
self._prompt_infos: dict[str, PromptInfo] = {}
|
|
||||||
|
|
||||||
def register(self, component_class: Type[BasePrompt]):
|
|
||||||
"""
|
|
||||||
注册一个 `BasePrompt` 组件类。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
component_class: 要注册的 `BasePrompt` 子类。
|
|
||||||
"""
|
|
||||||
if not issubclass(component_class, BasePrompt):
|
|
||||||
logger.warning(f"尝试注册一个非 BasePrompt 的类: {component_class.__name__}")
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
prompt_info = component_class.get_prompt_info()
|
|
||||||
if prompt_info.name in self._prompt_infos:
|
|
||||||
logger.warning(f"重复注册 Prompt 组件: {prompt_info.name}。将覆盖旧组件。")
|
|
||||||
|
|
||||||
injection_points = prompt_info.injection_point
|
|
||||||
if isinstance(injection_points, str):
|
|
||||||
injection_points = [injection_points]
|
|
||||||
|
|
||||||
if not injection_points or not all(injection_points):
|
|
||||||
logger.debug(f"Prompt 组件 '{prompt_info.name}' 未指定有效的 injection_point,将不会被自动注入。")
|
|
||||||
return
|
|
||||||
|
|
||||||
for point in injection_points:
|
|
||||||
self._registry[point].append(component_class)
|
|
||||||
|
|
||||||
self._prompt_infos[prompt_info.name] = prompt_info
|
|
||||||
logger.info(f"成功注册 Prompt 组件 '{prompt_info.name}' 到注入点: {injection_points}")
|
|
||||||
|
|
||||||
except ValueError as e:
|
|
||||||
logger.error(f"注册 Prompt 组件失败 {component_class.__name__}: {e}")
|
|
||||||
|
|
||||||
def get_components_for(self, injection_point: str) -> list[Type[BasePrompt]]:
|
def get_components_for(self, injection_point: str) -> list[Type[BasePrompt]]:
|
||||||
"""
|
"""
|
||||||
获取指定注入点的所有已注册组件类。
|
获取指定注入点的所有已注册组件类。
|
||||||
@@ -67,7 +30,29 @@ class PromptComponentManager:
|
|||||||
Returns:
|
Returns:
|
||||||
list[Type[BasePrompt]]: 与该注入点关联的组件类列表。
|
list[Type[BasePrompt]]: 与该注入点关联的组件类列表。
|
||||||
"""
|
"""
|
||||||
return self._registry.get(injection_point, [])
|
# 从组件注册中心获取所有启用的Prompt组件
|
||||||
|
enabled_prompts = component_registry.get_enabled_components_by_type(ComponentType.PROMPT)
|
||||||
|
|
||||||
|
matching_components: list[Type[BasePrompt]] = []
|
||||||
|
|
||||||
|
for prompt_name, prompt_info in enabled_prompts.items():
|
||||||
|
# 确保 prompt_info 是 PromptInfo 类型
|
||||||
|
if not isinstance(prompt_info, PromptInfo):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 获取注入点信息
|
||||||
|
injection_points = prompt_info.injection_point
|
||||||
|
if isinstance(injection_points, str):
|
||||||
|
injection_points = [injection_points]
|
||||||
|
|
||||||
|
# 检查当前注入点是否匹配
|
||||||
|
if injection_point in injection_points:
|
||||||
|
# 获取组件类
|
||||||
|
component_class = component_registry.get_component_class(prompt_name, ComponentType.PROMPT)
|
||||||
|
if component_class and issubclass(component_class, BasePrompt):
|
||||||
|
matching_components.append(component_class)
|
||||||
|
|
||||||
|
return matching_components
|
||||||
|
|
||||||
async def execute_components_for(self, injection_point: str, params: PromptParameters) -> str:
|
async def execute_components_for(self, injection_point: str, params: PromptParameters) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -85,12 +70,13 @@ class PromptComponentManager:
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
tasks = []
|
tasks = []
|
||||||
from src.plugin_system.core.component_registry import component_registry
|
|
||||||
|
|
||||||
for component_class in component_classes:
|
for component_class in component_classes:
|
||||||
try:
|
try:
|
||||||
prompt_info = self._prompt_infos.get(component_class.prompt_name)
|
# 从注册中心获取组件信息
|
||||||
if not prompt_info:
|
prompt_info = component_registry.get_component_info(
|
||||||
|
component_class.prompt_name, ComponentType.PROMPT
|
||||||
|
)
|
||||||
|
if not isinstance(prompt_info, PromptInfo):
|
||||||
logger.warning(f"找不到 Prompt 组件 '{component_class.prompt_name}' 的信息,无法获取插件配置")
|
logger.warning(f"找不到 Prompt 组件 '{component_class.prompt_name}' 的信息,无法获取插件配置")
|
||||||
plugin_config = {}
|
plugin_config = {}
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -155,87 +155,22 @@ class ChatterPlanFilter:
|
|||||||
identity_block = f"你的名字是{bot_name}{bot_nickname},你{bot_core_personality}:"
|
identity_block = f"你的名字是{bot_name}{bot_nickname},你{bot_core_personality}:"
|
||||||
|
|
||||||
schedule_block = ""
|
schedule_block = ""
|
||||||
# 优先检查是否被吵醒
|
if global_config.planning_system.schedule_enable:
|
||||||
|
|
||||||
angry_prompt_addition = ""
|
|
||||||
try:
|
|
||||||
from src.plugins.built_in.sleep_system.api import get_wakeup_manager
|
|
||||||
wakeup_mgr = get_wakeup_manager()
|
|
||||||
except ImportError:
|
|
||||||
logger.debug("无法导入睡眠系统API,将跳过相关检查。")
|
|
||||||
wakeup_mgr = None
|
|
||||||
|
|
||||||
if wakeup_mgr:
|
|
||||||
|
|
||||||
# 双重检查确保愤怒状态不会丢失
|
|
||||||
# 检查1: 直接从 wakeup_manager 获取
|
|
||||||
if wakeup_mgr.is_in_angry_state():
|
|
||||||
angry_prompt_addition = wakeup_mgr.get_angry_prompt_addition()
|
|
||||||
|
|
||||||
# 检查2: 如果上面没获取到,再从 mood_manager 确认
|
|
||||||
if not angry_prompt_addition:
|
|
||||||
chat_mood_for_check = mood_manager.get_mood_by_chat_id(plan.chat_id)
|
|
||||||
if chat_mood_for_check.is_angry_from_wakeup:
|
|
||||||
angry_prompt_addition = global_config.sleep_system.angry_prompt
|
|
||||||
|
|
||||||
if angry_prompt_addition:
|
|
||||||
schedule_block = angry_prompt_addition
|
|
||||||
elif global_config.planning_system.schedule_enable:
|
|
||||||
if activity_info := schedule_manager.get_current_activity():
|
if activity_info := schedule_manager.get_current_activity():
|
||||||
activity = activity_info.get("activity", "未知活动")
|
activity = activity_info.get("activity", "未知活动")
|
||||||
schedule_block = f"你当前正在:{activity},但注意它与群聊的聊天无关。"
|
schedule_block = f"你当前正在:{activity},但注意它与群聊的聊天无关。"
|
||||||
|
|
||||||
mood_block = ""
|
mood_block = ""
|
||||||
# 如果被吵醒,则心情也是愤怒的,不需要另外的情绪模块
|
# 如果被吵醒,则心情也是愤怒的,不需要另外的情绪模块
|
||||||
if not angry_prompt_addition and global_config.mood.enable_mood:
|
if global_config.mood.enable_mood:
|
||||||
chat_mood = mood_manager.get_mood_by_chat_id(plan.chat_id)
|
chat_mood = mood_manager.get_mood_by_chat_id(plan.chat_id)
|
||||||
mood_block = f"你现在的心情是:{chat_mood.mood_state}"
|
mood_block = f"你现在的心情是:{chat_mood.mood_state}"
|
||||||
|
|
||||||
if plan.mode == ChatMode.PROACTIVE:
|
|
||||||
long_term_memory_block = await self._get_long_term_memory_context()
|
|
||||||
|
|
||||||
chat_content_block, message_id_list = await build_readable_messages_with_id(
|
|
||||||
messages=[msg.flatten() for msg in plan.chat_history],
|
|
||||||
timestamp_mode="normal",
|
|
||||||
truncate=False,
|
|
||||||
show_actions=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt_template = await global_prompt_manager.get_prompt_async("proactive_planner_prompt")
|
|
||||||
actions_before_now = await get_actions_by_timestamp_with_chat(
|
|
||||||
chat_id=plan.chat_id,
|
|
||||||
timestamp_start=time.time() - 3600,
|
|
||||||
timestamp_end=time.time(),
|
|
||||||
limit=5,
|
|
||||||
)
|
|
||||||
actions_before_now_block = build_readable_actions(actions=actions_before_now)
|
|
||||||
actions_before_now_block = f"你刚刚选择并执行过的action是:\n{actions_before_now_block}"
|
|
||||||
|
|
||||||
prompt = prompt_template.format(
|
|
||||||
time_block=time_block,
|
|
||||||
identity_block=identity_block,
|
|
||||||
schedule_block=schedule_block,
|
|
||||||
mood_block=mood_block,
|
|
||||||
long_term_memory_block=long_term_memory_block,
|
|
||||||
chat_content_block=chat_content_block or "最近没有聊天内容。",
|
|
||||||
actions_before_now_block=actions_before_now_block,
|
|
||||||
)
|
|
||||||
return prompt, message_id_list
|
|
||||||
|
|
||||||
# 构建已读/未读历史消息
|
# 构建已读/未读历史消息
|
||||||
read_history_block, unread_history_block, message_id_list = await self._build_read_unread_history_blocks(
|
read_history_block, unread_history_block, message_id_list = await self._build_read_unread_history_blocks(
|
||||||
plan
|
plan
|
||||||
)
|
)
|
||||||
|
|
||||||
# 为了兼容性,保留原有的chat_content_block
|
|
||||||
chat_content_block, _ = await build_readable_messages_with_id(
|
|
||||||
messages=[msg.flatten() for msg in plan.chat_history],
|
|
||||||
timestamp_mode="normal",
|
|
||||||
read_mark=self.last_obs_time_mark,
|
|
||||||
truncate=True,
|
|
||||||
show_actions=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
actions_before_now = await get_actions_by_timestamp_with_chat(
|
actions_before_now = await get_actions_by_timestamp_with_chat(
|
||||||
chat_id=plan.chat_id,
|
chat_id=plan.chat_id,
|
||||||
timestamp_start=time.time() - 3600,
|
timestamp_start=time.time() - 3600,
|
||||||
|
|||||||
Reference in New Issue
Block a user