refactor(prompt): 优化任务并发执行逻辑,使用 as_completed 提升性能和错误处理

This commit is contained in:
Windpicker-owo
2025-11-12 21:19:19 +08:00
parent 61d86875ad
commit 17abfc74ae

View File

@@ -24,7 +24,6 @@ install(extra_lines=3)
logger = get_logger("unified_prompt")
class PromptContext:
"""提示词上下文管理器.
@@ -41,7 +40,9 @@ class PromptContext:
# _current_context_var: 使用contextvars来存储当前协程的上下文ID。
# 这确保了在并发执行的异步任务中每个任务都能访问到正确的上下文ID。
self._current_context_var = contextvars.ContextVar("current_context", default=None)
self._current_context_var = contextvars.ContextVar(
"current_context", default=None
)
# _context_lock: 一个异步锁用于保护对共享资源_context_prompts的并发访问。
self._context_lock = asyncio.Lock()
@@ -120,7 +121,9 @@ class PromptContext:
return self._context_prompts[current_context][name]
return None
async def register_async(self, prompt: "Prompt", context_id: str | None = None) -> None:
async def register_async(
self, prompt: "Prompt", context_id: str | None = None
) -> None:
"""异步、安全地将提示模板注册到指定的作用域.
如果未指定context_id则注册到当前协程的上下文中。
@@ -130,7 +133,9 @@ class PromptContext:
if target_context := context_id or self._current_context:
if prompt.name:
# 使用setdefault确保目标上下文的字典存在然后注册prompt
self._context_prompts.setdefault(target_context, {})[prompt.name] = prompt
self._context_prompts.setdefault(target_context, {})[
prompt.name
] = prompt
class PromptManager:
@@ -158,7 +163,9 @@ class PromptManager:
async with self._context.async_scope(message_id):
yield self
async def get_prompt_async(self, name: str, parameters: PromptParameters | None = None) -> "Prompt":
async def get_prompt_async(
self, name: str, parameters: PromptParameters | None = None
) -> "Prompt":
"""异步获取提示模板,并动态地将插件内容注入其中.
获取提示词的优先级顺序为:
@@ -309,12 +316,16 @@ class Prompt:
elif not isinstance(template, str):
template = str(template)
return template.replace("\\{", Prompt._TEMP_LEFT_BRACE).replace("\\}", Prompt._TEMP_RIGHT_BRACE)
return template.replace("\\{", Prompt._TEMP_LEFT_BRACE).replace(
"\\}", Prompt._TEMP_RIGHT_BRACE
)
@staticmethod
def _restore_escaped_braces(template: str) -> str:
"""在格式化完成后,将临时标记还原为实际的花括号字符 `{` 和 `}`."""
return template.replace(Prompt._TEMP_LEFT_BRACE, "{").replace(Prompt._TEMP_RIGHT_BRACE, "}")
return template.replace(Prompt._TEMP_LEFT_BRACE, "{").replace(
Prompt._TEMP_RIGHT_BRACE, "}"
)
def _parse_template_args(self, template: str) -> list[str]:
"""从模板字符串中解析出所有占位符(例如 "{user_name}" -> "user_name"."""
@@ -365,7 +376,9 @@ class Prompt:
result = main_formatted_prompt
total_time = time.time() - start_time
logger.debug(f"Prompt构建完成模式: {self.parameters.prompt_mode}, 耗时: {total_time:.2f}s")
logger.debug(
f"Prompt构建完成模式: {self.parameters.prompt_mode}, 耗时: {total_time:.2f}s"
)
# 缓存结果
self._formatted_result = result
@@ -410,9 +423,13 @@ class Prompt:
# 如果参数对象中已经包含了某些block说明它们是外部预构建的
# 我们将它们存起来,并跳过对应的实时构建任务。
if self.parameters.expression_habits_block:
pre_built_params["expression_habits_block"] = self.parameters.expression_habits_block
pre_built_params["expression_habits_block"] = (
self.parameters.expression_habits_block
)
if self.parameters.relation_info_block:
pre_built_params["relation_info_block"] = self.parameters.relation_info_block
pre_built_params["relation_info_block"] = (
self.parameters.relation_info_block
)
if self.parameters.memory_block:
pre_built_params["memory_block"] = self.parameters.memory_block
logger.debug("使用预构建的memory_block跳过实时构建")
@@ -421,12 +438,16 @@ class Prompt:
if self.parameters.knowledge_prompt:
pre_built_params["knowledge_prompt"] = self.parameters.knowledge_prompt
if self.parameters.cross_context_block:
pre_built_params["cross_context_block"] = self.parameters.cross_context_block
pre_built_params["cross_context_block"] = (
self.parameters.cross_context_block
)
if self.parameters.notice_block:
pre_built_params["notice_block"] = self.parameters.notice_block
# --- 步骤 1.2: 根据参数和预构建情况,决定需要实时运行的任务 ---
if self.parameters.enable_expression and not pre_built_params.get("expression_habits_block"):
if self.parameters.enable_expression and not pre_built_params.get(
"expression_habits_block"
):
tasks.append(self._build_expression_habits())
task_names.append("expression_habits")
@@ -434,19 +455,27 @@ class Prompt:
# 使用新的记忆图系统,不再在 prompt.py 中构建记忆
# 如果需要记忆,必须通过 pre_built_params 传入
if self.parameters.enable_relation and not pre_built_params.get("relation_info_block"):
if self.parameters.enable_relation and not pre_built_params.get(
"relation_info_block"
):
tasks.append(self._build_relation_info())
task_names.append("relation_info")
if self.parameters.enable_tool and not pre_built_params.get("tool_info_block"):
if self.parameters.enable_tool and not pre_built_params.get(
"tool_info_block"
):
tasks.append(self._build_tool_info())
task_names.append("tool_info")
if self.parameters.enable_knowledge and not pre_built_params.get("knowledge_prompt"):
if self.parameters.enable_knowledge and not pre_built_params.get(
"knowledge_prompt"
):
tasks.append(self._build_knowledge_info())
task_names.append("knowledge_info")
if self.parameters.enable_cross_context and not pre_built_params.get("cross_context_block"):
if self.parameters.enable_cross_context and not pre_built_params.get(
"cross_context_block"
):
tasks.append(self._build_cross_context())
task_names.append("cross_context")
@@ -463,32 +492,61 @@ class Prompt:
"expression_habits": 10.0,
}
results = []
# 循环等待每个任务,而不是使用`asyncio.gather`
# 这样可以为每个任务应用独立的超时控制。
# 使用 as_completed 并发执行任务,提供更好的性能和错误处理
results = [None] * len(tasks) # 预分配结果列表,保持任务顺序
task_with_meta = []
# 准备任务和元数据
for i, task in enumerate(tasks):
task_name = task_names[i] if i < len(task_names) else f"task_{i}"
task_timeout = task_timeouts.get(task_name, 2.0) # 未指定超时的任务默认为2秒
task_timeout = task_timeouts.get(
task_name, 2.0
) # 未指定超时的任务默认为2秒
# 检查任务是否为协程,非协程任务直接使用默认值
if asyncio.iscoroutine(task):
task_with_meta.append(
(
asyncio.wait_for(task, timeout=task_timeout),
task_name,
i,
task_timeout,
)
)
else:
logger.warning(
f"任务{task_name}不是协程对象,类型: {type(task)},跳过处理"
)
results[i] = self._get_default_result_for_task(task_name) # type: ignore
# 并发执行任务,使用 as_completed 获得更好的性能
for future in asyncio.as_completed(
[task_meta[0] for task_meta in task_with_meta]
):
# 找到对应的任务元数据
task_index = None
task_name = None
task_timeout = None
for idx, (task, name, index, timeout) in enumerate(task_with_meta):
if task == future:
task_index = index
task_name = name
task_timeout = timeout
break
try:
# 确保任务是可等待的协程
if asyncio.iscoroutine(task):
# 使用 `asyncio.wait_for` 来执行带超时的任务
result = await asyncio.wait_for(task, timeout=task_timeout)
results.append(result)
logger.debug(f"构建任务{task_name}完成 ({task_timeout}s)")
else:
# 如果任务不是协程,记录警告并使用默认值
logger.warning(f"任务{task_name}不是协程对象,类型: {type(task)},跳过处理")
results.append(self._get_default_result_for_task(task_name))
result = await future
results[task_index] = result # type: ignore
logger.debug(f"构建任务{task_name}完成 ({task_timeout}s)")
except asyncio.TimeoutError:
# 如果任务超时,记录警告并使用默认值
logger.warning(f"构建任务{task_name}超时 ({task_timeout}s),使用默认值")
results.append(self._get_default_result_for_task(task_name))
logger.warning(
f"构建任务{task_name}超时 ({task_timeout}s),使用默认值"
)
results[task_index] = self._get_default_result_for_task(task_name) # type: ignore
except Exception as e:
# 如果任务执行出错,记录错误并使用默认值
logger.error(f"构建任务{task_name}失败: {e!s}")
results.append(self._get_default_result_for_task(task_name))
results[task_index] = self._get_default_result_for_task(task_name) # type: ignore
# --- 步骤 3: 合并所有结果 ---
context_data = {}
@@ -501,7 +559,9 @@ class Prompt:
context_data.update(result)
# 合并预构建的参数,这会覆盖任何同名的实时构建结果
context_data.update({key: value for key, value in pre_built_params.items() if value})
context_data.update(
{key: value for key, value in pre_built_params.items() if value}
)
except asyncio.TimeoutError:
# 这是一个不太可能发生的、总体的构建超时,作为最后的保障
@@ -522,7 +582,8 @@ class Prompt:
{
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt,
"extra_info_block": self.parameters.extra_info_block,
"time_block": self.parameters.time_block or f"当前时间:{time.strftime('%Y-%m-%d %H:%M:%S')}",
"time_block": self.parameters.time_block
or f"当前时间:{time.strftime('%Y-%m-%d %H:%M:%S')}",
"identity": self.parameters.identity_block,
"schedule_block": self.parameters.schedule_block,
"moderation_prompt": self.parameters.moderation_prompt_block,
@@ -549,20 +610,25 @@ class Prompt:
target_user_id = self.parameters.target_user_info.get("user_id") or ""
# 调用核心构建逻辑
read_history_prompt, unread_history_prompt = await self._build_s4u_chat_history_prompts(
self.parameters.message_list_before_now_long,
target_user_id,
self.parameters.sender,
self.parameters.chat_id,
read_history_prompt, unread_history_prompt = (
await self._build_s4u_chat_history_prompts(
self.parameters.message_list_before_now_long,
target_user_id,
self.parameters.sender,
self.parameters.chat_id,
)
)
# 将构建好的prompt添加到上下文数据中
context_data["read_history_prompt"] = read_history_prompt
context_data["unread_history_prompt"] = unread_history_prompt
async def _build_s4u_chat_history_prompts(
self, message_list_before_now: list[dict[str, Any]], target_user_id: str, sender: str, chat_id: str
self,
message_list_before_now: list[dict[str, Any]],
target_user_id: str,
sender: str,
chat_id: str,
) -> tuple[str, str]:
"""构建S4U风格的已读/未读历史消息prompt.
@@ -574,7 +640,9 @@ class Prompt:
from src.plugin_system.apis.generator_api import get_replyer
# 获取一个临时的生成器实例来访问其方法
temp_generator = await get_replyer(None, chat_id, request_type="prompt_building")
temp_generator = await get_replyer(
None, chat_id, request_type="prompt_building"
)
if temp_generator:
# 调用实际的构建方法
return await temp_generator.build_s4u_chat_history_prompts(
@@ -588,7 +656,9 @@ class Prompt:
async def _build_expression_habits(self) -> dict[str, Any]:
"""构建表达习惯(如表情、口癖)的上下文块."""
# 检查当前聊天是否启用了表达习惯功能
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(self.parameters.chat_id)
use_expression, _, _ = global_config.expression.get_expression_config_for_chat(
self.parameters.chat_id
)
if not use_expression:
return {"expression_habits_block": ""}
@@ -601,15 +671,20 @@ class Prompt:
if self.parameters.message_list_before_now_long:
recent_messages = self.parameters.message_list_before_now_long[-10:]
chat_history = await build_readable_messages(
recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True
recent_messages,
replace_bot_name=True,
timestamp_mode="normal",
truncate=True,
)
# 使用统一的表达方式选择入口支持classic和exp_model模式
expression_selector = ExpressionSelector(self.parameters.chat_id)
selected_expressions = await expression_selector.select_suitable_expressions(
chat_id=self.parameters.chat_id,
chat_history=chat_history,
target_message=self.parameters.target,
selected_expressions = (
await expression_selector.select_suitable_expressions(
chat_id=self.parameters.chat_id,
chat_history=chat_history,
target_message=self.parameters.target,
)
)
# 将选择的表达习惯格式化为提示词的一部分
@@ -641,7 +716,9 @@ class Prompt:
"""构建与对话目标相关的关系信息."""
try:
# 调用静态方法来执行实际的构建逻辑
relation_info = await Prompt.build_relation_info(self.parameters.chat_id, self.parameters.reply_to)
relation_info = await Prompt.build_relation_info(
self.parameters.chat_id, self.parameters.reply_to
)
return {"relation_info_block": relation_info}
except Exception as e:
logger.error(f"构建关系信息失败: {e}")
@@ -660,7 +737,10 @@ class Prompt:
if self.parameters.message_list_before_now_long:
recent_messages = self.parameters.message_list_before_now_long[-15:]
chat_history = await build_readable_messages(
recent_messages, replace_bot_name=True, timestamp_mode="normal", truncate=True
recent_messages,
replace_bot_name=True,
timestamp_mode="normal",
truncate=True,
)
# 决定是否调用工具并执行
@@ -682,7 +762,9 @@ class Prompt:
tool_info_parts.append(f"- 【{tool_name}{result_type}: {content}")
tool_info_parts.append("以上是你获取到的实时信息,请在回复时参考这些信息。")
tool_info_parts.append(
"以上是你获取到的实时信息,请在回复时参考这些信息。"
)
tool_info_block = "\n".join(tool_info_parts)
else:
tool_info_block = ""
@@ -710,7 +792,10 @@ class Prompt:
# 将检索结果格式化为提示词
if knowledge_results and knowledge_results.get("knowledge_items"):
knowledge_parts = ["## 知识库信息", "以下是与你当前对话相关的知识信息:"]
knowledge_parts = [
"## 知识库信息",
"以下是与你当前对话相关的知识信息:",
]
for item in knowledge_results["knowledge_items"]:
content = item.get("content", "")
@@ -720,20 +805,27 @@ class Prompt:
# 过滤掉相关性低于阈值的知识
try:
relevance_float = float(relevance)
if relevance_float < global_config.lpmm_knowledge.qa_paragraph_threshold:
if (
relevance_float
< global_config.lpmm_knowledge.qa_paragraph_threshold
):
continue
relevance_str = f"{relevance_float:.2f}"
except (ValueError, TypeError):
relevance_str = str(relevance)
if source:
knowledge_parts.append(f"- [{relevance_str}] {content} (来源: {source})")
knowledge_parts.append(
f"- [{relevance_str}] {content} (来源: {source})"
)
else:
knowledge_parts.append(f"- [{relevance_str}] {content}")
# 如果有总结,也一并加入
if knowledge_results.get("summary"):
knowledge_parts.append(f"\n知识总结: {knowledge_results['summary']}")
knowledge_parts.append(
f"\n知识总结: {knowledge_results['summary']}"
)
knowledge_prompt = "\n".join(knowledge_parts)
else:
@@ -750,7 +842,9 @@ class Prompt:
try:
# 调用静态方法来执行实际的构建逻辑
cross_context = await Prompt.build_cross_context(
self.parameters.chat_id, self.parameters.prompt_mode, self.parameters.target_user_info
self.parameters.chat_id,
self.parameters.prompt_mode,
self.parameters.target_user_info,
)
return {"cross_context_block": cross_context}
except Exception as e:
@@ -769,7 +863,11 @@ class Prompt:
params = self._prepare_default_params(context_data)
# 如果prompt有名称则通过全局管理器格式化这样可以应用注入逻辑否则直接格式化
return await global_prompt_manager.format_prompt(self.name, **params) if self.name else self.format(**params)
return (
await global_prompt_manager.format_prompt(self.name, **params)
if self.name
else self.format(**params)
)
def _prepare_s4u_params(self, context_data: dict[str, Any]) -> dict[str, Any]:
"""为S4UScene for You模式准备最终用于格式化的参数字典."""
@@ -780,14 +878,20 @@ class Prompt:
"knowledge_prompt": context_data.get("knowledge_prompt", ""),
"memory_block": context_data.get("memory_block", ""),
"relation_info_block": context_data.get("relation_info_block", ""),
"extra_info_block": self.parameters.extra_info_block or context_data.get("extra_info_block", ""),
"extra_info_block": self.parameters.extra_info_block
or context_data.get("extra_info_block", ""),
"cross_context_block": context_data.get("cross_context_block", ""),
"notice_block": self.parameters.notice_block or context_data.get("notice_block", ""),
"identity": self.parameters.identity_block or context_data.get("identity", ""),
"action_descriptions": self.parameters.action_descriptions or context_data.get("action_descriptions", ""),
"schedule_block": self.parameters.schedule_block or context_data.get("schedule_block", ""),
"notice_block": self.parameters.notice_block
or context_data.get("notice_block", ""),
"identity": self.parameters.identity_block
or context_data.get("identity", ""),
"action_descriptions": self.parameters.action_descriptions
or context_data.get("action_descriptions", ""),
"schedule_block": self.parameters.schedule_block
or context_data.get("schedule_block", ""),
"sender_name": self.parameters.sender or "未知用户",
"mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""),
"mood_state": self.parameters.mood_prompt
or context_data.get("mood_state", ""),
"read_history_prompt": context_data.get("read_history_prompt", ""),
"unread_history_prompt": context_data.get("unread_history_prompt", ""),
"time_block": context_data.get("time_block", ""),
@@ -795,7 +899,8 @@ class Prompt:
"reply_style": global_config.personality.reply_style,
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt
or context_data.get("keywords_reaction_prompt", ""),
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
"moderation_prompt": self.parameters.moderation_prompt_block
or context_data.get("moderation_prompt", ""),
"safety_guidelines_block": self.parameters.safety_guidelines_block
or context_data.get("safety_guidelines_block", ""),
"auth_role_prompt_block": self.parameters.auth_role_prompt_block
@@ -813,22 +918,29 @@ class Prompt:
"knowledge_prompt": context_data.get("knowledge_prompt", ""),
"memory_block": context_data.get("memory_block", ""),
"relation_info_block": context_data.get("relation_info_block", ""),
"extra_info_block": self.parameters.extra_info_block or context_data.get("extra_info_block", ""),
"extra_info_block": self.parameters.extra_info_block
or context_data.get("extra_info_block", ""),
"cross_context_block": context_data.get("cross_context_block", ""),
"notice_block": self.parameters.notice_block or context_data.get("notice_block", ""),
"identity": self.parameters.identity_block or context_data.get("identity", ""),
"action_descriptions": self.parameters.action_descriptions or context_data.get("action_descriptions", ""),
"schedule_block": self.parameters.schedule_block or context_data.get("schedule_block", ""),
"notice_block": self.parameters.notice_block
or context_data.get("notice_block", ""),
"identity": self.parameters.identity_block
or context_data.get("identity", ""),
"action_descriptions": self.parameters.action_descriptions
or context_data.get("action_descriptions", ""),
"schedule_block": self.parameters.schedule_block
or context_data.get("schedule_block", ""),
"time_block": context_data.get("time_block", ""),
"chat_info": context_data.get("chat_info", ""),
"reply_target_block": context_data.get("reply_target_block", ""),
"reply_style": global_config.personality.reply_style,
"mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""),
"mood_state": self.parameters.mood_prompt
or context_data.get("mood_state", ""),
"read_history_prompt": context_data.get("read_history_prompt", ""),
"unread_history_prompt": context_data.get("unread_history_prompt", ""),
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt
or context_data.get("keywords_reaction_prompt", ""),
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
"moderation_prompt": self.parameters.moderation_prompt_block
or context_data.get("moderation_prompt", ""),
"safety_guidelines_block": self.parameters.safety_guidelines_block
or context_data.get("safety_guidelines_block", ""),
"auth_role_prompt_block": self.parameters.auth_role_prompt_block
@@ -847,17 +959,21 @@ class Prompt:
"chat_target": "",
"time_block": context_data.get("time_block", ""),
"chat_info": context_data.get("chat_info", ""),
"identity": self.parameters.identity_block or context_data.get("identity", ""),
"schedule_block": self.parameters.schedule_block or context_data.get("schedule_block", ""),
"identity": self.parameters.identity_block
or context_data.get("identity", ""),
"schedule_block": self.parameters.schedule_block
or context_data.get("schedule_block", ""),
"chat_target_2": "",
"reply_target_block": context_data.get("reply_target_block", ""),
"raw_reply": self.parameters.target,
"reason": "",
"mood_state": self.parameters.mood_prompt or context_data.get("mood_state", ""),
"mood_state": self.parameters.mood_prompt
or context_data.get("mood_state", ""),
"reply_style": global_config.personality.reply_style,
"keywords_reaction_prompt": self.parameters.keywords_reaction_prompt
or context_data.get("keywords_reaction_prompt", ""),
"moderation_prompt": self.parameters.moderation_prompt_block or context_data.get("moderation_prompt", ""),
"moderation_prompt": self.parameters.moderation_prompt_block
or context_data.get("moderation_prompt", ""),
"safety_guidelines_block": self.parameters.safety_guidelines_block
or context_data.get("safety_guidelines_block", ""),
"auth_role_prompt_block": self.parameters.auth_role_prompt_block
@@ -902,7 +1018,9 @@ class Prompt:
return result
except (IndexError, KeyError) as e:
# 捕获格式化错误并抛出更具信息量的异常
raise ValueError(f"格式化模板失败: {self.template}, args={args}, kwargs={kwargs} {e!s}") from e
raise ValueError(
f"格式化模板失败: {self.template}, args={args}, kwargs={kwargs} {e!s}"
) from e
def __str__(self) -> str:
"""返回格式化后的结果,如果还未格式化,则返回原始模板."""
@@ -976,8 +1094,12 @@ class Prompt:
return f"你完全不认识{sender}不理解ta的相关信息。"
# 使用关系提取器构建用户关系信息和聊天流印象
user_relation_info = await relationship_fetcher.build_relation_info(person_id, points_num=5)
stream_impression = await relationship_fetcher.build_chat_stream_impression(chat_id)
user_relation_info = await relationship_fetcher.build_relation_info(
person_id, points_num=5
)
stream_impression = await relationship_fetcher.build_chat_stream_impression(
chat_id
)
# 组合两部分信息
info_parts = []
@@ -1016,7 +1138,9 @@ class Prompt:
return {}
@staticmethod
async def build_cross_context(chat_id: str, prompt_mode: str, target_user_info: dict[str, Any] | None) -> str:
async def build_cross_context(
chat_id: str, prompt_mode: str, target_user_info: dict[str, Any] | None
) -> str:
"""构建跨群聊的上下文信息.
Args:
@@ -1039,7 +1163,9 @@ class Prompt:
# 目前只为s4u模式构建跨群上下文
if prompt_mode == "s4u":
return await cross_context_api.build_cross_context_s4u(chat_stream, target_user_info)
return await cross_context_api.build_cross_context_s4u(
chat_stream, target_user_info
)
return ""
@@ -1073,7 +1199,10 @@ class Prompt:
# 工厂函数
def create_prompt(
template: str, name: str | None = None, parameters: PromptParameters | None = None, **kwargs
template: str,
name: str | None = None,
parameters: PromptParameters | None = None,
**kwargs,
) -> Prompt:
"""一个用于快速创建`Prompt`实例的工厂函数.
@@ -1094,7 +1223,10 @@ def create_prompt(
async def create_prompt_async(
template: str, name: str | None = None, parameters: PromptParameters | None = None, **kwargs
template: str,
name: str | None = None,
parameters: PromptParameters | None = None,
**kwargs,
) -> Prompt:
"""异步创建`Prompt`实例,并自动处理插件内容的动态注入.