feat: 添加获取子心流循环信息和所有状态的API接口
This commit is contained in:
@@ -17,3 +17,20 @@ async def forced_change_subheartflow_status(subheartflow_id: str, status: ChatSt
|
|||||||
if subheartflow:
|
if subheartflow:
|
||||||
return await heartflow.force_change_subheartflow_status(subheartflow_id, status)
|
return await heartflow.force_change_subheartflow_status(subheartflow_id, status)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
async def get_subheartflow_cycle_info(subheartflow_id: str, history_len: int) -> dict:
|
||||||
|
"""获取子心流的循环信息"""
|
||||||
|
subheartflow_cycle_info = await heartflow.api_get_subheartflow_cycle_info(subheartflow_id, history_len)
|
||||||
|
logger.debug(f"子心流 {subheartflow_id} 循环信息: {subheartflow_cycle_info}")
|
||||||
|
if subheartflow_cycle_info:
|
||||||
|
return subheartflow_cycle_info
|
||||||
|
else:
|
||||||
|
logger.warning(f"子心流 {subheartflow_id} 循环信息未找到")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_all_states():
|
||||||
|
"""获取所有状态"""
|
||||||
|
all_states = await heartflow.api_get_all_states()
|
||||||
|
logger.debug(f"所有状态: {all_states}")
|
||||||
|
return all_states
|
||||||
|
|||||||
@@ -2,13 +2,18 @@ from fastapi import APIRouter
|
|||||||
from strawberry.fastapi import GraphQLRouter
|
from strawberry.fastapi import GraphQLRouter
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
# from src.heart_flow.heartflow import heartflow
|
||||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
|
||||||
# from src.config.config import BotConfig
|
# from src.config.config import BotConfig
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
from src.api.reload_config import reload_config as reload_config_func
|
from src.api.reload_config import reload_config as reload_config_func
|
||||||
from src.common.server import global_server
|
from src.common.server import global_server
|
||||||
from src.api.apiforgui import get_all_subheartflow_ids, forced_change_subheartflow_status
|
from src.api.apiforgui import (
|
||||||
|
get_all_subheartflow_ids,
|
||||||
|
forced_change_subheartflow_status,
|
||||||
|
get_subheartflow_cycle_info,
|
||||||
|
get_all_states,
|
||||||
|
)
|
||||||
from src.heart_flow.sub_heartflow import ChatState
|
from src.heart_flow.sub_heartflow import ChatState
|
||||||
|
|
||||||
# import uvicorn
|
# import uvicorn
|
||||||
@@ -67,7 +72,26 @@ async def force_stop_maibot():
|
|||||||
else:
|
else:
|
||||||
logger.error("MAI Bot强制停止失败")
|
logger.error("MAI Bot强制停止失败")
|
||||||
return {"status": "failed"}
|
return {"status": "failed"}
|
||||||
|
|
||||||
|
@router.get("/gui/subheartflow/cycleinfo")
|
||||||
|
async def get_subheartflow_cycle_info_api(subheartflow_id: str, history_len: int):
|
||||||
|
"""获取子心流的循环信息"""
|
||||||
|
cycle_info = await get_subheartflow_cycle_info(subheartflow_id, history_len)
|
||||||
|
if cycle_info:
|
||||||
|
return {"status": "success", "data": cycle_info}
|
||||||
|
else:
|
||||||
|
logger.warning(f"子心流 {subheartflow_id} 循环信息未找到")
|
||||||
|
return {"status": "failed", "reason": "subheartflow not found"}
|
||||||
|
|
||||||
|
@router.get("/gui/get_all_states")
|
||||||
|
async def get_all_states_api():
|
||||||
|
"""获取所有状态"""
|
||||||
|
all_states = await get_all_states()
|
||||||
|
if all_states:
|
||||||
|
return {"status": "success", "data": all_states}
|
||||||
|
else:
|
||||||
|
logger.warning("获取所有状态失败")
|
||||||
|
return {"status": "failed", "reason": "failed to get all states"}
|
||||||
|
|
||||||
def start_api_server():
|
def start_api_server():
|
||||||
"""启动API服务器"""
|
"""启动API服务器"""
|
||||||
|
|||||||
@@ -66,6 +66,24 @@ class Heartflow:
|
|||||||
"""强制改变子心流的状态"""
|
"""强制改变子心流的状态"""
|
||||||
# 这里的 message 是可选的,可能是一个消息对象,也可能是其他类型的数据
|
# 这里的 message 是可选的,可能是一个消息对象,也可能是其他类型的数据
|
||||||
return await self.subheartflow_manager.force_change_state(subheartflow_id, status)
|
return await self.subheartflow_manager.force_change_state(subheartflow_id, status)
|
||||||
|
|
||||||
|
async def api_get_all_states(self):
|
||||||
|
"""获取所有状态"""
|
||||||
|
return await self.interest_logger.api_get_all_states()
|
||||||
|
|
||||||
|
|
||||||
|
async def api_get_subheartflow_cycle_info(self, subheartflow_id: str, history_len: int) -> Optional[dict]:
|
||||||
|
"""获取子心流的循环信息"""
|
||||||
|
subheartflow = await self.subheartflow_manager.get_or_create_subheartflow(subheartflow_id)
|
||||||
|
if not subheartflow:
|
||||||
|
logger.warning(f"尝试获取不存在的子心流 {subheartflow_id} 的周期信息")
|
||||||
|
return None
|
||||||
|
heartfc_instance = subheartflow.heart_fc_instance
|
||||||
|
if not heartfc_instance:
|
||||||
|
logger.warning(f"子心流 {subheartflow_id} 没有心流实例,无法获取周期信息")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return heartfc_instance.get_cycle_history(last_n=history_len)
|
||||||
|
|
||||||
async def heartflow_start_working(self):
|
async def heartflow_start_working(self):
|
||||||
"""启动后台任务"""
|
"""启动后台任务"""
|
||||||
|
|||||||
@@ -158,3 +158,55 @@ class InterestLogger:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"记录状态时发生意外错误: {e}")
|
logger.error(f"记录状态时发生意外错误: {e}")
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
|
async def api_get_all_states(self):
|
||||||
|
"""获取主心流和所有子心流的状态。"""
|
||||||
|
try:
|
||||||
|
current_timestamp = time.time()
|
||||||
|
|
||||||
|
# main_mind = self.heartflow.current_mind
|
||||||
|
# 获取 Mai 状态名称
|
||||||
|
mai_state_name = self.heartflow.current_state.get_current_state().name
|
||||||
|
|
||||||
|
all_subflow_states = await self.get_all_subflow_states()
|
||||||
|
|
||||||
|
log_entry_base = {
|
||||||
|
"timestamp": round(current_timestamp, 2),
|
||||||
|
# "main_mind": main_mind,
|
||||||
|
"mai_state": mai_state_name,
|
||||||
|
"subflow_count": len(all_subflow_states),
|
||||||
|
"subflows": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
subflow_details = []
|
||||||
|
items_snapshot = list(all_subflow_states.items())
|
||||||
|
for stream_id, state in items_snapshot:
|
||||||
|
group_name = stream_id
|
||||||
|
try:
|
||||||
|
chat_stream = chat_manager.get_stream(stream_id)
|
||||||
|
if chat_stream:
|
||||||
|
if chat_stream.group_info:
|
||||||
|
group_name = chat_stream.group_info.group_name
|
||||||
|
elif chat_stream.user_info:
|
||||||
|
group_name = f"私聊_{chat_stream.user_info.user_nickname}"
|
||||||
|
except Exception as e:
|
||||||
|
logger.trace(f"无法获取 stream_id {stream_id} 的群组名: {e}")
|
||||||
|
|
||||||
|
interest_state = state.get("interest_state", {})
|
||||||
|
|
||||||
|
subflow_entry = {
|
||||||
|
"stream_id": stream_id,
|
||||||
|
"group_name": group_name,
|
||||||
|
"sub_mind": state.get("current_mind", "未知"),
|
||||||
|
"sub_chat_state": state.get("chat_state", "未知"),
|
||||||
|
"interest_level": interest_state.get("interest_level", 0.0),
|
||||||
|
"start_hfc_probability": interest_state.get("start_hfc_probability", 0.0),
|
||||||
|
# "is_above_threshold": interest_state.get("is_above_threshold", False),
|
||||||
|
}
|
||||||
|
subflow_details.append(subflow_entry)
|
||||||
|
|
||||||
|
log_entry_base["subflows"] = subflow_details
|
||||||
|
return subflow_details
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"记录状态时发生意外错误: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
Reference in New Issue
Block a user