From 162dc49acd76e3abeec86c38054a6310de80f4ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Wed, 7 May 2025 22:08:16 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E8=8E=B7=E5=8F=96?= =?UTF-8?q?=E5=AD=90=E5=BF=83=E6=B5=81=E5=BE=AA=E7=8E=AF=E4=BF=A1=E6=81=AF?= =?UTF-8?q?=E5=92=8C=E6=89=80=E6=9C=89=E7=8A=B6=E6=80=81=E7=9A=84API?= =?UTF-8?q?=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/api/apiforgui.py | 17 ++++++++++ src/api/main.py | 28 +++++++++++++++-- src/heart_flow/heartflow.py | 18 +++++++++++ src/heart_flow/interest_logger.py | 52 +++++++++++++++++++++++++++++++ 4 files changed, 113 insertions(+), 2 deletions(-) diff --git a/src/api/apiforgui.py b/src/api/apiforgui.py index a8027c480..1860aef7d 100644 --- a/src/api/apiforgui.py +++ b/src/api/apiforgui.py @@ -17,3 +17,20 @@ async def forced_change_subheartflow_status(subheartflow_id: str, status: ChatSt if subheartflow: return await heartflow.force_change_subheartflow_status(subheartflow_id, status) return False + +async def get_subheartflow_cycle_info(subheartflow_id: str, history_len: int) -> dict: + """获取子心流的循环信息""" + subheartflow_cycle_info = await heartflow.api_get_subheartflow_cycle_info(subheartflow_id, history_len) + logger.debug(f"子心流 {subheartflow_id} 循环信息: {subheartflow_cycle_info}") + if subheartflow_cycle_info: + return subheartflow_cycle_info + else: + logger.warning(f"子心流 {subheartflow_id} 循环信息未找到") + return None + + +async def get_all_states(): + """获取所有状态""" + all_states = await heartflow.api_get_all_states() + logger.debug(f"所有状态: {all_states}") + return all_states diff --git a/src/api/main.py b/src/api/main.py index 1f47a57cb..f5d299d85 100644 --- a/src/api/main.py +++ b/src/api/main.py @@ -2,13 +2,18 @@ from fastapi import APIRouter from strawberry.fastapi import GraphQLRouter import os import sys - +# from src.heart_flow.heartflow import heartflow sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) # from src.config.config import BotConfig from src.common.logger_manager import get_logger from src.api.reload_config import reload_config as reload_config_func from src.common.server import global_server -from src.api.apiforgui import get_all_subheartflow_ids, forced_change_subheartflow_status +from src.api.apiforgui import ( + get_all_subheartflow_ids, + forced_change_subheartflow_status, + get_subheartflow_cycle_info, + get_all_states, +) from src.heart_flow.sub_heartflow import ChatState # import uvicorn @@ -67,7 +72,26 @@ async def force_stop_maibot(): else: logger.error("MAI Bot强制停止失败") return {"status": "failed"} + +@router.get("/gui/subheartflow/cycleinfo") +async def get_subheartflow_cycle_info_api(subheartflow_id: str, history_len: int): + """获取子心流的循环信息""" + cycle_info = await get_subheartflow_cycle_info(subheartflow_id, history_len) + if cycle_info: + return {"status": "success", "data": cycle_info} + else: + logger.warning(f"子心流 {subheartflow_id} 循环信息未找到") + return {"status": "failed", "reason": "subheartflow not found"} +@router.get("/gui/get_all_states") +async def get_all_states_api(): + """获取所有状态""" + all_states = await get_all_states() + if all_states: + return {"status": "success", "data": all_states} + else: + logger.warning("获取所有状态失败") + return {"status": "failed", "reason": "failed to get all states"} def start_api_server(): """启动API服务器""" diff --git a/src/heart_flow/heartflow.py b/src/heart_flow/heartflow.py index 894247ce4..dd58f5cdf 100644 --- a/src/heart_flow/heartflow.py +++ b/src/heart_flow/heartflow.py @@ -66,6 +66,24 @@ class Heartflow: """强制改变子心流的状态""" # 这里的 message 是可选的,可能是一个消息对象,也可能是其他类型的数据 return await self.subheartflow_manager.force_change_state(subheartflow_id, status) + + async def api_get_all_states(self): + """获取所有状态""" + return await self.interest_logger.api_get_all_states() + + + async def api_get_subheartflow_cycle_info(self, subheartflow_id: str, history_len: int) -> Optional[dict]: + """获取子心流的循环信息""" + subheartflow = await self.subheartflow_manager.get_or_create_subheartflow(subheartflow_id) + if not subheartflow: + logger.warning(f"尝试获取不存在的子心流 {subheartflow_id} 的周期信息") + return None + heartfc_instance = subheartflow.heart_fc_instance + if not heartfc_instance: + logger.warning(f"子心流 {subheartflow_id} 没有心流实例,无法获取周期信息") + return None + + return heartfc_instance.get_cycle_history(last_n=history_len) async def heartflow_start_working(self): """启动后台任务""" diff --git a/src/heart_flow/interest_logger.py b/src/heart_flow/interest_logger.py index 1fe289b89..9b5621569 100644 --- a/src/heart_flow/interest_logger.py +++ b/src/heart_flow/interest_logger.py @@ -158,3 +158,55 @@ class InterestLogger: except Exception as e: logger.error(f"记录状态时发生意外错误: {e}") logger.error(traceback.format_exc()) + + async def api_get_all_states(self): + """获取主心流和所有子心流的状态。""" + try: + current_timestamp = time.time() + + # main_mind = self.heartflow.current_mind + # 获取 Mai 状态名称 + mai_state_name = self.heartflow.current_state.get_current_state().name + + all_subflow_states = await self.get_all_subflow_states() + + log_entry_base = { + "timestamp": round(current_timestamp, 2), + # "main_mind": main_mind, + "mai_state": mai_state_name, + "subflow_count": len(all_subflow_states), + "subflows": [], + } + + subflow_details = [] + items_snapshot = list(all_subflow_states.items()) + for stream_id, state in items_snapshot: + group_name = stream_id + try: + chat_stream = chat_manager.get_stream(stream_id) + if chat_stream: + if chat_stream.group_info: + group_name = chat_stream.group_info.group_name + elif chat_stream.user_info: + group_name = f"私聊_{chat_stream.user_info.user_nickname}" + except Exception as e: + logger.trace(f"无法获取 stream_id {stream_id} 的群组名: {e}") + + interest_state = state.get("interest_state", {}) + + subflow_entry = { + "stream_id": stream_id, + "group_name": group_name, + "sub_mind": state.get("current_mind", "未知"), + "sub_chat_state": state.get("chat_state", "未知"), + "interest_level": interest_state.get("interest_level", 0.0), + "start_hfc_probability": interest_state.get("start_hfc_probability", 0.0), + # "is_above_threshold": interest_state.get("is_above_threshold", False), + } + subflow_details.append(subflow_entry) + + log_entry_base["subflows"] = subflow_details + return subflow_details + except Exception as e: + logger.error(f"记录状态时发生意外错误: {e}") + logger.error(traceback.format_exc()) \ No newline at end of file