From b8d14add91a957999332490b48f88cf40edd8981 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Sun, 4 May 2025 00:32:10 +0800 Subject: [PATCH 01/11] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E5=B5=8C?= =?UTF-8?q?=E5=85=A5=E6=A8=A1=E5=9E=8B=E4=B8=80=E8=87=B4=E6=80=A7=E6=A0=A1?= =?UTF-8?q?=E9=AA=8C=E5=8A=9F=E8=83=BD=EF=BC=8C=E4=BC=98=E5=8C=96=E9=94=99?= =?UTF-8?q?=E8=AF=AF=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/import_openie.py | 7 ++ src/plugins/knowledge/src/embedding_store.py | 88 +++++++++++++++++++- 2 files changed, 94 insertions(+), 1 deletion(-) diff --git a/scripts/import_openie.py b/scripts/import_openie.py index 2a6e09b73..b7fc9f307 100644 --- a/scripts/import_openie.py +++ b/scripts/import_openie.py @@ -30,6 +30,7 @@ OPENIE_DIR = ( logger = get_module_logger("OpenIE导入") + def hash_deduplicate( raw_paragraphs: dict[str, str], triple_list_data: dict[str, list[list[str]]], @@ -167,6 +168,7 @@ def main(): global_config["llm_providers"][key]["api_key"], ) + # 初始化Embedding库 embed_manager = EmbeddingManager(llm_client_list[global_config["embedding"]["provider"]]) logger.info("正在从文件加载Embedding库") @@ -174,6 +176,11 @@ def main(): embed_manager.load_from_file() except Exception as e: logger.error("从文件加载Embedding库时发生错误:{}".format(e)) + if "嵌入模型与本地存储不一致" in str(e): + logger.error("检测到嵌入模型与本地存储不一致,已终止导入。请检查模型设置或清空嵌入库后重试。") + logger.error("请保证你的嵌入模型从未更改,并且在导入时使用相同的模型") + # print("检测到嵌入模型与本地存储不一致,已终止导入。请检查模型设置或清空嵌入库后重试。") + sys.exit(1) logger.error("如果你是第一次导入知识,请忽略此错误") logger.info("Embedding库加载完成") # 初始化KG diff --git a/src/plugins/knowledge/src/embedding_store.py b/src/plugins/knowledge/src/embedding_store.py index e734f4e9a..c68886fd0 100644 --- a/src/plugins/knowledge/src/embedding_store.py +++ b/src/plugins/knowledge/src/embedding_store.py @@ -1,6 +1,7 @@ from dataclasses import dataclass import json import os +import math from typing import Dict, List, Tuple import numpy as np @@ -25,9 +26,39 @@ from rich.progress import ( ) install(extra_lines=3) - +ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..")) TOTAL_EMBEDDING_TIMES = 3 # 统计嵌入次数 +# 嵌入模型测试字符串,测试模型一致性,来自开发群的聊天记录 +# 这些字符串的嵌入结果应该是固定的,不能随时间变化 +EMBEDDING_TEST_STRINGS = [ + "阿卡伊真的太好玩了,神秘性感大女同等着你", + "你怎么知道我arc12.64了", + "我是蕾缪乐小姐的狗", + "关注Oct谢谢喵", + "不是w6我不草", + "关注千石可乐谢谢喵", + "来玩CLANNAD,AIR,樱之诗,樱之刻谢谢喵", + "关注墨梓柒谢谢喵", + "Ciallo~", + "来玩巧克甜恋谢谢喵", + "水印", + "我也在纠结晚饭,铁锅炒鸡听着就香!", + "test你妈喵" +] +EMBEDDING_TEST_FILE = os.path.join(ROOT_PATH, "data", "embedding_model_test.json") +EMBEDDING_SIM_THRESHOLD = 0.99 + + +def cosine_similarity(a, b): + # 计算余弦相似度 + dot = sum(x * y for x, y in zip(a, b)) + norm_a = math.sqrt(sum(x * x for x in a)) + norm_b = math.sqrt(sum(x * x for x in b)) + if norm_a == 0 or norm_b == 0: + return 0.0 + return dot / (norm_a * norm_b) + @dataclass class EmbeddingStoreItem: @@ -64,6 +95,46 @@ class EmbeddingStore: def _get_embedding(self, s: str) -> List[float]: return self.llm_client.send_embedding_request(global_config["embedding"]["model"], s) + def get_test_file_path(self): + return EMBEDDING_TEST_FILE + + def save_embedding_test_vectors(self): + """保存测试字符串的嵌入到本地""" + test_vectors = {} + for idx, s in enumerate(EMBEDDING_TEST_STRINGS): + test_vectors[str(idx)] = self._get_embedding(s) + with open(self.get_test_file_path(), "w", encoding="utf-8") as f: + json.dump(test_vectors, f, ensure_ascii=False, indent=2) + + def load_embedding_test_vectors(self): + """加载本地保存的测试字符串嵌入""" + path = self.get_test_file_path() + if not os.path.exists(path): + return None + with open(path, "r", encoding="utf-8") as f: + return json.load(f) + + def check_embedding_model_consistency(self): + """校验当前模型与本地嵌入模型是否一致""" + local_vectors = self.load_embedding_test_vectors() + if local_vectors is None: + logger.warning("未检测到本地嵌入模型测试文件,将保存当前模型的测试嵌入。") + self.save_embedding_test_vectors() + return True + for idx, s in enumerate(EMBEDDING_TEST_STRINGS): + local_emb = local_vectors.get(str(idx)) + if local_emb is None: + logger.warning("本地嵌入模型测试文件缺失部分测试字符串,将重新保存。") + self.save_embedding_test_vectors() + return True + new_emb = self._get_embedding(s) + sim = cosine_similarity(local_emb, new_emb) + if sim < EMBEDDING_SIM_THRESHOLD: + logger.error("嵌入模型一致性校验失败") + return False + logger.info("嵌入模型一致性校验通过。") + return True + def batch_insert_strs(self, strs: List[str], times: int) -> None: """向库中存入字符串""" total = len(strs) @@ -216,6 +287,17 @@ class EmbeddingManager: ) self.stored_pg_hashes = set() + def check_all_embedding_model_consistency(self): + """对所有嵌入库做模型一致性校验""" + for store in [ + self.paragraphs_embedding_store, + self.entities_embedding_store, + self.relation_embedding_store, + ]: + if not store.check_embedding_model_consistency(): + return False + return True + def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]): """将段落编码存入Embedding库""" self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values()), times=1) @@ -239,6 +321,8 @@ class EmbeddingManager: def load_from_file(self): """从文件加载""" + if not self.check_all_embedding_model_consistency(): + raise Exception("嵌入模型与本地存储不一致,请检查模型设置或清空嵌入库后重试。") self.paragraphs_embedding_store.load_from_file() self.entities_embedding_store.load_from_file() self.relation_embedding_store.load_from_file() @@ -250,6 +334,8 @@ class EmbeddingManager: raw_paragraphs: Dict[str, str], triple_list_data: Dict[str, List[List[str]]], ): + if not self.check_all_embedding_model_consistency(): + raise Exception("嵌入模型与本地存储不一致,请检查模型设置或清空嵌入库后重试。") """存储新的数据集""" self._store_pg_into_embedding(raw_paragraphs) self._store_ent_into_embedding(triple_list_data) From cea176d63dae115e7ce21d3080b060fd668c121d Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Sat, 3 May 2025 16:32:25 +0000 Subject: [PATCH 02/11] =?UTF-8?q?=F0=9F=A4=96=20=E8=87=AA=E5=8A=A8?= =?UTF-8?q?=E6=A0=BC=E5=BC=8F=E5=8C=96=E4=BB=A3=E7=A0=81=20[skip=20ci]?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/import_openie.py | 2 -- src/plugins/knowledge/src/embedding_store.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/scripts/import_openie.py b/scripts/import_openie.py index b7fc9f307..851cc8b31 100644 --- a/scripts/import_openie.py +++ b/scripts/import_openie.py @@ -30,7 +30,6 @@ OPENIE_DIR = ( logger = get_module_logger("OpenIE导入") - def hash_deduplicate( raw_paragraphs: dict[str, str], triple_list_data: dict[str, list[list[str]]], @@ -168,7 +167,6 @@ def main(): global_config["llm_providers"][key]["api_key"], ) - # 初始化Embedding库 embed_manager = EmbeddingManager(llm_client_list[global_config["embedding"]["provider"]]) logger.info("正在从文件加载Embedding库") diff --git a/src/plugins/knowledge/src/embedding_store.py b/src/plugins/knowledge/src/embedding_store.py index c68886fd0..5ee92a869 100644 --- a/src/plugins/knowledge/src/embedding_store.py +++ b/src/plugins/knowledge/src/embedding_store.py @@ -44,7 +44,7 @@ EMBEDDING_TEST_STRINGS = [ "来玩巧克甜恋谢谢喵", "水印", "我也在纠结晚饭,铁锅炒鸡听着就香!", - "test你妈喵" + "test你妈喵", ] EMBEDDING_TEST_FILE = os.path.join(ROOT_PATH, "data", "embedding_model_test.json") EMBEDDING_SIM_THRESHOLD = 0.99 From 668c9bbad6fc5da798179513a7a17e2cb00db2bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Sun, 4 May 2025 01:41:49 +0800 Subject: [PATCH 03/11] =?UTF-8?q?feat:=20=E9=87=8D=E6=9E=84API=E6=9C=8D?= =?UTF-8?q?=E5=8A=A1=E5=99=A8=EF=BC=8C=E6=B7=BB=E5=8A=A0=E9=85=8D=E7=BD=AE?= =?UTF-8?q?=E9=87=8D=E8=BD=BD=E5=8A=9F=E8=83=BD=E5=B9=B6=E7=A7=BB=E9=99=A4?= =?UTF-8?q?=E5=86=97=E4=BD=99=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/api/__init__.py | 8 -------- src/api/main.py | 28 +++++++++++++++++++++++++++ src/api/reload_config.py | 18 +++++++++++++++++ src/main.py | 5 +++++ src/plugins/config_reload/__init__.py | 1 - src/plugins/config_reload/api.py | 19 ------------------ src/plugins/config_reload/test.py | 4 ---- 7 files changed, 51 insertions(+), 32 deletions(-) create mode 100644 src/api/main.py create mode 100644 src/api/reload_config.py delete mode 100644 src/plugins/config_reload/__init__.py delete mode 100644 src/plugins/config_reload/api.py delete mode 100644 src/plugins/config_reload/test.py diff --git a/src/api/__init__.py b/src/api/__init__.py index f5bc08a6e..e69de29bb 100644 --- a/src/api/__init__.py +++ b/src/api/__init__.py @@ -1,8 +0,0 @@ -from fastapi import FastAPI -from strawberry.fastapi import GraphQLRouter - -app = FastAPI() - -graphql_router = GraphQLRouter(schema=None, path="/") # Replace `None` with your actual schema - -app.include_router(graphql_router, prefix="/graphql", tags=["GraphQL"]) diff --git a/src/api/main.py b/src/api/main.py new file mode 100644 index 000000000..e8d2054f1 --- /dev/null +++ b/src/api/main.py @@ -0,0 +1,28 @@ +from fastapi import APIRouter +from strawberry.fastapi import GraphQLRouter +# from src.config.config import BotConfig +from src.common.logger_manager import get_logger +from src.api.reload_config import reload_config as reload_config_func +from src.common.server import global_server +# import uvicorn +# import os + +router = APIRouter() + + +logger = get_logger("api") + +# maiapi = FastAPI() +logger.info("API server started.") +graphql_router = GraphQLRouter(schema=None, path="/") # Replace `None` with your actual schema + +router.include_router(graphql_router, prefix="/graphql", tags=["GraphQL"]) + +@router.post("/config/reload") +async def reload_config(): + return await reload_config_func() + +def start_api_server(): + """启动API服务器""" + global_server.register_router(router, prefix="/api/v1") + diff --git a/src/api/reload_config.py b/src/api/reload_config.py new file mode 100644 index 000000000..33a02e731 --- /dev/null +++ b/src/api/reload_config.py @@ -0,0 +1,18 @@ +from fastapi import HTTPException +from rich.traceback import install +from src.config.config import BotConfig +import os +install(extra_lines=3) + + + +async def reload_config(): + try: + from src.config import config as config_module + bot_config_path = os.path.join(BotConfig.get_config_dir(), "bot_config.toml") + config_module.global_config = BotConfig.load_config(config_path=bot_config_path) + return {"status": "reloaded"} + except FileNotFoundError as e: + raise HTTPException(status_code=404, detail=str(e)) from e + except Exception as e: + raise HTTPException(status_code=500, detail=f"重载配置时发生错误: {str(e)}") from e \ No newline at end of file diff --git a/src/main.py b/src/main.py index 26a56ca2c..343028429 100644 --- a/src/main.py +++ b/src/main.py @@ -1,5 +1,6 @@ import asyncio import time +import os from .plugins.utils.statistic import LLMStatistics from .plugins.moods.moods import MoodManager from .plugins.schedule.schedule_generator import bot_schedule @@ -18,6 +19,7 @@ from .plugins.remote import heartbeat_thread # noqa: F401 from .individuality.individuality import Individuality from .common.server import global_server from rich.traceback import install +from .api.main import start_api_server install(extra_lines=3) @@ -54,6 +56,9 @@ class MainSystem: self.llm_stats.start() logger.success("LLM统计功能启动成功") + # 启动API服务器 + start_api_server() + logger.success("API服务器启动成功") # 初始化表情管理器 emoji_manager.initialize() logger.success("表情包管理器初始化成功") diff --git a/src/plugins/config_reload/__init__.py b/src/plugins/config_reload/__init__.py deleted file mode 100644 index 8b1378917..000000000 --- a/src/plugins/config_reload/__init__.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/src/plugins/config_reload/api.py b/src/plugins/config_reload/api.py deleted file mode 100644 index 56240b88e..000000000 --- a/src/plugins/config_reload/api.py +++ /dev/null @@ -1,19 +0,0 @@ -from fastapi import APIRouter, HTTPException -from rich.traceback import install - -install(extra_lines=3) - -# 创建APIRouter而不是FastAPI实例 -router = APIRouter() - - -@router.post("/reload-config") -async def reload_config(): - try: # TODO: 实现配置重载 - # bot_config_path = os.path.join(BotConfig.get_config_dir(), "bot_config.toml") - # BotConfig.reload_config(config_path=bot_config_path) - return {"message": "TODO: 实现配置重载", "status": "unimplemented"} - except FileNotFoundError as e: - raise HTTPException(status_code=404, detail=str(e)) from e - except Exception as e: - raise HTTPException(status_code=500, detail=f"重载配置时发生错误: {str(e)}") from e diff --git a/src/plugins/config_reload/test.py b/src/plugins/config_reload/test.py deleted file mode 100644 index fc4fc1e8c..000000000 --- a/src/plugins/config_reload/test.py +++ /dev/null @@ -1,4 +0,0 @@ -import requests - -response = requests.post("http://localhost:8080/api/reload-config") -print(response.json()) From 80ff6e81549413b3fe366fb4ba06bb51552831bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Sun, 4 May 2025 01:43:44 +0800 Subject: [PATCH 04/11] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0API=E6=9C=8D?= =?UTF-8?q?=E5=8A=A1=E5=99=A8=E6=97=A5=E5=BF=97=E6=A0=B7=E5=BC=8F=E9=85=8D?= =?UTF-8?q?=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/common/logger.py | 20 +++++++++++++++++++- src/common/logger_manager.py | 2 ++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/src/common/logger.py b/src/common/logger.py index a82c6d883..432b1bdca 100644 --- a/src/common/logger.py +++ b/src/common/logger.py @@ -808,6 +808,22 @@ INIT_STYLE_CONFIG = { }, } +API_SERVER_STYLE_CONFIG = { + "advanced": { + "console_format": ( + "{time:YYYY-MM-DD HH:mm:ss} | " + "{level: <8} | " + "API服务 | " + "{message}" + ), + "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | API服务 | {message}", + }, + "simple": { + "console_format": "{time:MM-DD HH:mm} | API服务 | {message}", + "file_format": "{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {extra[module]: <15} | API服务 | {message}", + }, +} + # 根据SIMPLE_OUTPUT选择配置 MAIN_STYLE_CONFIG = MAIN_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else MAIN_STYLE_CONFIG["advanced"] @@ -878,7 +894,9 @@ CHAT_MESSAGE_STYLE_CONFIG = ( ) CHAT_IMAGE_STYLE_CONFIG = CHAT_IMAGE_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else CHAT_IMAGE_STYLE_CONFIG["advanced"] INIT_STYLE_CONFIG = INIT_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else INIT_STYLE_CONFIG["advanced"] - +API_SERVER_STYLE_CONFIG = ( + API_SERVER_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else API_SERVER_STYLE_CONFIG["advanced"] +) def is_registered_module(record: dict) -> bool: """检查是否为已注册的模块""" diff --git a/src/common/logger_manager.py b/src/common/logger_manager.py index 5c5538385..4c28f82f8 100644 --- a/src/common/logger_manager.py +++ b/src/common/logger_manager.py @@ -41,6 +41,7 @@ from src.common.logger import ( CHAT_MESSAGE_STYLE_CONFIG, CHAT_IMAGE_STYLE_CONFIG, INIT_STYLE_CONFIG, + API_SERVER_STYLE_CONFIG, ) # 可根据实际需要补充更多模块配置 @@ -86,6 +87,7 @@ MODULE_LOGGER_CONFIGS = { "chat_message": CHAT_MESSAGE_STYLE_CONFIG, # 聊天消息 "chat_image": CHAT_IMAGE_STYLE_CONFIG, # 聊天图片 "init": INIT_STYLE_CONFIG, # 初始化 + "api": API_SERVER_STYLE_CONFIG, # API服务器 # ...如有更多模块,继续添加... } From 96f33ee086e3352fbe2b33af99919cd67dfb5e67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Sun, 4 May 2025 01:45:00 +0800 Subject: [PATCH 05/11] fix: Ruff --- src/main.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/main.py b/src/main.py index 343028429..be71524e2 100644 --- a/src/main.py +++ b/src/main.py @@ -1,6 +1,5 @@ import asyncio import time -import os from .plugins.utils.statistic import LLMStatistics from .plugins.moods.moods import MoodManager from .plugins.schedule.schedule_generator import bot_schedule From aa86387f36ca9fbf8dbc8eddf2f9218bffb9db8c Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Sat, 3 May 2025 17:45:14 +0000 Subject: [PATCH 06/11] =?UTF-8?q?=F0=9F=A4=96=20=E8=87=AA=E5=8A=A8?= =?UTF-8?q?=E6=A0=BC=E5=BC=8F=E5=8C=96=E4=BB=A3=E7=A0=81=20[skip=20ci]?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/api/main.py | 6 ++++-- src/api/reload_config.py | 5 +++-- src/common/logger.py | 5 ++--- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/api/main.py b/src/api/main.py index e8d2054f1..be2259404 100644 --- a/src/api/main.py +++ b/src/api/main.py @@ -1,5 +1,6 @@ from fastapi import APIRouter from strawberry.fastapi import GraphQLRouter + # from src.config.config import BotConfig from src.common.logger_manager import get_logger from src.api.reload_config import reload_config as reload_config_func @@ -18,11 +19,12 @@ graphql_router = GraphQLRouter(schema=None, path="/") # Replace `None` with you router.include_router(graphql_router, prefix="/graphql", tags=["GraphQL"]) + @router.post("/config/reload") async def reload_config(): return await reload_config_func() + def start_api_server(): """启动API服务器""" - global_server.register_router(router, prefix="/api/v1") - + global_server.register_router(router, prefix="/api/v1") diff --git a/src/api/reload_config.py b/src/api/reload_config.py index 33a02e731..d77cb536b 100644 --- a/src/api/reload_config.py +++ b/src/api/reload_config.py @@ -2,17 +2,18 @@ from fastapi import HTTPException from rich.traceback import install from src.config.config import BotConfig import os -install(extra_lines=3) +install(extra_lines=3) async def reload_config(): try: from src.config import config as config_module + bot_config_path = os.path.join(BotConfig.get_config_dir(), "bot_config.toml") config_module.global_config = BotConfig.load_config(config_path=bot_config_path) return {"status": "reloaded"} except FileNotFoundError as e: raise HTTPException(status_code=404, detail=str(e)) from e except Exception as e: - raise HTTPException(status_code=500, detail=f"重载配置时发生错误: {str(e)}") from e \ No newline at end of file + raise HTTPException(status_code=500, detail=f"重载配置时发生错误: {str(e)}") from e diff --git a/src/common/logger.py b/src/common/logger.py index 432b1bdca..88fc427f2 100644 --- a/src/common/logger.py +++ b/src/common/logger.py @@ -894,9 +894,8 @@ CHAT_MESSAGE_STYLE_CONFIG = ( ) CHAT_IMAGE_STYLE_CONFIG = CHAT_IMAGE_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else CHAT_IMAGE_STYLE_CONFIG["advanced"] INIT_STYLE_CONFIG = INIT_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else INIT_STYLE_CONFIG["advanced"] -API_SERVER_STYLE_CONFIG = ( - API_SERVER_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else API_SERVER_STYLE_CONFIG["advanced"] -) +API_SERVER_STYLE_CONFIG = API_SERVER_STYLE_CONFIG["simple"] if SIMPLE_OUTPUT else API_SERVER_STYLE_CONFIG["advanced"] + def is_registered_module(record: dict) -> bool: """检查是否为已注册的模块""" From 88a2b9d2ee6072de8ac3952558cde5fb244f38fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Sun, 4 May 2025 13:43:30 +0800 Subject: [PATCH 07/11] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0API=E6=9C=8D?= =?UTF-8?q?=E5=8A=A1=E5=99=A8=E9=85=8D=E7=BD=AE=E5=92=8CGraphQL=E8=B7=AF?= =?UTF-8?q?=E7=94=B1=EF=BC=8C=E9=87=8D=E8=BD=BD=E9=85=8D=E7=BD=AE=E5=8A=9F?= =?UTF-8?q?=E8=83=BD=EF=BC=8C=E6=9B=B4=E6=96=B0=E6=97=A5=E5=BF=97=E4=BF=A1?= =?UTF-8?q?=E6=81=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/api/config_api.py | 230 +++++++++++--------- src/api/{graphql => maigraphql}/__init__.py | 0 src/api/{graphql => maigraphql}/schema.py | 0 src/api/main.py | 2 +- src/api/reload_config.py | 5 +- 5 files changed, 136 insertions(+), 101 deletions(-) rename src/api/{graphql => maigraphql}/__init__.py (100%) rename src/api/{graphql => maigraphql}/schema.py (100%) diff --git a/src/api/config_api.py b/src/api/config_api.py index 6ecd4e6db..7279ac63a 100644 --- a/src/api/config_api.py +++ b/src/api/config_api.py @@ -1,155 +1,187 @@ -from typing import List, Optional +from typing import List, Optional, Dict, Any import strawberry - -# from packaging.version import Version, InvalidVersion -# from packaging.specifiers import SpecifierSet, InvalidSpecifier -# from ..config.config import global_config -# import os from packaging.version import Version +import os + +ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) @strawberry.type -class BotConfig: +class APIBotConfig: """机器人配置类""" - INNER_VERSION: Version + INNER_VERSION: Version # 配置文件内部版本号 MAI_VERSION: str # 硬编码的版本信息 # bot - BOT_QQ: Optional[int] - BOT_NICKNAME: Optional[str] - BOT_ALIAS_NAMES: List[str] # 别名,可以通过这个叫它 + BOT_QQ: Optional[int] # 机器人QQ号 + BOT_NICKNAME: Optional[str] # 机器人昵称 + BOT_ALIAS_NAMES: List[str] # 机器人别名列表 # group - talk_allowed_groups: set - talk_frequency_down_groups: set - ban_user_id: set + talk_allowed_groups: List[int] # 允许回复消息的群号列表 + talk_frequency_down_groups: List[int] # 降低回复频率的群号列表 + ban_user_id: List[int] # 禁止回复和读取消息的QQ号列表 # personality - personality_core: str # 建议20字以内,谁再写3000字小作文敲谁脑袋 - personality_sides: List[str] + personality_core: str # 人格核心特点描述 + personality_sides: List[str] # 人格细节描述列表 + # identity - identity_detail: List[str] - height: int # 身高 单位厘米 - weight: int # 体重 单位千克 - age: int # 年龄 单位岁 + identity_detail: List[str] # 身份特点列表 + age: int # 年龄(岁) gender: str # 性别 - appearance: str # 外貌特征 + appearance: str # 外貌特征描述 # schedule ENABLE_SCHEDULE_GEN: bool # 是否启用日程生成 - PROMPT_SCHEDULE_GEN: str - SCHEDULE_DOING_UPDATE_INTERVAL: int # 日程表更新间隔 单位秒 - SCHEDULE_TEMPERATURE: float # 日程表温度,建议0.5-1.0 + ENABLE_SCHEDULE_INTERACTION: bool # 是否启用日程交互 + PROMPT_SCHEDULE_GEN: str # 日程生成提示词 + SCHEDULE_DOING_UPDATE_INTERVAL: int # 日程进行中更新间隔 + SCHEDULE_TEMPERATURE: float # 日程生成温度 TIME_ZONE: str # 时区 - # message - MAX_CONTEXT_SIZE: int # 上下文最大消息数 - emoji_chance: float # 发送表情包的基础概率 - thinking_timeout: int # 思考时间 - model_max_output_length: int # 最大回复长度 - message_buffer: bool # 消息缓冲器 + # platforms + platforms: Dict[str, str] # 平台信息 - ban_words: set - ban_msgs_regex: set - # heartflow - # enable_heartflow: bool = False # 是否启用心流 - sub_heart_flow_update_interval: int # 子心流更新频率,间隔 单位秒 - sub_heart_flow_freeze_time: int # 子心流冻结时间,超过这个时间没有回复,子心流会冻结,间隔 单位秒 - sub_heart_flow_stop_time: int # 子心流停止时间,超过这个时间没有回复,子心流会停止,间隔 单位秒 - heart_flow_update_interval: int # 心流更新频率,间隔 单位秒 - observation_context_size: int # 心流观察到的最长上下文大小,超过这个值的上下文会被压缩 - compressed_length: int # 不能大于observation_context_size,心流上下文压缩的最短压缩长度,超过心流观察到的上下文长度,会压缩,最短压缩长度为5 - compress_length_limit: int # 最多压缩份数,超过该数值的压缩上下文会被删除 + # chat + allow_focus_mode: bool # 是否允许专注模式 + base_normal_chat_num: int # 基础普通聊天次数 + base_focused_chat_num: int # 基础专注聊天次数 + observation_context_size: int # 观察上下文大小 + message_buffer: bool # 是否启用消息缓冲 + ban_words: List[str] # 禁止词列表 + ban_msgs_regex: List[str] # 禁止消息的正则表达式列表 - # willing + # normal_chat + MODEL_R1_PROBABILITY: float # 模型推理概率 + MODEL_V3_PROBABILITY: float # 模型普通概率 + emoji_chance: float # 表情符号出现概率 + thinking_timeout: int # 思考超时时间 willing_mode: str # 意愿模式 - response_willing_amplifier: float # 回复意愿放大系数 - response_interested_rate_amplifier: float # 回复兴趣度放大系数 - down_frequency_rate: float # 降低回复频率的群组回复意愿降低系数 - emoji_response_penalty: float # 表情包回复惩罚 - mentioned_bot_inevitable_reply: bool # 提及 bot 必然回复 - at_bot_inevitable_reply: bool # @bot 必然回复 + response_willing_amplifier: float # 回复意愿放大器 + response_interested_rate_amplifier: float # 回复兴趣率放大器 + down_frequency_rate: float # 降低频率率 + emoji_response_penalty: float # 表情回复惩罚 + mentioned_bot_inevitable_reply: bool # 提到机器人时是否必定回复 + at_bot_inevitable_reply: bool # @机器人时是否必定回复 - # response - response_mode: str # 回复策略 - MODEL_R1_PROBABILITY: float # R1模型概率 - MODEL_V3_PROBABILITY: float # V3模型概率 - # MODEL_R1_DISTILL_PROBABILITY: float # R1蒸馏模型概率 + # focus_chat + reply_trigger_threshold: float # 回复触发阈值 + default_decay_rate_per_second: float # 默认每秒衰减率 + consecutive_no_reply_threshold: int # 连续不回复阈值 + + # compressed + compressed_length: int # 压缩长度 + compress_length_limit: int # 压缩长度限制 # emoji - max_emoji_num: int # 表情包最大数量 - max_reach_deletion: bool # 开启则在达到最大数量时删除表情包,关闭则不会继续收集表情包 - EMOJI_CHECK_INTERVAL: int # 表情包检查间隔(分钟) - EMOJI_REGISTER_INTERVAL: int # 表情包注册间隔(分钟) - EMOJI_SAVE: bool # 偷表情包 - EMOJI_CHECK: bool # 是否开启过滤 - EMOJI_CHECK_PROMPT: str # 表情包过滤要求 + max_emoji_num: int # 最大表情符号数量 + max_reach_deletion: bool # 达到最大数量时是否删除 + EMOJI_CHECK_INTERVAL: int # 表情检查间隔 + EMOJI_REGISTER_INTERVAL: Optional[int] # 表情注册间隔(兼容性保留) + EMOJI_SAVE: bool # 是否保存表情 + EMOJI_CHECK: bool # 是否检查表情 + EMOJI_CHECK_PROMPT: str # 表情检查提示词 # memory - build_memory_interval: int # 记忆构建间隔(秒) - memory_build_distribution: list # 记忆构建分布,参数:分布1均值,标准差,权重,分布2均值,标准差,权重 - build_memory_sample_num: int # 记忆构建采样数量 - build_memory_sample_length: int # 记忆构建采样长度 + build_memory_interval: int # 构建记忆间隔 + memory_build_distribution: List[float] # 记忆构建分布 + build_memory_sample_num: int # 构建记忆样本数量 + build_memory_sample_length: int # 构建记忆样本长度 memory_compress_rate: float # 记忆压缩率 - - forget_memory_interval: int # 记忆遗忘间隔(秒) - memory_forget_time: int # 记忆遗忘时间(小时) - memory_forget_percentage: float # 记忆遗忘比例 - - memory_ban_words: list # 添加新的配置项默认值 + forget_memory_interval: int # 忘记记忆间隔 + memory_forget_time: int # 记忆忘记时间 + memory_forget_percentage: float # 记忆忘记百分比 + consolidate_memory_interval: int # 巩固记忆间隔 + consolidation_similarity_threshold: float # 巩固相似度阈值 + consolidation_check_percentage: float # 巩固检查百分比 + memory_ban_words: List[str] # 记忆禁止词列表 # mood - mood_update_interval: float # 情绪更新间隔 单位秒 + mood_update_interval: float # 情绪更新间隔 mood_decay_rate: float # 情绪衰减率 mood_intensity_factor: float # 情绪强度因子 - # keywords - keywords_reaction_rules: list # 关键词回复规则 + # keywords_reaction + keywords_reaction_enable: bool # 是否启用关键词反应 + keywords_reaction_rules: List[Dict[str, Any]] # 关键词反应规则 # chinese_typo - chinese_typo_enable: bool # 是否启用中文错别字生成器 - chinese_typo_error_rate: float # 单字替换概率 - chinese_typo_min_freq: int # 最小字频阈值 - chinese_typo_tone_error_rate: float # 声调错误概率 - chinese_typo_word_replace_rate: float # 整词替换概率 + chinese_typo_enable: bool # 是否启用中文错别字 + chinese_typo_error_rate: float # 中文错别字错误率 + chinese_typo_min_freq: int # 中文错别字最小频率 + chinese_typo_tone_error_rate: float # 中文错别字声调错误率 + chinese_typo_word_replace_rate: float # 中文错别字单词替换率 # response_splitter enable_response_splitter: bool # 是否启用回复分割器 - response_max_length: int # 回复允许的最大长度 - response_max_sentence_num: int # 回复允许的最大句子数 + response_max_length: int # 回复最大长度 + response_max_sentence_num: int # 回复最大句子数 + enable_kaomoji_protection: bool # 是否启用颜文字保护 + + model_max_output_length: int # 模型最大输出长度 # remote - remote_enable: bool # 是否启用远程控制 + remote_enable: bool # 是否启用远程功能 # experimental enable_friend_chat: bool # 是否启用好友聊天 - # enable_think_flow: bool # 是否启用思考流程 + talk_allowed_private: List[int] # 允许私聊的QQ号列表 enable_pfc_chatting: bool # 是否启用PFC聊天 # 模型配置 - llm_reasoning: dict[str, str] # LLM推理 - # llm_reasoning_minor: dict[str, str] - llm_normal: dict[str, str] # LLM普通 - llm_topic_judge: dict[str, str] # LLM话题判断 - llm_summary: dict[str, str] # LLM话题总结 - llm_emotion_judge: dict[str, str] # LLM情感判断 - embedding: dict[str, str] # 嵌入 - vlm: dict[str, str] # VLM - moderation: dict[str, str] # 审核 + llm_reasoning: Dict[str, Any] # 推理模型配置 + llm_normal: Dict[str, Any] # 普通模型配置 + llm_topic_judge: Dict[str, Any] # 主题判断模型配置 + llm_summary: Dict[str, Any] # 总结模型配置 + llm_emotion_judge: Optional[Dict[str, Any]] # 情绪判断模型配置(兼容性保留) + embedding: Dict[str, Any] # 嵌入模型配置 + vlm: Dict[str, Any] # VLM模型配置 + moderation: Optional[Dict[str, Any]] # 审核模型配置(兼容性保留) + llm_observation: Dict[str, Any] # 观察模型配置 + llm_sub_heartflow: Dict[str, Any] # 子心流模型配置 + llm_heartflow: Dict[str, Any] # 心流模型配置 + llm_plan: Optional[Dict[str, Any]] # 计划模型配置 + llm_PFC_action_planner: Optional[Dict[str, Any]] # PFC行动计划模型配置 + llm_PFC_chat: Optional[Dict[str, Any]] # PFC聊天模型配置 + llm_PFC_reply_checker: Optional[Dict[str, Any]] # PFC回复检查模型配置 + llm_tool_use: Optional[Dict[str, Any]] # 工具使用模型配置 - # 实验性 - llm_observation: dict[str, str] # LLM观察 - llm_sub_heartflow: dict[str, str] # LLM子心流 - llm_heartflow: dict[str, str] # LLM心流 - - api_urls: dict[str, str] # API URLs + api_urls: Optional[Dict[str, str]] # API地址配置 @strawberry.type -class EnvConfig: - pass +class APIEnvConfig: + """环境变量配置""" + + HOST: str # 服务主机地址 + PORT: int # 服务端口 + + PLUGINS: List[str] # 插件列表 + + MONGODB_HOST: str # MongoDB 主机地址 + MONGODB_PORT: int # MongoDB 端口 + DATABASE_NAME: str # 数据库名称 + + CHAT_ANY_WHERE_BASE_URL: str # ChatAnywhere 基础URL + SILICONFLOW_BASE_URL: str # SiliconFlow 基础URL + DEEP_SEEK_BASE_URL: str # DeepSeek 基础URL + + DEEP_SEEK_KEY: Optional[str] # DeepSeek API Key + CHAT_ANY_WHERE_KEY: Optional[str] # ChatAnywhere API Key + SILICONFLOW_KEY: Optional[str] # SiliconFlow API Key + + SIMPLE_OUTPUT: Optional[bool] # 是否简化输出 + CONSOLE_LOG_LEVEL: Optional[str] # 控制台日志等级 + FILE_LOG_LEVEL: Optional[str] # 文件日志等级 + DEFAULT_CONSOLE_LOG_LEVEL: Optional[str] # 默认控制台日志等级 + DEFAULT_FILE_LOG_LEVEL: Optional[str] # 默认文件日志等级 @strawberry.field def get_env(self) -> str: return "env" + + +print("当前路径:") +print(ROOT_PATH) \ No newline at end of file diff --git a/src/api/graphql/__init__.py b/src/api/maigraphql/__init__.py similarity index 100% rename from src/api/graphql/__init__.py rename to src/api/maigraphql/__init__.py diff --git a/src/api/graphql/schema.py b/src/api/maigraphql/schema.py similarity index 100% rename from src/api/graphql/schema.py rename to src/api/maigraphql/schema.py diff --git a/src/api/main.py b/src/api/main.py index be2259404..d4d3c62e7 100644 --- a/src/api/main.py +++ b/src/api/main.py @@ -14,7 +14,7 @@ router = APIRouter() logger = get_logger("api") # maiapi = FastAPI() -logger.info("API server started.") +logger.info("麦麦API服务器已启动") graphql_router = GraphQLRouter(schema=None, path="/") # Replace `None` with your actual schema router.include_router(graphql_router, prefix="/graphql", tags=["GraphQL"]) diff --git a/src/api/reload_config.py b/src/api/reload_config.py index d77cb536b..150194d08 100644 --- a/src/api/reload_config.py +++ b/src/api/reload_config.py @@ -1,17 +1,20 @@ from fastapi import HTTPException from rich.traceback import install from src.config.config import BotConfig +from src.common.logger_manager import get_logger import os install(extra_lines=3) +logger = get_logger("api") async def reload_config(): try: from src.config import config as config_module - + logger.debug("正在重载配置文件...") bot_config_path = os.path.join(BotConfig.get_config_dir(), "bot_config.toml") config_module.global_config = BotConfig.load_config(config_path=bot_config_path) + logger.debug("配置文件重载成功") return {"status": "reloaded"} except FileNotFoundError as e: raise HTTPException(status_code=404, detail=str(e)) from e From 27212c5d43a98405c518601447d9b9bbc43a993e Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Sun, 4 May 2025 05:43:44 +0000 Subject: [PATCH 08/11] =?UTF-8?q?=F0=9F=A4=96=20=E8=87=AA=E5=8A=A8?= =?UTF-8?q?=E6=A0=BC=E5=BC=8F=E5=8C=96=E4=BB=A3=E7=A0=81=20[skip=20ci]?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/api/config_api.py | 2 +- src/api/reload_config.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/api/config_api.py b/src/api/config_api.py index 7279ac63a..581c05a01 100644 --- a/src/api/config_api.py +++ b/src/api/config_api.py @@ -184,4 +184,4 @@ class APIEnvConfig: print("当前路径:") -print(ROOT_PATH) \ No newline at end of file +print(ROOT_PATH) diff --git a/src/api/reload_config.py b/src/api/reload_config.py index 150194d08..a5f36e3db 100644 --- a/src/api/reload_config.py +++ b/src/api/reload_config.py @@ -8,9 +8,11 @@ install(extra_lines=3) logger = get_logger("api") + async def reload_config(): try: from src.config import config as config_module + logger.debug("正在重载配置文件...") bot_config_path = os.path.join(BotConfig.get_config_dir(), "bot_config.toml") config_module.global_config = BotConfig.load_config(config_path=bot_config_path) From 21159175807bbc87a705e7da64e34b94fef65a9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Mon, 5 May 2025 01:26:34 +0800 Subject: [PATCH 09/11] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E8=8E=B7?= =?UTF-8?q?=E5=8F=96=E6=89=80=E6=9C=89=E5=AD=90=E5=BF=83=E6=B5=81ID?= =?UTF-8?q?=E5=92=8C=E5=BC=BA=E5=88=B6=E6=94=B9=E5=8F=98=E5=AD=90=E5=BF=83?= =?UTF-8?q?=E6=B5=81=E7=8A=B6=E6=80=81=E7=9A=84API=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/api/apiforgui.py | 15 +++++++++++++++ src/api/main.py | 25 +++++++++++++++++++++++++ src/heart_flow/heartflow.py | 9 ++++++++- src/heart_flow/subheartflow_manager.py | 16 +++++++++++++--- 4 files changed, 61 insertions(+), 4 deletions(-) create mode 100644 src/api/apiforgui.py diff --git a/src/api/apiforgui.py b/src/api/apiforgui.py new file mode 100644 index 000000000..04fe37bb9 --- /dev/null +++ b/src/api/apiforgui.py @@ -0,0 +1,15 @@ +from src.heart_flow.heartflow import heartflow +from src.heart_flow.sub_heartflow import ChatState + +async def get_all_subheartflow_ids() -> list: + """获取所有子心流的ID列表""" + all_subheartflows = heartflow.subheartflow_manager.get_all_subheartflows() + return [subheartflow.subheartflow_id for subheartflow in all_subheartflows] + +async def forced_change_subheartflow_status(subheartflow_id: str, status: ChatState) -> bool: + """强制改变子心流的状态""" + subheartflow = await heartflow.get_or_create_subheartflow(subheartflow_id) + if subheartflow: + return await heartflow.force_change_subheartflow_status(subheartflow_id, status) + return False + diff --git a/src/api/main.py b/src/api/main.py index d4d3c62e7..6c2009972 100644 --- a/src/api/main.py +++ b/src/api/main.py @@ -5,9 +5,12 @@ from strawberry.fastapi import GraphQLRouter from src.common.logger_manager import get_logger from src.api.reload_config import reload_config as reload_config_func from src.common.server import global_server +from .apiforgui import get_all_subheartflow_ids, forced_change_subheartflow_status +from src.heart_flow.sub_heartflow import ChatState # import uvicorn # import os + router = APIRouter() @@ -24,6 +27,28 @@ router.include_router(graphql_router, prefix="/graphql", tags=["GraphQL"]) async def reload_config(): return await reload_config_func() +@router.get("/gui/subheartflow/get/all") +async def get_subheartflow_ids(): + """获取所有子心流的ID列表""" + return await get_all_subheartflow_ids() + +@router.post("/gui/subheartflow/forced_change_status") +async def forced_change_subheartflow_status_api(subheartflow_id: str, status: ChatState): #noqa + """强制改变子心流的状态""" + # 参数检查 + if not isinstance(status, ChatState): + logger.warning(f"无效的状态参数: {status}") + return {"status": "failed", "reason": "invalid status"} + logger.info(f"尝试将子心流 {subheartflow_id} 状态更改为 {status.value}") + success = await forced_change_subheartflow_status(subheartflow_id, status) + if success: + logger.info(f"子心流 {subheartflow_id} 状态更改为 {status.value} 成功") + return {"status": "success"} + else: + logger.error(f"子心流 {subheartflow_id} 状态更改为 {status.value} 失败") + return {"status": "failed"} + + def start_api_server(): """启动API服务器""" diff --git a/src/heart_flow/heartflow.py b/src/heart_flow/heartflow.py index bd8bc6ff4..5d9400880 100644 --- a/src/heart_flow/heartflow.py +++ b/src/heart_flow/heartflow.py @@ -1,4 +1,4 @@ -from src.heart_flow.sub_heartflow import SubHeartflow +from src.heart_flow.sub_heartflow import SubHeartflow, ChatState from src.plugins.models.utils_model import LLMRequest from src.config.config import global_config from src.plugins.schedule.schedule_generator import bot_schedule @@ -62,6 +62,13 @@ class Heartflow: # 不再需要传入 self.current_state return await self.subheartflow_manager.get_or_create_subheartflow(subheartflow_id) + async def force_change_subheartflow_status( + self, subheartflow_id: str, status: ChatState + ) -> None: + """强制改变子心流的状态""" + # 这里的 message 是可选的,可能是一个消息对象,也可能是其他类型的数据 + return await self.subheartflow_manager.force_change_state(subheartflow_id, status) + async def heartflow_start_working(self): """启动后台任务""" await self.background_task_manager.start_tasks() diff --git a/src/heart_flow/subheartflow_manager.py b/src/heart_flow/subheartflow_manager.py index f06a68c87..057d6cca3 100644 --- a/src/heart_flow/subheartflow_manager.py +++ b/src/heart_flow/subheartflow_manager.py @@ -82,6 +82,17 @@ class SubHeartflowManager: max_tokens=1000, # 与 Heartflow 一致 (虽然可能不需要这么多) request_type="subheartflow_state_eval", # 保留特定的请求类型 ) + + async def force_change_state(self, subflow_id: Any, target_state: ChatState) -> bool: + """强制改变指定子心流的状态""" + async with self._lock: + subflow = self.subheartflows.get(subflow_id) + if not subflow: + logger.warning(f"[强制状态转换]尝试转换不存在的子心流{subflow_id} 到 {target_state.value}") + return False + await subflow.change_chat_state(target_state) + logger.info(f"[强制状态转换]子心流 {subflow_id} 已转换到 {target_state.value}") + return True def get_all_subheartflows(self) -> List["SubHeartflow"]: """获取所有当前管理的 SubHeartflow 实例列表 (快照)。""" @@ -92,7 +103,7 @@ class SubHeartflowManager: Args: subheartflow_id: 子心流唯一标识符 - # mai_states 参数已被移除,使用 self.mai_state_info + mai_states 参数已被移除,使用 self.mai_state_info Returns: 成功返回SubHeartflow实例,失败返回None @@ -174,8 +185,7 @@ class SubHeartflowManager: continue subheartflow.update_last_chat_state_time() absent_last_time = subheartflow.chat_state_last_time - if max_age_seconds and (current_time - absent_last_time) > max_age_seconds: - flows_to_stop.append(subheartflow_id) + flows_to_stop.append(subheartflow_id) return flows_to_stop From 1da2b4ca701a2a0f3bd91050946387616d9627bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Mon, 5 May 2025 01:29:07 +0800 Subject: [PATCH 10/11] fix: Ruff --- src/heart_flow/subheartflow_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/heart_flow/subheartflow_manager.py b/src/heart_flow/subheartflow_manager.py index 057d6cca3..8dfdcd9f1 100644 --- a/src/heart_flow/subheartflow_manager.py +++ b/src/heart_flow/subheartflow_manager.py @@ -176,7 +176,7 @@ class SubHeartflowManager: def get_inactive_subheartflows(self, max_age_seconds=INACTIVE_THRESHOLD_SECONDS): """识别并返回需要清理的不活跃(处于ABSENT状态超过一小时)子心流(id, 原因)""" - current_time = time.time() + _current_time = time.time() flows_to_stop = [] for subheartflow_id, subheartflow in list(self.subheartflows.items()): @@ -184,7 +184,7 @@ class SubHeartflowManager: if state != ChatState.ABSENT: continue subheartflow.update_last_chat_state_time() - absent_last_time = subheartflow.chat_state_last_time + _absent_last_time = subheartflow.chat_state_last_time flows_to_stop.append(subheartflow_id) return flows_to_stop From 78e145bd564cd8c852bb2b62fd9f52b7ffafacbb Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Sun, 4 May 2025 17:29:22 +0000 Subject: [PATCH 11/11] =?UTF-8?q?=F0=9F=A4=96=20=E8=87=AA=E5=8A=A8?= =?UTF-8?q?=E6=A0=BC=E5=BC=8F=E5=8C=96=E4=BB=A3=E7=A0=81=20[skip=20ci]?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/api/apiforgui.py | 3 ++- src/api/main.py | 5 +++-- src/heart_flow/heartflow.py | 4 +--- src/heart_flow/subheartflow_manager.py | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/api/apiforgui.py b/src/api/apiforgui.py index 04fe37bb9..75ef2f8d1 100644 --- a/src/api/apiforgui.py +++ b/src/api/apiforgui.py @@ -1,15 +1,16 @@ from src.heart_flow.heartflow import heartflow from src.heart_flow.sub_heartflow import ChatState + async def get_all_subheartflow_ids() -> list: """获取所有子心流的ID列表""" all_subheartflows = heartflow.subheartflow_manager.get_all_subheartflows() return [subheartflow.subheartflow_id for subheartflow in all_subheartflows] + async def forced_change_subheartflow_status(subheartflow_id: str, status: ChatState) -> bool: """强制改变子心流的状态""" subheartflow = await heartflow.get_or_create_subheartflow(subheartflow_id) if subheartflow: return await heartflow.force_change_subheartflow_status(subheartflow_id, status) return False - diff --git a/src/api/main.py b/src/api/main.py index 6c2009972..6d7e3c1e2 100644 --- a/src/api/main.py +++ b/src/api/main.py @@ -27,13 +27,15 @@ router.include_router(graphql_router, prefix="/graphql", tags=["GraphQL"]) async def reload_config(): return await reload_config_func() + @router.get("/gui/subheartflow/get/all") async def get_subheartflow_ids(): """获取所有子心流的ID列表""" return await get_all_subheartflow_ids() + @router.post("/gui/subheartflow/forced_change_status") -async def forced_change_subheartflow_status_api(subheartflow_id: str, status: ChatState): #noqa +async def forced_change_subheartflow_status_api(subheartflow_id: str, status: ChatState): # noqa """强制改变子心流的状态""" # 参数检查 if not isinstance(status, ChatState): @@ -49,7 +51,6 @@ async def forced_change_subheartflow_status_api(subheartflow_id: str, status: Ch return {"status": "failed"} - def start_api_server(): """启动API服务器""" global_server.register_router(router, prefix="/api/v1") diff --git a/src/heart_flow/heartflow.py b/src/heart_flow/heartflow.py index 5d9400880..894247ce4 100644 --- a/src/heart_flow/heartflow.py +++ b/src/heart_flow/heartflow.py @@ -62,9 +62,7 @@ class Heartflow: # 不再需要传入 self.current_state return await self.subheartflow_manager.get_or_create_subheartflow(subheartflow_id) - async def force_change_subheartflow_status( - self, subheartflow_id: str, status: ChatState - ) -> None: + async def force_change_subheartflow_status(self, subheartflow_id: str, status: ChatState) -> None: """强制改变子心流的状态""" # 这里的 message 是可选的,可能是一个消息对象,也可能是其他类型的数据 return await self.subheartflow_manager.force_change_state(subheartflow_id, status) diff --git a/src/heart_flow/subheartflow_manager.py b/src/heart_flow/subheartflow_manager.py index 8dfdcd9f1..b09f10844 100644 --- a/src/heart_flow/subheartflow_manager.py +++ b/src/heart_flow/subheartflow_manager.py @@ -82,7 +82,7 @@ class SubHeartflowManager: max_tokens=1000, # 与 Heartflow 一致 (虽然可能不需要这么多) request_type="subheartflow_state_eval", # 保留特定的请求类型 ) - + async def force_change_state(self, subflow_id: Any, target_state: ChatState) -> bool: """强制改变指定子心流的状态""" async with self._lock: