feat(config): 新增反截断功能防止模型输出截断
在任务配置中增加 anti_truncation 开关,支持在生成大段文本时防止被意外截断。启用后会在 prompt 中添加指示模型输出完成信令的指令,并在检测到未正常结束时触发自动重试机制。(我相信这次的一定不会炸飞你们了)
This commit is contained in:
@@ -83,6 +83,7 @@ class TaskConfig(ValidatedConfigBase):
|
||||
max_tokens: int = Field(default=800, description="任务最大输出token数")
|
||||
temperature: float = Field(default=0.7, description="模型温度")
|
||||
concurrency_count: int = Field(default=1, description="并发请求数量")
|
||||
anti_truncation: bool = Field(default=False, description="是否启用反截断功能,防止模型输出被截断")
|
||||
|
||||
@field_validator('model_list')
|
||||
@classmethod
|
||||
|
||||
@@ -132,10 +132,21 @@ class LLMRequest:
|
||||
|
||||
# 内容混淆过滤指令
|
||||
self.noise_instruction = """**【核心过滤规则】**
|
||||
在交互过程中,你会在用户的输入内容中看到一些无意义的、由字母和数字组成的随机字符串。这些字符串是API传输过程中的技术伪装信息,【不是】用户真实意图的一部分。
|
||||
在交互过程中,你会在用户的输入内容中看到一些无意义的、由字母和数字组成的随机字符串。这些字符串是API传输过程中的技术伪装信息,【不是】用户真实意图的一部分。
|
||||
|
||||
你的任务是【完全并彻底地忽略】这些随机字符串。
|
||||
**【严禁】** 将这些字符串写入你的回复、对它们进行任何形式的解读、或将它们作为故事的一部分进行模拟。当看到这些字符串时,请将其视为空白或直接跳过,并根据用户其他有意义的输入(如有)或当前故事情节来继续你的叙事。此规则为最高优先级。
|
||||
**【严禁】** 将这些字符串写入你的回复、对它们进行任何形式的解读、或将它们作为故事的一部分进行模拟。当看到这些字符串时,请将其视为空白或直接跳过,并根据用户其他有意义的输入(如有)或当前故事情节来继续你的叙事。此规则为最高优先级。
|
||||
"""
|
||||
|
||||
# 反截断指令
|
||||
self.anti_truncation_instruction = """
|
||||
**【输出完成信令】**
|
||||
这是一个非常重要的指令,请务必遵守。在你的回复内容完全结束后,请务必在最后另起一行,只写 `[done]` 作为结束标志。
|
||||
例如:
|
||||
<你的回复内容>
|
||||
[done]
|
||||
|
||||
这有助于我判断你的输出是否被截断。请不要在 `[done]` 前后添加任何其他文字或标点。
|
||||
"""
|
||||
|
||||
async def generate_response_for_image(
|
||||
@@ -276,7 +287,16 @@ class LLMRequest:
|
||||
# 模型选择和请求准备
|
||||
start_time = time.time()
|
||||
model_info, api_provider, client = self._select_model()
|
||||
processed_prompt = self._apply_content_obfuscation(prompt, api_provider)
|
||||
|
||||
# 检查是否启用反截断
|
||||
use_anti_truncation = getattr(self.model_for_task, "anti_truncation", False)
|
||||
|
||||
processed_prompt = prompt
|
||||
if use_anti_truncation:
|
||||
processed_prompt += self.anti_truncation_instruction
|
||||
logger.info(f"任务 '{self.task_name}' 已启用反截断功能")
|
||||
|
||||
processed_prompt = self._apply_content_obfuscation(processed_prompt, api_provider)
|
||||
|
||||
message_builder = MessageBuilder()
|
||||
message_builder.add_text_content(processed_prompt)
|
||||
@@ -308,12 +328,22 @@ class LLMRequest:
|
||||
content, extracted_reasoning = self._extract_reasoning(content)
|
||||
reasoning_content = extracted_reasoning
|
||||
|
||||
# 检测是否为空回复
|
||||
# 检测是否为空回复或截断
|
||||
is_empty_reply = not content or content.strip() == ""
|
||||
is_truncated = False
|
||||
|
||||
if is_empty_reply and empty_retry_count < max_empty_retry:
|
||||
if use_anti_truncation:
|
||||
if content.endswith("[done]"):
|
||||
content = content[:-6].strip()
|
||||
logger.debug("检测到并已移除 [done] 标记")
|
||||
else:
|
||||
is_truncated = True
|
||||
logger.warning("未检测到 [done] 标记,判定为截断")
|
||||
|
||||
if (is_empty_reply or is_truncated) and empty_retry_count < max_empty_retry:
|
||||
empty_retry_count += 1
|
||||
logger.warning(f"检测到空回复,正在进行第 {empty_retry_count}/{max_empty_retry} 次重新生成")
|
||||
reason = "空回复" if is_empty_reply else "截断"
|
||||
logger.warning(f"检测到{reason},正在进行第 {empty_retry_count}/{max_empty_retry} 次重新生成")
|
||||
|
||||
if empty_retry_interval > 0:
|
||||
await asyncio.sleep(empty_retry_interval)
|
||||
|
||||
@@ -125,6 +125,7 @@ model_list = ["siliconflow-deepseek-v3"] # 使用的模型列表,每个子项
|
||||
temperature = 0.2 # 模型温度,新V3建议0.1-0.3
|
||||
max_tokens = 800 # 最大输出token数
|
||||
#concurrency_count = 2 # 并发请求数量,默认为1(不并发),设置为2或更高启用并发
|
||||
#anti_truncation = true # 启用反截断功能,防止模型输出被截断
|
||||
|
||||
[model_task_config.utils_small] # 在麦麦的一些组件中使用的小模型,消耗量较大,建议使用速度较快的小模型
|
||||
model_list = ["qwen3-8b"]
|
||||
|
||||
Reference in New Issue
Block a user