refactor(schedule): 重构月度计划生成器以增强稳定性和可维护性
将计划生成逻辑重构为使用统一的 `LLMRequest` 类,以简化模型调用流程。此更改消除了原有的手动选择模型、构建客户端和处理响应的复杂逻辑。 引入了 Pydantic 模型 `PlanResponse` 对 LLM 的 JSON 输出进行严格的验证和解析,并集成了 `json_repair` 库来自动修复格式错误的 JSON。这些措施显著提高了计划生成功能的健壮性和对 LLM 异常输出的容错能力。
This commit is contained in:
committed by
Windpicker-owo
parent
f53993e34a
commit
fcb7a85e69
@@ -1,16 +1,22 @@
|
|||||||
# mmc/src/schedule/plan_generator.py
|
# mmc/src/schedule/plan_generator.py
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import random
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
from pydantic import BaseModel, ValidationError
|
||||||
|
from json_repair import repair_json
|
||||||
|
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
from src.llm_models.model_client.base_client import client_registry
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.llm_models.payload_content.message import Message, RoleType
|
|
||||||
from src.llm_models.payload_content.resp_format import RespFormat, RespFormatType
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
logger = get_logger("plan_generator")
|
logger = get_logger("plan_generator")
|
||||||
|
|
||||||
|
class PlanResponse(BaseModel):
|
||||||
|
"""
|
||||||
|
用于验证月度计划LLM响应的Pydantic模型。
|
||||||
|
"""
|
||||||
|
plans: List[str]
|
||||||
|
|
||||||
class PlanGenerator:
|
class PlanGenerator:
|
||||||
"""
|
"""
|
||||||
负责生成月度计划。
|
负责生成月度计划。
|
||||||
@@ -18,6 +24,8 @@ class PlanGenerator:
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.bot_personality = self._get_bot_personality()
|
self.bot_personality = self._get_bot_personality()
|
||||||
|
task_config = model_config.model_task_config.get_task("monthly_plan_generator")
|
||||||
|
self.llm_request = LLMRequest(model_set=task_config, request_type="monthly_plan_generator")
|
||||||
|
|
||||||
def _get_bot_personality(self) -> str:
|
def _get_bot_personality(self) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -61,56 +69,46 @@ class PlanGenerator:
|
|||||||
:return: 生成的计划文本列表
|
:return: 生成的计划文本列表
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 1. 获取模型任务配置
|
# 1. 构建Prompt
|
||||||
task_config = model_config.model_task_config.get_task("monthly_plan_generator")
|
|
||||||
|
|
||||||
# 2. 随机选择一个模型
|
|
||||||
model_name = random.choice(task_config.model_list)
|
|
||||||
model_info = model_config.get_model_info(model_name)
|
|
||||||
api_provider = model_config.get_provider(model_info.api_provider)
|
|
||||||
|
|
||||||
# 3. 获取客户端实例
|
|
||||||
llm_client = client_registry.get_client_class_instance(api_provider)
|
|
||||||
|
|
||||||
# 4. 构建Prompt和消息体
|
|
||||||
prompt = self._build_prompt(year, month, count)
|
prompt = self._build_prompt(year, month, count)
|
||||||
message_list = [Message(role=RoleType.User, content=prompt)]
|
logger.info(f"正在为 {year}-{month} 生成 {count} 个月度计划...")
|
||||||
|
|
||||||
logger.info(f"正在使用模型 '{model_name}' 为 {year}-{month} 生成 {count} 个月度计划...")
|
# 2. 调用LLM
|
||||||
|
llm_content, (reasoning, model_name, _) = await self.llm_request.generate_response_async(prompt=prompt)
|
||||||
|
|
||||||
# 5. 调用LLM
|
logger.info(f"使用模型 '{model_name}' 生成完成。")
|
||||||
response = await llm_client.get_response(
|
if reasoning:
|
||||||
model_info=model_info,
|
logger.debug(f"模型推理过程: {reasoning}")
|
||||||
message_list=message_list,
|
|
||||||
temperature=task_config.temperature,
|
|
||||||
max_tokens=task_config.max_tokens,
|
|
||||||
response_format=RespFormat(format_type=RespFormatType.JSON_OBJ) # 请求JSON输出
|
|
||||||
)
|
|
||||||
|
|
||||||
if not response or not response.content:
|
if not llm_content:
|
||||||
logger.error("LLM未能返回有效的计划内容。")
|
logger.error("LLM未能返回有效的计划内容。")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# 6. 解析LLM返回的JSON
|
# 3. 解析并验证LLM返回的JSON
|
||||||
try:
|
try:
|
||||||
# 移除可能的Markdown代码块标记
|
# 移除可能的Markdown代码块标记
|
||||||
clean_content = response.content.strip()
|
clean_content = llm_content.strip()
|
||||||
if clean_content.startswith("```json"):
|
if clean_content.startswith("```json"):
|
||||||
clean_content = clean_content[7:]
|
clean_content = clean_content[7:]
|
||||||
if clean_content.endswith("```"):
|
if clean_content.endswith("```"):
|
||||||
clean_content = clean_content[:-3]
|
clean_content = clean_content[:-3]
|
||||||
|
|
||||||
data = json.loads(clean_content.strip())
|
# 修复并解析JSON
|
||||||
plans = data.get("plans", [])
|
repaired_json_str = repair_json(clean_content)
|
||||||
|
data = json.loads(repaired_json_str)
|
||||||
|
|
||||||
if isinstance(plans, list) and all(isinstance(p, str) for p in plans):
|
# 使用Pydantic进行验证
|
||||||
logger.info(f"成功生成并解析了 {len(plans)} 个月度计划。")
|
validated_response = PlanResponse.model_validate(data)
|
||||||
|
plans = validated_response.plans
|
||||||
|
|
||||||
|
logger.info(f"成功生成并验证了 {len(plans)} 个月度计划。")
|
||||||
return plans
|
return plans
|
||||||
else:
|
|
||||||
logger.error(f"LLM返回的JSON格式不正确或'plans'键不是字符串列表: {response.content}")
|
|
||||||
return []
|
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logger.error(f"无法解析LLM返回的JSON: {response.content}")
|
logger.error(f"修复后仍然无法解析LLM返回的JSON: {llm_content}")
|
||||||
|
return []
|
||||||
|
except ValidationError as e:
|
||||||
|
logger.error(f"LLM返回的JSON格式不符合预期: {e}\n原始响应: {llm_content}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
Reference in New Issue
Block a user