feat: 添加获取所有子心流ID和强制改变子心流状态的API接口
This commit is contained in:
15
src/api/apiforgui.py
Normal file
15
src/api/apiforgui.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
from src.heart_flow.heartflow import heartflow
|
||||||
|
from src.heart_flow.sub_heartflow import ChatState
|
||||||
|
|
||||||
|
async def get_all_subheartflow_ids() -> list:
|
||||||
|
"""获取所有子心流的ID列表"""
|
||||||
|
all_subheartflows = heartflow.subheartflow_manager.get_all_subheartflows()
|
||||||
|
return [subheartflow.subheartflow_id for subheartflow in all_subheartflows]
|
||||||
|
|
||||||
|
async def forced_change_subheartflow_status(subheartflow_id: str, status: ChatState) -> bool:
|
||||||
|
"""强制改变子心流的状态"""
|
||||||
|
subheartflow = await heartflow.get_or_create_subheartflow(subheartflow_id)
|
||||||
|
if subheartflow:
|
||||||
|
return await heartflow.force_change_subheartflow_status(subheartflow_id, status)
|
||||||
|
return False
|
||||||
|
|
||||||
@@ -5,9 +5,12 @@ from strawberry.fastapi import GraphQLRouter
|
|||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
from src.api.reload_config import reload_config as reload_config_func
|
from src.api.reload_config import reload_config as reload_config_func
|
||||||
from src.common.server import global_server
|
from src.common.server import global_server
|
||||||
|
from .apiforgui import get_all_subheartflow_ids, forced_change_subheartflow_status
|
||||||
|
from src.heart_flow.sub_heartflow import ChatState
|
||||||
# import uvicorn
|
# import uvicorn
|
||||||
# import os
|
# import os
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@@ -24,6 +27,28 @@ router.include_router(graphql_router, prefix="/graphql", tags=["GraphQL"])
|
|||||||
async def reload_config():
|
async def reload_config():
|
||||||
return await reload_config_func()
|
return await reload_config_func()
|
||||||
|
|
||||||
|
@router.get("/gui/subheartflow/get/all")
|
||||||
|
async def get_subheartflow_ids():
|
||||||
|
"""获取所有子心流的ID列表"""
|
||||||
|
return await get_all_subheartflow_ids()
|
||||||
|
|
||||||
|
@router.post("/gui/subheartflow/forced_change_status")
|
||||||
|
async def forced_change_subheartflow_status_api(subheartflow_id: str, status: ChatState): #noqa
|
||||||
|
"""强制改变子心流的状态"""
|
||||||
|
# 参数检查
|
||||||
|
if not isinstance(status, ChatState):
|
||||||
|
logger.warning(f"无效的状态参数: {status}")
|
||||||
|
return {"status": "failed", "reason": "invalid status"}
|
||||||
|
logger.info(f"尝试将子心流 {subheartflow_id} 状态更改为 {status.value}")
|
||||||
|
success = await forced_change_subheartflow_status(subheartflow_id, status)
|
||||||
|
if success:
|
||||||
|
logger.info(f"子心流 {subheartflow_id} 状态更改为 {status.value} 成功")
|
||||||
|
return {"status": "success"}
|
||||||
|
else:
|
||||||
|
logger.error(f"子心流 {subheartflow_id} 状态更改为 {status.value} 失败")
|
||||||
|
return {"status": "failed"}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def start_api_server():
|
def start_api_server():
|
||||||
"""启动API服务器"""
|
"""启动API服务器"""
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from src.heart_flow.sub_heartflow import SubHeartflow
|
from src.heart_flow.sub_heartflow import SubHeartflow, ChatState
|
||||||
from src.plugins.models.utils_model import LLMRequest
|
from src.plugins.models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.plugins.schedule.schedule_generator import bot_schedule
|
from src.plugins.schedule.schedule_generator import bot_schedule
|
||||||
@@ -62,6 +62,13 @@ class Heartflow:
|
|||||||
# 不再需要传入 self.current_state
|
# 不再需要传入 self.current_state
|
||||||
return await self.subheartflow_manager.get_or_create_subheartflow(subheartflow_id)
|
return await self.subheartflow_manager.get_or_create_subheartflow(subheartflow_id)
|
||||||
|
|
||||||
|
async def force_change_subheartflow_status(
|
||||||
|
self, subheartflow_id: str, status: ChatState
|
||||||
|
) -> None:
|
||||||
|
"""强制改变子心流的状态"""
|
||||||
|
# 这里的 message 是可选的,可能是一个消息对象,也可能是其他类型的数据
|
||||||
|
return await self.subheartflow_manager.force_change_state(subheartflow_id, status)
|
||||||
|
|
||||||
async def heartflow_start_working(self):
|
async def heartflow_start_working(self):
|
||||||
"""启动后台任务"""
|
"""启动后台任务"""
|
||||||
await self.background_task_manager.start_tasks()
|
await self.background_task_manager.start_tasks()
|
||||||
|
|||||||
@@ -83,6 +83,17 @@ class SubHeartflowManager:
|
|||||||
request_type="subheartflow_state_eval", # 保留特定的请求类型
|
request_type="subheartflow_state_eval", # 保留特定的请求类型
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def force_change_state(self, subflow_id: Any, target_state: ChatState) -> bool:
|
||||||
|
"""强制改变指定子心流的状态"""
|
||||||
|
async with self._lock:
|
||||||
|
subflow = self.subheartflows.get(subflow_id)
|
||||||
|
if not subflow:
|
||||||
|
logger.warning(f"[强制状态转换]尝试转换不存在的子心流{subflow_id} 到 {target_state.value}")
|
||||||
|
return False
|
||||||
|
await subflow.change_chat_state(target_state)
|
||||||
|
logger.info(f"[强制状态转换]子心流 {subflow_id} 已转换到 {target_state.value}")
|
||||||
|
return True
|
||||||
|
|
||||||
def get_all_subheartflows(self) -> List["SubHeartflow"]:
|
def get_all_subheartflows(self) -> List["SubHeartflow"]:
|
||||||
"""获取所有当前管理的 SubHeartflow 实例列表 (快照)。"""
|
"""获取所有当前管理的 SubHeartflow 实例列表 (快照)。"""
|
||||||
return list(self.subheartflows.values())
|
return list(self.subheartflows.values())
|
||||||
@@ -92,7 +103,7 @@ class SubHeartflowManager:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
subheartflow_id: 子心流唯一标识符
|
subheartflow_id: 子心流唯一标识符
|
||||||
# mai_states 参数已被移除,使用 self.mai_state_info
|
mai_states 参数已被移除,使用 self.mai_state_info
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
成功返回SubHeartflow实例,失败返回None
|
成功返回SubHeartflow实例,失败返回None
|
||||||
@@ -174,7 +185,6 @@ class SubHeartflowManager:
|
|||||||
continue
|
continue
|
||||||
subheartflow.update_last_chat_state_time()
|
subheartflow.update_last_chat_state_time()
|
||||||
absent_last_time = subheartflow.chat_state_last_time
|
absent_last_time = subheartflow.chat_state_last_time
|
||||||
if max_age_seconds and (current_time - absent_last_time) > max_age_seconds:
|
|
||||||
flows_to_stop.append(subheartflow_id)
|
flows_to_stop.append(subheartflow_id)
|
||||||
|
|
||||||
return flows_to_stop
|
return flows_to_stop
|
||||||
|
|||||||
Reference in New Issue
Block a user