三次修改
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# mmc/src/schedule/database.py
|
||||
|
||||
from typing import List
|
||||
from sqlalchemy import select, func, update, delete
|
||||
from src.common.database.sqlalchemy_models import MonthlyPlan, get_db_session
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
@@ -8,21 +9,22 @@ from src.config.config import global_config
|
||||
logger = get_logger("schedule_database")
|
||||
|
||||
|
||||
def add_new_plans(plans: List[str], month: str):
|
||||
async def add_new_plans(plans: List[str], month: str):
|
||||
"""
|
||||
批量添加新生成的月度计划到数据库,并确保不超过上限。
|
||||
|
||||
:param plans: 计划内容列表。
|
||||
:param month: 目标月份,格式为 "YYYY-MM"。
|
||||
"""
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
try:
|
||||
# 1. 获取当前有效计划数量(状态为 'active')
|
||||
current_plan_count = (
|
||||
session.query(MonthlyPlan)
|
||||
.filter(MonthlyPlan.target_month == month, MonthlyPlan.status == "active")
|
||||
.count()
|
||||
result = await session.execute(
|
||||
select(func.count(MonthlyPlan.id)).where(
|
||||
MonthlyPlan.target_month == month, MonthlyPlan.status == "active"
|
||||
)
|
||||
)
|
||||
current_plan_count = result.scalar_one()
|
||||
|
||||
# 2. 从配置获取上限
|
||||
max_plans = global_config.planning_system.max_plans_per_month
|
||||
@@ -41,7 +43,7 @@ def add_new_plans(plans: List[str], month: str):
|
||||
MonthlyPlan(plan_text=plan, target_month=month, status="active") for plan in plans_to_add
|
||||
]
|
||||
session.add_all(new_plan_objects)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
logger.info(f"成功向数据库添加了 {len(new_plan_objects)} 条 {month} 的月度计划。")
|
||||
if len(plans) > len(plans_to_add):
|
||||
@@ -49,32 +51,31 @@ def add_new_plans(plans: List[str], month: str):
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"添加月度计划时发生错误: {e}")
|
||||
session.rollback()
|
||||
await session.rollback()
|
||||
raise
|
||||
|
||||
|
||||
def get_active_plans_for_month(month: str) -> List[MonthlyPlan]:
|
||||
async def get_active_plans_for_month(month: str) -> List[MonthlyPlan]:
|
||||
"""
|
||||
获取指定月份所有状态为 'active' 的计划。
|
||||
|
||||
:param month: 目标月份,格式为 "YYYY-MM"。
|
||||
:return: MonthlyPlan 对象列表。
|
||||
"""
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
try:
|
||||
plans = (
|
||||
session.query(MonthlyPlan)
|
||||
.filter(MonthlyPlan.target_month == month, MonthlyPlan.status == "active")
|
||||
result = await session.execute(
|
||||
select(MonthlyPlan)
|
||||
.where(MonthlyPlan.target_month == month, MonthlyPlan.status == "active")
|
||||
.order_by(MonthlyPlan.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
return plans
|
||||
return result.scalars().all()
|
||||
except Exception as e:
|
||||
logger.error(f"查询 {month} 的有效月度计划时发生错误: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def mark_plans_completed(plan_ids: List[int]):
|
||||
async def mark_plans_completed(plan_ids: List[int]):
|
||||
"""
|
||||
将指定ID的计划标记为已完成。
|
||||
|
||||
@@ -83,9 +84,10 @@ def mark_plans_completed(plan_ids: List[int]):
|
||||
if not plan_ids:
|
||||
return
|
||||
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
try:
|
||||
plans_to_mark = session.query(MonthlyPlan).filter(MonthlyPlan.id.in_(plan_ids)).all()
|
||||
result = await session.execute(select(MonthlyPlan).where(MonthlyPlan.id.in_(plan_ids)))
|
||||
plans_to_mark = result.scalars().all()
|
||||
if not plans_to_mark:
|
||||
logger.info("没有需要标记为完成的月度计划。")
|
||||
return
|
||||
@@ -93,17 +95,17 @@ def mark_plans_completed(plan_ids: List[int]):
|
||||
plan_details = "\n".join([f" {i + 1}. {plan.plan_text}" for i, plan in enumerate(plans_to_mark)])
|
||||
logger.info(f"以下 {len(plans_to_mark)} 条月度计划将被标记为已完成:\n{plan_details}")
|
||||
|
||||
session.query(MonthlyPlan).filter(MonthlyPlan.id.in_(plan_ids)).update(
|
||||
{"status": "completed"}, synchronize_session=False
|
||||
await session.execute(
|
||||
update(MonthlyPlan).where(MonthlyPlan.id.in_(plan_ids)).values(status="completed")
|
||||
)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"标记月度计划为完成时发生错误: {e}")
|
||||
session.rollback()
|
||||
await session.rollback()
|
||||
raise
|
||||
|
||||
|
||||
def delete_plans_by_ids(plan_ids: List[int]):
|
||||
async def delete_plans_by_ids(plan_ids: List[int]):
|
||||
"""
|
||||
根据ID列表从数据库中物理删除月度计划。
|
||||
|
||||
@@ -112,10 +114,11 @@ def delete_plans_by_ids(plan_ids: List[int]):
|
||||
if not plan_ids:
|
||||
return
|
||||
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
try:
|
||||
# 先查询要删除的计划,用于日志记录
|
||||
plans_to_delete = session.query(MonthlyPlan).filter(MonthlyPlan.id.in_(plan_ids)).all()
|
||||
result = await session.execute(select(MonthlyPlan).where(MonthlyPlan.id.in_(plan_ids)))
|
||||
plans_to_delete = result.scalars().all()
|
||||
if not plans_to_delete:
|
||||
logger.info("没有找到需要删除的月度计划。")
|
||||
return
|
||||
@@ -124,16 +127,16 @@ def delete_plans_by_ids(plan_ids: List[int]):
|
||||
logger.info(f"检测到月度计划超额,将删除以下 {len(plans_to_delete)} 条计划:\n{plan_details}")
|
||||
|
||||
# 执行删除
|
||||
session.query(MonthlyPlan).filter(MonthlyPlan.id.in_(plan_ids)).delete(synchronize_session=False)
|
||||
session.commit()
|
||||
await session.execute(delete(MonthlyPlan).where(MonthlyPlan.id.in_(plan_ids)))
|
||||
await session.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"删除月度计划时发生错误: {e}")
|
||||
session.rollback()
|
||||
await session.rollback()
|
||||
raise
|
||||
|
||||
|
||||
def update_plan_usage(plan_ids: List[int], used_date: str):
|
||||
async def update_plan_usage(plan_ids: List[int], used_date: str):
|
||||
"""
|
||||
更新计划的使用统计信息。
|
||||
|
||||
@@ -143,44 +146,47 @@ def update_plan_usage(plan_ids: List[int], used_date: str):
|
||||
if not plan_ids:
|
||||
return
|
||||
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
try:
|
||||
# 获取完成阈值配置,如果不存在则使用默认值
|
||||
completion_threshold = getattr(global_config.planning_system, "completion_threshold", 3)
|
||||
|
||||
# 批量更新使用次数和最后使用日期
|
||||
session.query(MonthlyPlan).filter(MonthlyPlan.id.in_(plan_ids)).update(
|
||||
{"usage_count": MonthlyPlan.usage_count + 1, "last_used_date": used_date}, synchronize_session=False
|
||||
await session.execute(
|
||||
update(MonthlyPlan)
|
||||
.where(MonthlyPlan.id.in_(plan_ids))
|
||||
.values(usage_count=MonthlyPlan.usage_count + 1, last_used_date=used_date)
|
||||
)
|
||||
|
||||
# 检查是否有计划达到完成阈值
|
||||
plans_to_complete = (
|
||||
session.query(MonthlyPlan)
|
||||
.filter(
|
||||
result = await session.execute(
|
||||
select(MonthlyPlan).where(
|
||||
MonthlyPlan.id.in_(plan_ids),
|
||||
MonthlyPlan.usage_count >= completion_threshold,
|
||||
MonthlyPlan.status == "active",
|
||||
)
|
||||
.all()
|
||||
)
|
||||
plans_to_complete = result.scalars().all()
|
||||
|
||||
if plans_to_complete:
|
||||
completed_ids = [plan.id for plan in plans_to_complete]
|
||||
session.query(MonthlyPlan).filter(MonthlyPlan.id.in_(completed_ids)).update(
|
||||
{"status": "completed"}, synchronize_session=False
|
||||
await session.execute(
|
||||
update(MonthlyPlan).where(MonthlyPlan.id.in_(completed_ids)).values(status="completed")
|
||||
)
|
||||
|
||||
logger.info(f"计划 {completed_ids} 已达到使用阈值 ({completion_threshold}),标记为已完成。")
|
||||
|
||||
session.commit()
|
||||
await session.commit()
|
||||
logger.info(f"成功更新了 {len(plan_ids)} 条月度计划的使用统计。")
|
||||
except Exception as e:
|
||||
logger.error(f"更新月度计划使用统计时发生错误: {e}")
|
||||
session.rollback()
|
||||
await session.rollback()
|
||||
raise
|
||||
|
||||
|
||||
def get_smart_plans_for_daily_schedule(month: str, max_count: int = 3, avoid_days: int = 7) -> List[MonthlyPlan]:
|
||||
async def get_smart_plans_for_daily_schedule(
|
||||
month: str, max_count: int = 3, avoid_days: int = 7
|
||||
) -> List[MonthlyPlan]:
|
||||
"""
|
||||
智能抽取月度计划用于每日日程生成。
|
||||
|
||||
@@ -196,19 +202,24 @@ def get_smart_plans_for_daily_schedule(month: str, max_count: int = 3, avoid_day
|
||||
"""
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
try:
|
||||
# 计算避免重复的日期阈值
|
||||
avoid_date = (datetime.now() - timedelta(days=avoid_days)).strftime("%Y-%m-%d")
|
||||
|
||||
# 查询符合条件的计划
|
||||
query = session.query(MonthlyPlan).filter(MonthlyPlan.target_month == month, MonthlyPlan.status == "active")
|
||||
query = select(MonthlyPlan).where(
|
||||
MonthlyPlan.target_month == month, MonthlyPlan.status == "active"
|
||||
)
|
||||
|
||||
# 排除最近使用过的计划
|
||||
query = query.filter((MonthlyPlan.last_used_date.is_(None)) | (MonthlyPlan.last_used_date < avoid_date))
|
||||
query = query.where(
|
||||
(MonthlyPlan.last_used_date.is_(None)) | (MonthlyPlan.last_used_date < avoid_date)
|
||||
)
|
||||
|
||||
# 按使用次数升序排列,优先选择使用次数少的
|
||||
plans = query.order_by(MonthlyPlan.usage_count.asc()).all()
|
||||
result = await session.execute(query.order_by(MonthlyPlan.usage_count.asc()))
|
||||
plans = result.scalars().all()
|
||||
|
||||
if not plans:
|
||||
logger.info(f"没有找到符合条件的 {month} 月度计划。")
|
||||
@@ -228,31 +239,31 @@ def get_smart_plans_for_daily_schedule(month: str, max_count: int = 3, avoid_day
|
||||
return []
|
||||
|
||||
|
||||
def archive_active_plans_for_month(month: str):
|
||||
async def archive_active_plans_for_month(month: str):
|
||||
"""
|
||||
将指定月份所有状态为 'active' 的计划归档为 'archived'。
|
||||
通常在月底调用。
|
||||
|
||||
:param month: 目标月份,格式为 "YYYY-MM"。
|
||||
"""
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
try:
|
||||
updated_count = (
|
||||
session.query(MonthlyPlan)
|
||||
.filter(MonthlyPlan.target_month == month, MonthlyPlan.status == "active")
|
||||
.update({"status": "archived"}, synchronize_session=False)
|
||||
result = await session.execute(
|
||||
update(MonthlyPlan)
|
||||
.where(MonthlyPlan.target_month == month, MonthlyPlan.status == "active")
|
||||
.values(status="archived")
|
||||
)
|
||||
|
||||
session.commit()
|
||||
updated_count = result.rowcount
|
||||
await session.commit()
|
||||
logger.info(f"成功将 {updated_count} 条 {month} 的活跃月度计划归档。")
|
||||
return updated_count
|
||||
except Exception as e:
|
||||
logger.error(f"归档 {month} 的月度计划时发生错误: {e}")
|
||||
session.rollback()
|
||||
await session.rollback()
|
||||
raise
|
||||
|
||||
|
||||
def get_archived_plans_for_month(month: str) -> List[MonthlyPlan]:
|
||||
async def get_archived_plans_for_month(month: str) -> List[MonthlyPlan]:
|
||||
"""
|
||||
获取指定月份所有状态为 'archived' 的计划。
|
||||
用于生成下个月计划时的参考。
|
||||
@@ -260,34 +271,34 @@ def get_archived_plans_for_month(month: str) -> List[MonthlyPlan]:
|
||||
:param month: 目标月份,格式为 "YYYY-MM"。
|
||||
:return: MonthlyPlan 对象列表。
|
||||
"""
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
try:
|
||||
plans = (
|
||||
session.query(MonthlyPlan)
|
||||
.filter(MonthlyPlan.target_month == month, MonthlyPlan.status == "archived")
|
||||
.all()
|
||||
result = await session.execute(
|
||||
select(MonthlyPlan).where(
|
||||
MonthlyPlan.target_month == month, MonthlyPlan.status == "archived"
|
||||
)
|
||||
)
|
||||
return plans
|
||||
return result.scalars().all()
|
||||
except Exception as e:
|
||||
logger.error(f"查询 {month} 的归档月度计划时发生错误: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def has_active_plans(month: str) -> bool:
|
||||
async def has_active_plans(month: str) -> bool:
|
||||
"""
|
||||
检查指定月份是否存在任何状态为 'active' 的计划。
|
||||
|
||||
:param month: 目标月份,格式为 "YYYY-MM"。
|
||||
:return: 如果存在则返回 True,否则返回 False。
|
||||
"""
|
||||
with get_db_session() as session:
|
||||
async with get_db_session() as session:
|
||||
try:
|
||||
count = (
|
||||
session.query(MonthlyPlan)
|
||||
.filter(MonthlyPlan.target_month == month, MonthlyPlan.status == "active")
|
||||
.count()
|
||||
result = await session.execute(
|
||||
select(func.count(MonthlyPlan.id)).where(
|
||||
MonthlyPlan.target_month == month, MonthlyPlan.status == "active"
|
||||
)
|
||||
)
|
||||
return count > 0
|
||||
return result.scalar_one() > 0
|
||||
except Exception as e:
|
||||
logger.error(f"检查 {month} 的有效月度计划时发生错误: {e}")
|
||||
return False
|
||||
@@ -14,6 +14,11 @@ class MonthlyPlanManager:
|
||||
self.plan_manager = PlanManager()
|
||||
self.monthly_task_started = False
|
||||
|
||||
async def initialize(self):
|
||||
logger.info("正在初始化月度计划管理器...")
|
||||
await self.start_monthly_plan_generation()
|
||||
logger.info("月度计划管理器初始化成功")
|
||||
|
||||
async def start_monthly_plan_generation(self):
|
||||
if not self.monthly_task_started:
|
||||
logger.info(" 正在启动每月月度计划生成任务...")
|
||||
|
||||
@@ -28,20 +28,20 @@ class PlanManager:
|
||||
if target_month is None:
|
||||
target_month = datetime.now().strftime("%Y-%m")
|
||||
|
||||
if not has_active_plans(target_month):
|
||||
if not await has_active_plans(target_month):
|
||||
logger.info(f" {target_month} 没有任何有效的月度计划,将触发同步生成。")
|
||||
generation_successful = await self._generate_monthly_plans_logic(target_month)
|
||||
return generation_successful
|
||||
else:
|
||||
logger.info(f"{target_month} 已存在有效的月度计划。")
|
||||
plans = get_active_plans_for_month(target_month)
|
||||
plans = await get_active_plans_for_month(target_month)
|
||||
max_plans = global_config.planning_system.max_plans_per_month
|
||||
if len(plans) > max_plans:
|
||||
logger.warning(f"当前月度计划数量 ({len(plans)}) 超出上限 ({max_plans}),将自动删除多余的计划。")
|
||||
plans_to_delete = plans[: len(plans) - max_plans]
|
||||
delete_ids = [p.id for p in plans_to_delete]
|
||||
delete_plans_by_ids(delete_ids) # type: ignore
|
||||
plans = get_active_plans_for_month(target_month)
|
||||
await delete_plans_by_ids(delete_ids) # type: ignore
|
||||
plans = await get_active_plans_for_month(target_month)
|
||||
|
||||
if plans:
|
||||
plan_texts = "\n".join([f" {i + 1}. {plan.plan_text}" for i, plan in enumerate(plans)])
|
||||
@@ -64,11 +64,11 @@ class PlanManager:
|
||||
return False
|
||||
|
||||
last_month = self._get_previous_month(target_month)
|
||||
archived_plans = get_archived_plans_for_month(last_month)
|
||||
archived_plans = await get_archived_plans_for_month(last_month)
|
||||
plans = await self.llm_generator.generate_plans_with_llm(target_month, archived_plans)
|
||||
|
||||
if plans:
|
||||
add_new_plans(plans, target_month)
|
||||
await add_new_plans(plans, target_month)
|
||||
logger.info(f"成功为 {target_month} 生成并保存了 {len(plans)} 条月度计划。")
|
||||
return True
|
||||
else:
|
||||
@@ -95,11 +95,11 @@ class PlanManager:
|
||||
if target_month is None:
|
||||
target_month = datetime.now().strftime("%Y-%m")
|
||||
logger.info(f" 开始归档 {target_month} 的活跃月度计划...")
|
||||
archived_count = archive_active_plans_for_month(target_month)
|
||||
archived_count = await archive_active_plans_for_month(target_month)
|
||||
logger.info(f" 成功归档了 {archived_count} 条 {target_month} 的月度计划。")
|
||||
except Exception as e:
|
||||
logger.error(f" 归档 {target_month} 月度计划时发生错误: {e}")
|
||||
|
||||
def get_plans_for_schedule(self, month: str, max_count: int) -> List:
|
||||
async def get_plans_for_schedule(self, month: str, max_count: int) -> List:
|
||||
avoid_days = global_config.planning_system.avoid_repetition_days
|
||||
return get_smart_plans_for_daily_schedule(month, max_count=max_count, avoid_days=avoid_days)
|
||||
return await get_smart_plans_for_daily_schedule(month, max_count=max_count, avoid_days=avoid_days)
|
||||
@@ -23,6 +23,13 @@ class ScheduleManager:
|
||||
self.daily_task_started = False
|
||||
self.schedule_generation_running = False
|
||||
|
||||
async def initialize(self):
|
||||
if global_config.planning_system.schedule_enable:
|
||||
logger.info("日程表功能已启用,正在初始化管理器...")
|
||||
await self.load_or_generate_today_schedule()
|
||||
await self.start_daily_schedule_generation()
|
||||
logger.info("日程表管理器初始化成功。")
|
||||
|
||||
async def start_daily_schedule_generation(self):
|
||||
if not self.daily_task_started:
|
||||
logger.info("正在启动每日日程生成任务...")
|
||||
@@ -40,7 +47,7 @@ class ScheduleManager:
|
||||
|
||||
today_str = datetime.now().strftime("%Y-%m-%d")
|
||||
try:
|
||||
schedule_data = self._load_schedule_from_db(today_str)
|
||||
schedule_data = await self._load_schedule_from_db(today_str)
|
||||
if schedule_data:
|
||||
self.today_schedule = schedule_data
|
||||
self._log_loaded_schedule(today_str)
|
||||
@@ -54,9 +61,10 @@ class ScheduleManager:
|
||||
logger.info("尝试生成日程作为备用方案...")
|
||||
await self.generate_and_save_schedule()
|
||||
|
||||
def _load_schedule_from_db(self, date_str: str) -> Optional[List[Dict[str, Any]]]:
|
||||
with get_db_session() as session:
|
||||
schedule_record = session.query(Schedule).filter(Schedule.date == date_str).first()
|
||||
async def _load_schedule_from_db(self, date_str: str) -> Optional[List[Dict[str, Any]]]:
|
||||
async with get_db_session() as session:
|
||||
result = await session.execute(select(Schedule).filter(Schedule.date == date_str))
|
||||
schedule_record = result.scalars().first()
|
||||
if schedule_record:
|
||||
logger.info(f"从数据库加载今天的日程 ({date_str})。")
|
||||
schedule_data = orjson.loads(str(schedule_record.schedule_data))
|
||||
@@ -90,35 +98,35 @@ class ScheduleManager:
|
||||
sampled_plans = []
|
||||
if global_config.planning_system.monthly_plan_enable:
|
||||
await self.plan_manager.ensure_and_generate_plans_if_needed(current_month_str)
|
||||
sampled_plans = self.plan_manager.get_plans_for_schedule(current_month_str, max_count=3)
|
||||
sampled_plans = await self.plan_manager.get_plans_for_schedule(current_month_str, max_count=3)
|
||||
|
||||
schedule_data = await self.llm_generator.generate_schedule_with_llm(sampled_plans)
|
||||
|
||||
if schedule_data:
|
||||
self._save_schedule_to_db(today_str, schedule_data)
|
||||
await self._save_schedule_to_db(today_str, schedule_data)
|
||||
self.today_schedule = schedule_data
|
||||
self._log_generated_schedule(today_str, schedule_data)
|
||||
|
||||
if sampled_plans:
|
||||
used_plan_ids = [plan.id for plan in sampled_plans]
|
||||
logger.info(f"更新使用过的月度计划 {used_plan_ids} 的统计信息。")
|
||||
update_plan_usage(used_plan_ids, today_str)
|
||||
await update_plan_usage(used_plan_ids, today_str)
|
||||
finally:
|
||||
self.schedule_generation_running = False
|
||||
logger.info("日程生成任务结束")
|
||||
|
||||
def _save_schedule_to_db(self, date_str: str, schedule_data: List[Dict[str, Any]]):
|
||||
with get_db_session() as session:
|
||||
async def _save_schedule_to_db(self, date_str: str, schedule_data: List[Dict[str, Any]]):
|
||||
async with get_db_session() as session:
|
||||
schedule_json = orjson.dumps(schedule_data).decode("utf-8")
|
||||
existing_schedule = session.query(Schedule).filter(Schedule.date == date_str).first()
|
||||
result = await session.execute(select(Schedule).filter(Schedule.date == date_str))
|
||||
existing_schedule = result.scalars().first()
|
||||
if existing_schedule:
|
||||
session.query(Schedule).filter(Schedule.date == date_str).update(
|
||||
{Schedule.schedule_data: schedule_json, Schedule.updated_at: datetime.now()}
|
||||
)
|
||||
existing_schedule.schedule_data = schedule_json
|
||||
existing_schedule.updated_at = datetime.now()
|
||||
else:
|
||||
new_schedule = Schedule(date=date_str, schedule_data=schedule_json)
|
||||
session.add(new_schedule)
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
def _log_generated_schedule(self, date_str: str, schedule_data: List[Dict[str, Any]]):
|
||||
schedule_str = f"✅ 成功生成并保存今天的日程 ({date_str}):\n"
|
||||
|
||||
Reference in New Issue
Block a user