feat(attention): 添加注意力优化器以增强提示词多样性和防止注意力退化
refactor(prompt): 使用 asyncio.gather 替代 as_completed 以提升并发性能 refactor(config): 添加注意力优化配置选项 refactor(prompt_params): 增加注意力优化开关
This commit is contained in:
@@ -375,6 +375,15 @@ class Prompt:
|
||||
# 这样做可以更早地组合模板,也使得`Prompt`类的职责更单一。
|
||||
result = main_formatted_prompt
|
||||
|
||||
# 步骤 4: 注意力优化(如果启用)
|
||||
# 通过轻量级随机化避免提示词过度相似导致LLM注意力退化
|
||||
if self.parameters.enable_attention_optimization:
|
||||
from src.chat.utils.attention_optimizer import get_attention_optimizer
|
||||
|
||||
optimizer = get_attention_optimizer()
|
||||
result = optimizer.optimize_prompt(result, context_data)
|
||||
logger.debug("已应用注意力优化")
|
||||
|
||||
total_time = time.time() - start_time
|
||||
logger.debug(
|
||||
f"Prompt构建完成,模式: {self.parameters.prompt_mode}, 耗时: {total_time:.2f}s"
|
||||
@@ -492,11 +501,12 @@ class Prompt:
|
||||
"expression_habits": 10.0,
|
||||
}
|
||||
|
||||
# 使用 as_completed 并发执行任务,提供更好的性能和错误处理
|
||||
# 使用 asyncio.gather 实现并发执行,提供更好的错误处理和性能
|
||||
results = [None] * len(tasks) # 预分配结果列表,保持任务顺序
|
||||
task_with_meta = []
|
||||
tasks_to_run = [] # 存储带超时的任务
|
||||
task_info = [] # 存储任务信息,用于结果处理
|
||||
|
||||
# 准备任务和元数据
|
||||
# 准备任务并创建带超时的协程
|
||||
for i, task in enumerate(tasks):
|
||||
task_name = task_names[i] if i < len(task_names) else f"task_{i}"
|
||||
task_timeout = task_timeouts.get(
|
||||
@@ -505,48 +515,41 @@ class Prompt:
|
||||
|
||||
# 检查任务是否为协程,非协程任务直接使用默认值
|
||||
if asyncio.iscoroutine(task):
|
||||
task_with_meta.append(
|
||||
(
|
||||
asyncio.wait_for(task, timeout=task_timeout),
|
||||
task_name,
|
||||
i,
|
||||
task_timeout,
|
||||
)
|
||||
)
|
||||
# 创建带超时的任务
|
||||
timeout_task = asyncio.wait_for(task, timeout=task_timeout)
|
||||
tasks_to_run.append(timeout_task)
|
||||
task_info.append({"index": i, "name": task_name, "timeout": task_timeout})
|
||||
else:
|
||||
logger.warning(
|
||||
f"任务{task_name}不是协程对象,类型: {type(task)},跳过处理"
|
||||
)
|
||||
results[i] = self._get_default_result_for_task(task_name) # type: ignore
|
||||
|
||||
# 并发执行任务,使用 as_completed 获得更好的性能
|
||||
for future in asyncio.as_completed(
|
||||
[task_meta[0] for task_meta in task_with_meta]
|
||||
):
|
||||
# 找到对应的任务元数据
|
||||
task_index = None
|
||||
task_name = None
|
||||
task_timeout = None
|
||||
# 使用 gather 并发执行所有任务,return_exceptions=True 确保单个任务失败不影响其他任务
|
||||
if tasks_to_run:
|
||||
task_results = await asyncio.gather(*tasks_to_run, return_exceptions=True)
|
||||
|
||||
for idx, (task, name, index, timeout) in enumerate(task_with_meta):
|
||||
if task == future:
|
||||
task_index = index
|
||||
task_name = name
|
||||
task_timeout = timeout
|
||||
break
|
||||
# 处理任务结果
|
||||
for i, result in enumerate(task_results):
|
||||
info = task_info[i]
|
||||
task_index = info["index"]
|
||||
task_name = info["name"]
|
||||
task_timeout = info["timeout"]
|
||||
|
||||
try:
|
||||
result = await future
|
||||
results[task_index] = result # type: ignore
|
||||
logger.debug(f"构建任务{task_name}完成 ({task_timeout}s)")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
f"构建任务{task_name}超时 ({task_timeout}s),使用默认值"
|
||||
)
|
||||
results[task_index] = self._get_default_result_for_task(task_name) # type: ignore
|
||||
except Exception as e:
|
||||
logger.error(f"构建任务{task_name}失败: {e!s}")
|
||||
results[task_index] = self._get_default_result_for_task(task_name) # type: ignore
|
||||
if isinstance(result, asyncio.TimeoutError):
|
||||
# 处理超时错误
|
||||
logger.warning(
|
||||
f"构建任务{task_name}超时 ({task_timeout}s),使用默认值"
|
||||
)
|
||||
results[task_index] = self._get_default_result_for_task(task_name)
|
||||
elif isinstance(result, Exception):
|
||||
# 处理其他异常
|
||||
logger.error(f"构建任务{task_name}失败: {result!s}")
|
||||
results[task_index] = self._get_default_result_for_task(task_name)
|
||||
else:
|
||||
# 成功完成
|
||||
results[task_index] = result
|
||||
logger.debug(f"构建任务{task_name}完成 ({task_timeout}s)")
|
||||
|
||||
# --- 步骤 3: 合并所有结果 ---
|
||||
context_data = {}
|
||||
|
||||
Reference in New Issue
Block a user