feat: 添加获取所有子心流ID和强制改变子心流状态的API接口

This commit is contained in:
墨梓柒
2025-05-05 01:26:34 +08:00
parent 27212c5d43
commit 2115917580
4 changed files with 61 additions and 4 deletions

15
src/api/apiforgui.py Normal file
View File

@@ -0,0 +1,15 @@
from src.heart_flow.heartflow import heartflow
from src.heart_flow.sub_heartflow import ChatState
async def get_all_subheartflow_ids() -> list:
"""获取所有子心流的ID列表"""
all_subheartflows = heartflow.subheartflow_manager.get_all_subheartflows()
return [subheartflow.subheartflow_id for subheartflow in all_subheartflows]
async def forced_change_subheartflow_status(subheartflow_id: str, status: ChatState) -> bool:
"""强制改变子心流的状态"""
subheartflow = await heartflow.get_or_create_subheartflow(subheartflow_id)
if subheartflow:
return await heartflow.force_change_subheartflow_status(subheartflow_id, status)
return False

View File

@@ -5,9 +5,12 @@ from strawberry.fastapi import GraphQLRouter
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from src.api.reload_config import reload_config as reload_config_func from src.api.reload_config import reload_config as reload_config_func
from src.common.server import global_server from src.common.server import global_server
from .apiforgui import get_all_subheartflow_ids, forced_change_subheartflow_status
from src.heart_flow.sub_heartflow import ChatState
# import uvicorn # import uvicorn
# import os # import os
router = APIRouter() router = APIRouter()
@@ -24,6 +27,28 @@ router.include_router(graphql_router, prefix="/graphql", tags=["GraphQL"])
async def reload_config(): async def reload_config():
return await reload_config_func() return await reload_config_func()
@router.get("/gui/subheartflow/get/all")
async def get_subheartflow_ids():
"""获取所有子心流的ID列表"""
return await get_all_subheartflow_ids()
@router.post("/gui/subheartflow/forced_change_status")
async def forced_change_subheartflow_status_api(subheartflow_id: str, status: ChatState): #noqa
"""强制改变子心流的状态"""
# 参数检查
if not isinstance(status, ChatState):
logger.warning(f"无效的状态参数: {status}")
return {"status": "failed", "reason": "invalid status"}
logger.info(f"尝试将子心流 {subheartflow_id} 状态更改为 {status.value}")
success = await forced_change_subheartflow_status(subheartflow_id, status)
if success:
logger.info(f"子心流 {subheartflow_id} 状态更改为 {status.value} 成功")
return {"status": "success"}
else:
logger.error(f"子心流 {subheartflow_id} 状态更改为 {status.value} 失败")
return {"status": "failed"}
def start_api_server(): def start_api_server():
"""启动API服务器""" """启动API服务器"""

View File

@@ -1,4 +1,4 @@
from src.heart_flow.sub_heartflow import SubHeartflow from src.heart_flow.sub_heartflow import SubHeartflow, ChatState
from src.plugins.models.utils_model import LLMRequest from src.plugins.models.utils_model import LLMRequest
from src.config.config import global_config from src.config.config import global_config
from src.plugins.schedule.schedule_generator import bot_schedule from src.plugins.schedule.schedule_generator import bot_schedule
@@ -62,6 +62,13 @@ class Heartflow:
# 不再需要传入 self.current_state # 不再需要传入 self.current_state
return await self.subheartflow_manager.get_or_create_subheartflow(subheartflow_id) return await self.subheartflow_manager.get_or_create_subheartflow(subheartflow_id)
async def force_change_subheartflow_status(
self, subheartflow_id: str, status: ChatState
) -> None:
"""强制改变子心流的状态"""
# 这里的 message 是可选的,可能是一个消息对象,也可能是其他类型的数据
return await self.subheartflow_manager.force_change_state(subheartflow_id, status)
async def heartflow_start_working(self): async def heartflow_start_working(self):
"""启动后台任务""" """启动后台任务"""
await self.background_task_manager.start_tasks() await self.background_task_manager.start_tasks()

View File

@@ -82,6 +82,17 @@ class SubHeartflowManager:
max_tokens=1000, # 与 Heartflow 一致 (虽然可能不需要这么多) max_tokens=1000, # 与 Heartflow 一致 (虽然可能不需要这么多)
request_type="subheartflow_state_eval", # 保留特定的请求类型 request_type="subheartflow_state_eval", # 保留特定的请求类型
) )
async def force_change_state(self, subflow_id: Any, target_state: ChatState) -> bool:
"""强制改变指定子心流的状态"""
async with self._lock:
subflow = self.subheartflows.get(subflow_id)
if not subflow:
logger.warning(f"[强制状态转换]尝试转换不存在的子心流{subflow_id}{target_state.value}")
return False
await subflow.change_chat_state(target_state)
logger.info(f"[强制状态转换]子心流 {subflow_id} 已转换到 {target_state.value}")
return True
def get_all_subheartflows(self) -> List["SubHeartflow"]: def get_all_subheartflows(self) -> List["SubHeartflow"]:
"""获取所有当前管理的 SubHeartflow 实例列表 (快照)。""" """获取所有当前管理的 SubHeartflow 实例列表 (快照)。"""
@@ -92,7 +103,7 @@ class SubHeartflowManager:
Args: Args:
subheartflow_id: 子心流唯一标识符 subheartflow_id: 子心流唯一标识符
# mai_states 参数已被移除,使用 self.mai_state_info mai_states 参数已被移除,使用 self.mai_state_info
Returns: Returns:
成功返回SubHeartflow实例失败返回None 成功返回SubHeartflow实例失败返回None
@@ -174,8 +185,7 @@ class SubHeartflowManager:
continue continue
subheartflow.update_last_chat_state_time() subheartflow.update_last_chat_state_time()
absent_last_time = subheartflow.chat_state_last_time absent_last_time = subheartflow.chat_state_last_time
if max_age_seconds and (current_time - absent_last_time) > max_age_seconds: flows_to_stop.append(subheartflow_id)
flows_to_stop.append(subheartflow_id)
return flows_to_stop return flows_to_stop