diff --git a/src/api/apiforgui.py b/src/api/apiforgui.py new file mode 100644 index 000000000..04fe37bb9 --- /dev/null +++ b/src/api/apiforgui.py @@ -0,0 +1,15 @@ +from src.heart_flow.heartflow import heartflow +from src.heart_flow.sub_heartflow import ChatState + +async def get_all_subheartflow_ids() -> list: + """获取所有子心流的ID列表""" + all_subheartflows = heartflow.subheartflow_manager.get_all_subheartflows() + return [subheartflow.subheartflow_id for subheartflow in all_subheartflows] + +async def forced_change_subheartflow_status(subheartflow_id: str, status: ChatState) -> bool: + """强制改变子心流的状态""" + subheartflow = await heartflow.get_or_create_subheartflow(subheartflow_id) + if subheartflow: + return await heartflow.force_change_subheartflow_status(subheartflow_id, status) + return False + diff --git a/src/api/main.py b/src/api/main.py index d4d3c62e7..6c2009972 100644 --- a/src/api/main.py +++ b/src/api/main.py @@ -5,9 +5,12 @@ from strawberry.fastapi import GraphQLRouter from src.common.logger_manager import get_logger from src.api.reload_config import reload_config as reload_config_func from src.common.server import global_server +from .apiforgui import get_all_subheartflow_ids, forced_change_subheartflow_status +from src.heart_flow.sub_heartflow import ChatState # import uvicorn # import os + router = APIRouter() @@ -24,6 +27,28 @@ router.include_router(graphql_router, prefix="/graphql", tags=["GraphQL"]) async def reload_config(): return await reload_config_func() +@router.get("/gui/subheartflow/get/all") +async def get_subheartflow_ids(): + """获取所有子心流的ID列表""" + return await get_all_subheartflow_ids() + +@router.post("/gui/subheartflow/forced_change_status") +async def forced_change_subheartflow_status_api(subheartflow_id: str, status: ChatState): #noqa + """强制改变子心流的状态""" + # 参数检查 + if not isinstance(status, ChatState): + logger.warning(f"无效的状态参数: {status}") + return {"status": "failed", "reason": "invalid status"} + logger.info(f"尝试将子心流 {subheartflow_id} 状态更改为 {status.value}") + success = await forced_change_subheartflow_status(subheartflow_id, status) + if success: + logger.info(f"子心流 {subheartflow_id} 状态更改为 {status.value} 成功") + return {"status": "success"} + else: + logger.error(f"子心流 {subheartflow_id} 状态更改为 {status.value} 失败") + return {"status": "failed"} + + def start_api_server(): """启动API服务器""" diff --git a/src/heart_flow/heartflow.py b/src/heart_flow/heartflow.py index bd8bc6ff4..5d9400880 100644 --- a/src/heart_flow/heartflow.py +++ b/src/heart_flow/heartflow.py @@ -1,4 +1,4 @@ -from src.heart_flow.sub_heartflow import SubHeartflow +from src.heart_flow.sub_heartflow import SubHeartflow, ChatState from src.plugins.models.utils_model import LLMRequest from src.config.config import global_config from src.plugins.schedule.schedule_generator import bot_schedule @@ -62,6 +62,13 @@ class Heartflow: # 不再需要传入 self.current_state return await self.subheartflow_manager.get_or_create_subheartflow(subheartflow_id) + async def force_change_subheartflow_status( + self, subheartflow_id: str, status: ChatState + ) -> None: + """强制改变子心流的状态""" + # 这里的 message 是可选的,可能是一个消息对象,也可能是其他类型的数据 + return await self.subheartflow_manager.force_change_state(subheartflow_id, status) + async def heartflow_start_working(self): """启动后台任务""" await self.background_task_manager.start_tasks() diff --git a/src/heart_flow/subheartflow_manager.py b/src/heart_flow/subheartflow_manager.py index f06a68c87..057d6cca3 100644 --- a/src/heart_flow/subheartflow_manager.py +++ b/src/heart_flow/subheartflow_manager.py @@ -82,6 +82,17 @@ class SubHeartflowManager: max_tokens=1000, # 与 Heartflow 一致 (虽然可能不需要这么多) request_type="subheartflow_state_eval", # 保留特定的请求类型 ) + + async def force_change_state(self, subflow_id: Any, target_state: ChatState) -> bool: + """强制改变指定子心流的状态""" + async with self._lock: + subflow = self.subheartflows.get(subflow_id) + if not subflow: + logger.warning(f"[强制状态转换]尝试转换不存在的子心流{subflow_id} 到 {target_state.value}") + return False + await subflow.change_chat_state(target_state) + logger.info(f"[强制状态转换]子心流 {subflow_id} 已转换到 {target_state.value}") + return True def get_all_subheartflows(self) -> List["SubHeartflow"]: """获取所有当前管理的 SubHeartflow 实例列表 (快照)。""" @@ -92,7 +103,7 @@ class SubHeartflowManager: Args: subheartflow_id: 子心流唯一标识符 - # mai_states 参数已被移除,使用 self.mai_state_info + mai_states 参数已被移除,使用 self.mai_state_info Returns: 成功返回SubHeartflow实例,失败返回None @@ -174,8 +185,7 @@ class SubHeartflowManager: continue subheartflow.update_last_chat_state_time() absent_last_time = subheartflow.chat_state_last_time - if max_age_seconds and (current_time - absent_last_time) > max_age_seconds: - flows_to_stop.append(subheartflow_id) + flows_to_stop.append(subheartflow_id) return flows_to_stop