feat(attention): 添加注意力优化器以增强提示词多样性和防止注意力退化

refactor(prompt): 使用 asyncio.gather 替代 as_completed 以提升并发性能
refactor(config): 添加注意力优化配置选项
refactor(prompt_params): 增加注意力优化开关
This commit is contained in:
Windpicker-owo
2025-11-12 22:37:35 +08:00
parent c1cda89d65
commit 310256e24d
8 changed files with 420 additions and 48 deletions

View File

@@ -375,6 +375,15 @@ class Prompt:
# 这样做可以更早地组合模板,也使得`Prompt`类的职责更单一。
result = main_formatted_prompt
# 步骤 4: 注意力优化(如果启用)
# 通过轻量级随机化避免提示词过度相似导致LLM注意力退化
if self.parameters.enable_attention_optimization:
from src.chat.utils.attention_optimizer import get_attention_optimizer
optimizer = get_attention_optimizer()
result = optimizer.optimize_prompt(result, context_data)
logger.debug("已应用注意力优化")
total_time = time.time() - start_time
logger.debug(
f"Prompt构建完成模式: {self.parameters.prompt_mode}, 耗时: {total_time:.2f}s"
@@ -492,11 +501,12 @@ class Prompt:
"expression_habits": 10.0,
}
# 使用 as_completed 并发执行任务,提供更好的性能和错误处理
# 使用 asyncio.gather 实现并发执行,提供更好的错误处理和性能
results = [None] * len(tasks) # 预分配结果列表,保持任务顺序
task_with_meta = []
tasks_to_run = [] # 存储带超时的任务
task_info = [] # 存储任务信息,用于结果处理
# 准备任务和元数据
# 准备任务并创建带超时的协程
for i, task in enumerate(tasks):
task_name = task_names[i] if i < len(task_names) else f"task_{i}"
task_timeout = task_timeouts.get(
@@ -505,48 +515,41 @@ class Prompt:
# 检查任务是否为协程,非协程任务直接使用默认值
if asyncio.iscoroutine(task):
task_with_meta.append(
(
asyncio.wait_for(task, timeout=task_timeout),
task_name,
i,
task_timeout,
)
)
# 创建带超时的任务
timeout_task = asyncio.wait_for(task, timeout=task_timeout)
tasks_to_run.append(timeout_task)
task_info.append({"index": i, "name": task_name, "timeout": task_timeout})
else:
logger.warning(
f"任务{task_name}不是协程对象,类型: {type(task)},跳过处理"
)
results[i] = self._get_default_result_for_task(task_name) # type: ignore
# 并发执行任务,使用 as_completed 获得更好的性能
for future in asyncio.as_completed(
[task_meta[0] for task_meta in task_with_meta]
):
# 找到对应的任务元数据
task_index = None
task_name = None
task_timeout = None
# 使用 gather 并发执行所有任务,return_exceptions=True 确保单个任务失败不影响其他任务
if tasks_to_run:
task_results = await asyncio.gather(*tasks_to_run, return_exceptions=True)
for idx, (task, name, index, timeout) in enumerate(task_with_meta):
if task == future:
task_index = index
task_name = name
task_timeout = timeout
break
# 处理任务结果
for i, result in enumerate(task_results):
info = task_info[i]
task_index = info["index"]
task_name = info["name"]
task_timeout = info["timeout"]
try:
result = await future
results[task_index] = result # type: ignore
logger.debug(f"构建任务{task_name}完成 ({task_timeout}s)")
except asyncio.TimeoutError:
logger.warning(
f"构建任务{task_name}超时 ({task_timeout}s),使用默认值"
)
results[task_index] = self._get_default_result_for_task(task_name) # type: ignore
except Exception as e:
logger.error(f"构建任务{task_name}失败: {e!s}")
results[task_index] = self._get_default_result_for_task(task_name) # type: ignore
if isinstance(result, asyncio.TimeoutError):
# 处理超时错误
logger.warning(
f"构建任务{task_name}超时 ({task_timeout}s),使用默认值"
)
results[task_index] = self._get_default_result_for_task(task_name)
elif isinstance(result, Exception):
# 处理其他异常
logger.error(f"构建任务{task_name}失败: {result!s}")
results[task_index] = self._get_default_result_for_task(task_name)
else:
# 成功完成
results[task_index] = result
logger.debug(f"构建任务{task_name}完成 ({task_timeout}s)")
# --- 步骤 3: 合并所有结果 ---
context_data = {}