From 1e2cdeeea536bdee8212a7bb8de837fef940c2a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Wed, 7 May 2025 00:21:04 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E5=BC=BA=E5=88=B6?= =?UTF-8?q?=E5=81=9C=E6=AD=A2MAI=20Bot=E7=9A=84API=E6=8E=A5=E5=8F=A3(?= =?UTF-8?q?=E5=8D=8A=E6=88=90=E5=93=81)=EF=BC=8C=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E5=B5=8C=E5=85=A5=E6=95=B0=E6=8D=AE=E7=9B=AE=E5=BD=95=E8=B7=AF?= =?UTF-8?q?=E5=BE=84=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bot.py | 23 +++++++++++++++++++- src/api/apiforgui.py | 2 ++ src/api/main.py | 18 +++++++++++++-- src/plugins/knowledge/src/embedding_store.py | 8 ++++--- src/plugins/knowledge/src/kg_manager.py | 7 +++--- 5 files changed, 49 insertions(+), 9 deletions(-) diff --git a/bot.py b/bot.py index 010416294..e3c3cb5d4 100644 --- a/bot.py +++ b/bot.py @@ -33,6 +33,22 @@ driver = None app = None loop = None +# shutdown_requested = False # 新增全局变量 + +async def request_shutdown() -> bool: + """请求关闭程序""" + try: + if loop and not loop.is_closed(): + try: + loop.run_until_complete(graceful_shutdown()) + except Exception as ge: # 捕捉优雅关闭时可能发生的错误 + logger.error(f"优雅关闭时发生错误: {ge}") + return False + return True + except Exception as e: + logger.error(f"请求关闭程序时发生错误: {e}") + return False + def easter_egg(): # 彩蛋 @@ -230,6 +246,9 @@ def raw_main(): return MainSystem() + + + if __name__ == "__main__": exit_code = 0 # 用于记录程序最终的退出状态 try: @@ -252,6 +271,8 @@ if __name__ == "__main__": loop.run_until_complete(graceful_shutdown()) except Exception as ge: # 捕捉优雅关闭时可能发生的错误 logger.error(f"优雅关闭时发生错误: {ge}") + # 新增:检测外部请求关闭 + # except Exception as e: # 将主异常捕获移到外层 try...except # logger.error(f"事件循环内发生错误: {str(e)} {str(traceback.format_exc())}") # exit_code = 1 @@ -271,5 +292,5 @@ if __name__ == "__main__": loop.close() logger.info("事件循环已关闭") # 在程序退出前暂停,让你有机会看到输出 - input("按 Enter 键退出...") # <--- 添加这行 + # input("按 Enter 键退出...") # <--- 添加这行 sys.exit(exit_code) # <--- 使用记录的退出码 diff --git a/src/api/apiforgui.py b/src/api/apiforgui.py index 75ef2f8d1..7e2460b05 100644 --- a/src/api/apiforgui.py +++ b/src/api/apiforgui.py @@ -1,5 +1,7 @@ from src.heart_flow.heartflow import heartflow from src.heart_flow.sub_heartflow import ChatState +from src.common.logger_manager import get_logger +logger = get_logger("api") async def get_all_subheartflow_ids() -> list: diff --git a/src/api/main.py b/src/api/main.py index 6d7e3c1e2..a39dafd5a 100644 --- a/src/api/main.py +++ b/src/api/main.py @@ -1,12 +1,15 @@ from fastapi import APIRouter from strawberry.fastapi import GraphQLRouter - +import os +import sys +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) # from src.config.config import BotConfig from src.common.logger_manager import get_logger from src.api.reload_config import reload_config as reload_config_func from src.common.server import global_server -from .apiforgui import get_all_subheartflow_ids, forced_change_subheartflow_status +from src.api.apiforgui import get_all_subheartflow_ids, forced_change_subheartflow_status from src.heart_flow.sub_heartflow import ChatState + # import uvicorn # import os @@ -50,6 +53,17 @@ async def forced_change_subheartflow_status_api(subheartflow_id: str, status: Ch logger.error(f"子心流 {subheartflow_id} 状态更改为 {status.value} 失败") return {"status": "failed"} +@router.get("/stop") +async def force_stop_maibot(): + """强制停止MAI Bot""" + from bot import request_shutdown + success = await request_shutdown() + if success: + logger.info("MAI Bot已强制停止") + return {"status": "success"} + else: + logger.error("MAI Bot强制停止失败") + return {"status": "failed"} def start_api_server(): """启动API服务器""" diff --git a/src/plugins/knowledge/src/embedding_store.py b/src/plugins/knowledge/src/embedding_store.py index d1eb7f90f..2a27c5396 100644 --- a/src/plugins/knowledge/src/embedding_store.py +++ b/src/plugins/knowledge/src/embedding_store.py @@ -28,6 +28,8 @@ from rich.progress import ( install(extra_lines=3) ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..")) +EMBEDDING_DATA_DIR = os.path.join(ROOT_PATH, "data", "embedding") if global_config["persistence"]["embedding_data_dir"] is None else os.path.join(ROOT_PATH, global_config["persistence"]["embedding_data_dir"]) +EMBEDDING_DATA_DIR_STR = str(EMBEDDING_DATA_DIR).replace("\\", "/") TOTAL_EMBEDDING_TIMES = 3 # 统计嵌入次数 # 嵌入模型测试字符串,测试模型一致性,来自开发群的聊天记录 @@ -288,17 +290,17 @@ class EmbeddingManager: self.paragraphs_embedding_store = EmbeddingStore( llm_client, PG_NAMESPACE, - global_config["persistence"]["embedding_data_dir"], + EMBEDDING_DATA_DIR_STR, ) self.entities_embedding_store = EmbeddingStore( llm_client, ENT_NAMESPACE, - global_config["persistence"]["embedding_data_dir"], + EMBEDDING_DATA_DIR_STR, ) self.relation_embedding_store = EmbeddingStore( llm_client, REL_NAMESPACE, - global_config["persistence"]["embedding_data_dir"], + EMBEDDING_DATA_DIR_STR, ) self.stored_pg_hashes = set() diff --git a/src/plugins/knowledge/src/kg_manager.py b/src/plugins/knowledge/src/kg_manager.py index fd922af48..19403f9ba 100644 --- a/src/plugins/knowledge/src/kg_manager.py +++ b/src/plugins/knowledge/src/kg_manager.py @@ -30,8 +30,9 @@ from .lpmmconfig import ( ) from .global_logger import logger - - +ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..")) +KG_DIR = os.path.join(ROOT_PATH, "data/rag") if global_config["persistence"]["rag_data_dir"] is None else os.path.join(ROOT_PATH, global_config["persistence"]["rag_data_dir"]) +KG_DIR_STR = str(KG_DIR).replace("\\", "/") class KGManager: def __init__(self): # 会被保存的字段 @@ -43,7 +44,7 @@ class KGManager: self.graph = di_graph.DiGraph() # 持久化相关 - self.dir_path = global_config["persistence"]["rag_data_dir"] + self.dir_path = KG_DIR_STR self.graph_data_path = self.dir_path + "/" + RAG_GRAPH_NAMESPACE + ".graphml" self.ent_cnt_data_path = self.dir_path + "/" + RAG_ENT_CNT_NAMESPACE + ".parquet" self.pg_hash_file_path = self.dir_path + "/" + RAG_PG_HASH_NAMESPACE + ".json"