From cc91e1a2a5fe75619c38efe461801dc6d1212ef6 Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Thu, 4 Dec 2025 18:34:37 +0800 Subject: [PATCH 1/2] =?UTF-8?q?feat(exception-handling):=20=E5=A2=9E?= =?UTF-8?q?=E5=BC=BA=E5=85=A8=E5=B1=80=E5=BC=82=E5=B8=B8=E5=A4=84=E7=90=86?= =?UTF-8?q?=EF=BC=8C=E6=B7=BB=E5=8A=A0=E7=BA=BF=E7=A8=8B=E5=92=8C=20asynci?= =?UTF-8?q?o=20=E5=BC=82=E5=B8=B8=E6=8D=95=E8=8E=B7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bot.py | 182 ++++++++++++++++++++++++++++++++++++++++--- src/common/logger.py | 29 +++++-- src/main.py | 18 ++++- 3 files changed, 210 insertions(+), 19 deletions(-) diff --git a/bot.py b/bot.py index b3c653096..7f6cc6107 100644 --- a/bot.py +++ b/bot.py @@ -1,12 +1,17 @@ import asyncio +import faulthandler import os import platform import sys +import threading import time import traceback from contextlib import asynccontextmanager from pathlib import Path +# 启用 faulthandler 以便在段错误等致命错误时打印 traceback +faulthandler.enable() + # 初始化基础工具 from colorama import Fore, init from dotenv import load_dotenv @@ -20,6 +25,88 @@ initialize_logging() logger = get_logger("main") install(extra_lines=3) + +# ============= 全局异常捕获系统 ============= +def _global_exception_handler(exc_type, exc_value, exc_tb): + """全局异常处理器 - 捕获主线程中未处理的异常""" + if issubclass(exc_type, KeyboardInterrupt): + # 对 Ctrl+C 使用默认处理 + sys.__excepthook__(exc_type, exc_value, exc_tb) + return + + # 格式化异常信息 + error_msg = "".join(traceback.format_exception(exc_type, exc_value, exc_tb)) + + # 尝试通过日志系统记录 + try: + logger.critical(f"未捕获的致命异常:\n{error_msg}") + except Exception: + pass + + # 同时直接输出到 stderr,确保即使日志系统失败也能看到错误 + print(f"\n{'='*60}", file=sys.stderr) + print("致命错误 - 未捕获的异常:", file=sys.stderr) + print(f"{'='*60}", file=sys.stderr) + print(error_msg, file=sys.stderr) + print(f"{'='*60}\n", file=sys.stderr) + sys.stderr.flush() + + +def _thread_exception_handler(args): + """线程异常处理器 - 捕获子线程中未处理的异常""" + exc_type = args.exc_type + exc_value = args.exc_value + exc_tb = args.exc_traceback + thread = args.thread + + if issubclass(exc_type, SystemExit): + return + + thread_name = thread.name if thread else "Unknown" + error_msg = "".join(traceback.format_exception(exc_type, exc_value, exc_tb)) + + # 尝试通过日志系统记录 + try: + logger.critical(f"线程 '{thread_name}' 中发生未捕获的异常:\n{error_msg}") + except Exception: + pass + + # 同时直接输出到 stderr + print(f"\n{'='*60}", file=sys.stderr) + print(f"致命错误 - 线程 '{thread_name}' 中未捕获的异常:", file=sys.stderr) + print(f"{'='*60}", file=sys.stderr) + print(error_msg, file=sys.stderr) + print(f"{'='*60}\n", file=sys.stderr) + sys.stderr.flush() + + +def _unraisable_exception_handler(unraisable): + """不可抛出异常处理器 - 捕获 __del__ 等中的异常""" + exc_type = unraisable.exc_type + exc_value = unraisable.exc_value + exc_tb = unraisable.exc_traceback + obj = unraisable.object + + error_msg = "".join(traceback.format_exception(exc_type, exc_value, exc_tb)) + obj_repr = repr(obj) if obj else "N/A" + + # 尝试通过日志系统记录 + try: + logger.error(f"不可抛出的异常 (对象: {obj_repr}):\n{error_msg}") + except Exception: + pass + + # 输出到 stderr + print(f"\n警告 - 不可抛出的异常 (对象: {obj_repr}):", file=sys.stderr) + print(error_msg, file=sys.stderr) + sys.stderr.flush() + + +# 安装全局异常处理器 +sys.excepthook = _global_exception_handler +threading.excepthook = _thread_exception_handler +sys.unraisablehook = _unraisable_exception_handler + # 常量定义 SUPPORTED_DATABASES = ["sqlite", "mysql", "postgresql"] SHUTDOWN_TIMEOUT = 10.0 @@ -260,6 +347,31 @@ async def create_event_loop_context(): try: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) + + # 设置 asyncio 异常处理器 + def asyncio_exception_handler(loop, context): + """asyncio 异常处理器 - 捕获事件循环中未处理的异常""" + exception = context.get("exception") + message = context.get("message", "") + + if exception: + error_msg = "".join(traceback.format_exception(type(exception), exception, exception.__traceback__)) + try: + logger.error(f"asyncio 未处理异常: {message}\n{error_msg}") + except Exception: + pass + print(f"\nasyncio 未处理异常: {message}", file=sys.stderr) + print(error_msg, file=sys.stderr) + sys.stderr.flush() + else: + try: + logger.error(f"asyncio 错误: {message}") + except Exception: + pass + print(f"\nasyncio 错误: {message}", file=sys.stderr) + sys.stderr.flush() + + loop.set_exception_handler(asyncio_exception_handler) yield loop except Exception as e: logger.error(f"创建事件循环失败: {e}") @@ -267,9 +379,13 @@ async def create_event_loop_context(): finally: if loop and not loop.is_closed(): try: - await ShutdownManager.graceful_shutdown(loop) + # 执行优雅关闭 + # 注意:在 finally 中不能直接 await,需要使用 loop.run_until_complete + shutdown_coro = ShutdownManager.graceful_shutdown(loop) + loop.run_until_complete(shutdown_coro) except Exception as e: logger.error(f"关闭事件循环时出错: {e}") + print(f"关闭事件循环时出错: {e}", file=sys.stderr) finally: try: loop.close() @@ -627,11 +743,18 @@ async def main_async(): exit_code = 0 main_task = None + # 在进入事件循环上下文之前先进行基本检查 + try: + ConfigManager.ensure_env_file() + except Exception as e: + logger.critical(f"环境文件检查失败: {e}") + print(f"致命错误 - 环境文件检查失败: {e}", file=sys.stderr) + traceback.print_exc(file=sys.stderr) + sys.stderr.flush() + return 1 + async with create_event_loop_context(): try: - # 确保环境文件存在 - ConfigManager.ensure_env_file() - # 启动 WebUI 开发服务器(成功/失败都继续后续步骤) webui_ok = await WebUIManager.start_dev_server(timeout=60) if webui_ok: @@ -654,6 +777,14 @@ async def main_async(): # 使用wait等待任意一个任务完成 done, _pending = await asyncio.wait([main_task, user_input_done], return_when=asyncio.FIRST_COMPLETED) + # 检查已完成的任务是否有异常 + for task in done: + exc = task.exception() + if exc is not None: + error_msg = "".join(traceback.format_exception(type(exc), exc, exc.__traceback__)) + logger.error(f"任务异常退出:\n{error_msg}") + exit_code = 1 + # 如果用户输入任务完成(用户按了Ctrl+C),取消主任务 if user_input_done in done and main_task not in done: logger.info("用户请求退出,正在取消主任务...") @@ -669,9 +800,14 @@ async def main_async(): logger.warning("收到中断信号,正在优雅关闭...") if main_task and not main_task.done(): main_task.cancel() + except asyncio.CancelledError: + logger.info("主异步任务被取消") except Exception as e: - logger.error(f"主程序发生异常: {e}") - logger.debug(f"异常详情: {traceback.format_exc()}") + error_msg = traceback.format_exc() + logger.critical(f"主程序发生异常:\n{error_msg}") + # 同时输出到 stderr 确保可见 + print(f"\n主程序异常:\n{error_msg}", file=sys.stderr) + sys.stderr.flush() exit_code = 1 return exit_code @@ -684,14 +820,40 @@ if __name__ == "__main__": except KeyboardInterrupt: logger.info("程序被用户中断") exit_code = 130 - except Exception as e: - logger.error(f"程序启动失败: {e}") + except SystemExit as e: + # 保留 SystemExit 的退出码 + exit_code = e.code if isinstance(e.code, int) else 1 + if exit_code != 0: + logger.warning(f"程序通过 SystemExit 退出,退出码: {exit_code}") + except BaseException as e: + # 捕获所有异常,包括 Exception 和其他 BaseException 子类 + error_msg = traceback.format_exc() + try: + logger.critical(f"程序启动失败 - 致命异常:\n{error_msg}") + except Exception: + pass + # 确保错误被输出到 stderr + print(f"\n{'='*60}", file=sys.stderr) + print("致命错误 - 程序启动失败:", file=sys.stderr) + print(f"{'='*60}", file=sys.stderr) + print(error_msg, file=sys.stderr) + print(f"{'='*60}\n", file=sys.stderr) + sys.stderr.flush() exit_code = 1 finally: - # 确保日志系统正确关闭 + # 确保日志系统正确关闭,并给日志队列时间刷新 try: + # 给日志队列一点时间来刷新 + time.sleep(0.2) shutdown_logging() except Exception as e: - print(f"关闭日志系统时出错: {e}") + print(f"关闭日志系统时出错: {e}", file=sys.stderr) + + # 最后刷新标准输出和错误输出 + try: + sys.stdout.flush() + sys.stderr.flush() + except Exception: + pass sys.exit(exit_code) diff --git a/src/common/logger.py b/src/common/logger.py index dd3425797..c6afb03df 100644 --- a/src/common/logger.py +++ b/src/common/logger.py @@ -1307,12 +1307,25 @@ def shutdown_logging(): """优雅关闭日志系统,释放所有文件句柄""" logger = get_logger("logger") logger.info("正在关闭日志系统...") + + # 给日志队列一点时间来刷新 + import time + time.sleep(0.1) + + # 停止队列监听器(这会等待队列清空) + _stop_queue_logging() # 关闭所有handler root_logger = logging.getLogger() for handler in root_logger.handlers[:]: - if hasattr(handler, "close"): - handler.close() + try: + # 先刷新 handler + if hasattr(handler, "flush"): + handler.flush() + if hasattr(handler, "close"): + handler.close() + except Exception as e: + print(f"[日志系统] 关闭 handler 时出错: {e}") root_logger.removeHandler(handler) # 关闭全局handler @@ -1323,8 +1336,14 @@ def shutdown_logging(): for logger_obj in logger_dict.values(): if isinstance(logger_obj, logging.Logger): for handler in logger_obj.handlers[:]: - if hasattr(handler, "close"): - handler.close() + try: + if hasattr(handler, "flush"): + handler.flush() + if hasattr(handler, "close"): + handler.close() + except Exception as e: + print(f"[日志系统] 关闭 logger handler 时出错: {e}") logger_obj.removeHandler(handler) - logger.info("日志系统已关闭") + # 最终输出(直接到 stderr 因为日志系统已关闭) + print("[日志系统] 日志系统已关闭") diff --git a/src/main.py b/src/main.py index 3e5754647..b927cce71 100644 --- a/src/main.py +++ b/src/main.py @@ -678,9 +678,16 @@ async def main() -> None: await system.schedule_tasks() except KeyboardInterrupt: logger.info("收到键盘中断信号") + except asyncio.CancelledError: + logger.info("主任务被取消") except Exception as e: - logger.error(f"主函数执行失败: {e}") - logger.error(traceback.format_exc()) + error_msg = traceback.format_exc() + logger.critical(f"主函数执行失败:\n{error_msg}") + # 同时输出到 stderr 确保即使日志系统异常也能看到 + import sys + print(f"\n主函数致命错误:\n{error_msg}", file=sys.stderr) + sys.stderr.flush() + raise # 重新抛出以便上层处理 finally: await system.shutdown() @@ -691,6 +698,9 @@ if __name__ == "__main__": except KeyboardInterrupt: logger.info("程序被用户中断") except Exception as e: - logger.error(f"程序执行失败: {e}") - logger.error(traceback.format_exc()) + error_msg = traceback.format_exc() + logger.critical(f"程序执行失败:\n{error_msg}") + import sys + print(f"\n程序致命错误:\n{error_msg}", file=sys.stderr) + sys.stderr.flush() sys.exit(1) From e7cb04bfdde573096d32060c72a4b643bf2e66da Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Thu, 4 Dec 2025 18:58:07 +0800 Subject: [PATCH 2/2] =?UTF-8?q?feat(chromadb):=20=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E5=85=A8=E5=B1=80=E9=94=81=E4=BB=A5=E4=BF=9D=E6=8A=A4=20Chroma?= =?UTF-8?q?DB=20=E6=93=8D=E4=BD=9C=EF=BC=8C=E7=A1=AE=E4=BF=9D=E7=BA=BF?= =?UTF-8?q?=E7=A8=8B=E5=AE=89=E5=85=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/common/vector_db/chromadb_impl.py | 72 +++++++------ src/memory_graph/storage/vector_store.py | 123 ++++++++++++++--------- 2 files changed, 118 insertions(+), 77 deletions(-) diff --git a/src/common/vector_db/chromadb_impl.py b/src/common/vector_db/chromadb_impl.py index e4a2911ae..487bf6246 100644 --- a/src/common/vector_db/chromadb_impl.py +++ b/src/common/vector_db/chromadb_impl.py @@ -10,11 +10,17 @@ from .base import VectorDBBase logger = get_logger("chromadb_impl") +# 全局操作锁,用于保护 ChromaDB 的所有操作 +# ChromaDB 的 Rust 后端在 Windows 上多线程并发访问时可能导致 access violation +_operation_lock = threading.Lock() + class ChromaDBImpl(VectorDBBase): """ ChromaDB 的具体实现,遵循 VectorDBBase 接口。 采用单例模式,确保全局只有一个 ChromaDB 客户端实例。 + + 注意:所有操作都使用 _operation_lock 保护,以避免 Windows 上的并发访问崩溃。 """ _instance = None @@ -36,9 +42,10 @@ class ChromaDBImpl(VectorDBBase): with self._lock: if not hasattr(self, "_initialized"): try: - self.client = chromadb.PersistentClient( - path=path, settings=Settings(anonymized_telemetry=False) - ) + with _operation_lock: + self.client = chromadb.PersistentClient( + path=path, settings=Settings(anonymized_telemetry=False) + ) self._collections: dict[str, Any] = {} self._initialized = True logger.info(f"ChromaDB 客户端已初始化,数据库路径: {path}") @@ -56,7 +63,8 @@ class ChromaDBImpl(VectorDBBase): return self._collections[name] try: - collection = self.client.get_or_create_collection(name=name, **kwargs) + with _operation_lock: + collection = self.client.get_or_create_collection(name=name, **kwargs) self._collections[name] = collection logger.info(f"成功获取或创建集合: '{name}'") return collection @@ -75,12 +83,13 @@ class ChromaDBImpl(VectorDBBase): collection = self.get_or_create_collection(collection_name) if collection: try: - collection.add( - embeddings=embeddings, - documents=documents, - metadatas=metadatas, - ids=ids, - ) + with _operation_lock: + collection.add( + embeddings=embeddings, + documents=documents, + metadatas=metadatas, + ids=ids, + ) except Exception as e: logger.error(f"向集合 '{collection_name}' 添加数据失败: {e}") @@ -107,7 +116,8 @@ class ChromaDBImpl(VectorDBBase): if processed_where: query_params["where"] = processed_where - return collection.query(**query_params) + with _operation_lock: + return collection.query(**query_params) except Exception as e: logger.error(f"查询集合 '{collection_name}' 失败: {e}") # 如果查询失败,尝试不使用where条件重新查询 @@ -117,7 +127,8 @@ class ChromaDBImpl(VectorDBBase): "n_results": n_results, } logger.warning("使用回退查询模式(无where条件)") - return collection.query(**fallback_params) + with _operation_lock: + return collection.query(**fallback_params) except Exception as fallback_e: logger.error(f"回退查询也失败: {fallback_e}") return {} @@ -192,26 +203,28 @@ class ChromaDBImpl(VectorDBBase): if where: processed_where = self._process_where_condition(where) - return collection.get( - ids=ids, - where=processed_where, - limit=limit, - offset=offset, - where_document=where_document, - include=include or ["documents", "metadatas", "embeddings"], - ) - except Exception as e: - logger.error(f"从集合 '{collection_name}' 获取数据失败: {e}") - # 如果获取失败,尝试不使用where条件重新获取 - try: - logger.warning("使用回退获取模式(无where条件)") + with _operation_lock: return collection.get( ids=ids, + where=processed_where, limit=limit, offset=offset, where_document=where_document, include=include or ["documents", "metadatas", "embeddings"], ) + except Exception as e: + logger.error(f"从集合 '{collection_name}' 获取数据失败: {e}") + # 如果获取失败,尝试不使用where条件重新获取 + try: + logger.warning("使用回退获取模式(无where条件)") + with _operation_lock: + return collection.get( + ids=ids, + limit=limit, + offset=offset, + where_document=where_document, + include=include or ["documents", "metadatas", "embeddings"], + ) except Exception as fallback_e: logger.error(f"回退获取也失败: {fallback_e}") return {} @@ -225,7 +238,8 @@ class ChromaDBImpl(VectorDBBase): collection = self.get_or_create_collection(collection_name) if collection: try: - collection.delete(ids=ids, where=where) + with _operation_lock: + collection.delete(ids=ids, where=where) except Exception as e: logger.error(f"从集合 '{collection_name}' 删除数据失败: {e}") @@ -233,7 +247,8 @@ class ChromaDBImpl(VectorDBBase): collection = self.get_or_create_collection(collection_name) if collection: try: - return collection.count() + with _operation_lock: + return collection.count() except Exception as e: logger.error(f"获取集合 '{collection_name}' 计数失败: {e}") return 0 @@ -243,7 +258,8 @@ class ChromaDBImpl(VectorDBBase): raise ConnectionError("ChromaDB 客户端未初始化") try: - self.client.delete_collection(name=name) + with _operation_lock: + self.client.delete_collection(name=name) if name in self._collections: del self._collections[name] logger.info(f"集合 '{name}' 已被删除") diff --git a/src/memory_graph/storage/vector_store.py b/src/memory_graph/storage/vector_store.py index b59ea1b83..6e602a732 100644 --- a/src/memory_graph/storage/vector_store.py +++ b/src/memory_graph/storage/vector_store.py @@ -3,11 +3,15 @@ 注意:ChromaDB 是同步库,所有操作都必须使用 asyncio.to_thread() 包装 以避免阻塞 asyncio 事件循环导致死锁。 + +重要:ChromaDB 的 Rust 后端在 Windows 上多线程并发访问时可能导致 access violation, +因此所有操作都需要通过全局锁保护以确保串行执行。 """ from __future__ import annotations import asyncio +import threading from pathlib import Path from typing import Any @@ -18,6 +22,10 @@ from src.memory_graph.models import MemoryNode, NodeType logger = get_logger(__name__) +# 全局锁,用于保护 ChromaDB 的所有操作 +# ChromaDB 的 Rust 后端在 Windows 上多线程并发访问时可能导致 access violation +_chromadb_lock = threading.Lock() + class VectorStore: """ @@ -57,29 +65,35 @@ class VectorStore: import chromadb from chromadb.config import Settings - # 创建持久化客户端 - 同步操作需要在线程中执行 + # 创建持久化客户端 - 同步操作需要在线程中执行,并使用锁保护 def _create_client(): - return chromadb.PersistentClient( - path=str(self.data_dir / "chroma"), - settings=Settings( - anonymized_telemetry=False, - allow_reset=True, - ), - ) + with _chromadb_lock: + return chromadb.PersistentClient( + path=str(self.data_dir / "chroma"), + settings=Settings( + anonymized_telemetry=False, + allow_reset=True, + ), + ) self.client = await asyncio.to_thread(_create_client) - # 获取或创建集合 - 同步操作需要在线程中执行 + # 获取或创建集合 - 同步操作需要在线程中执行,并使用锁保护 def _get_or_create_collection(): - return self.client.get_or_create_collection( - name=self.collection_name, - metadata={"description": "Memory graph node embeddings"}, - ) + with _chromadb_lock: + return self.client.get_or_create_collection( + name=self.collection_name, + metadata={"description": "Memory graph node embeddings"}, + ) self.collection = await asyncio.to_thread(_get_or_create_collection) - # count() 也是同步操作 - count = await asyncio.to_thread(self.collection.count) + # count() 也是同步操作,使用锁保护 + def _count(): + with _chromadb_lock: + return self.collection.count() + + count = await asyncio.to_thread(_count) logger.debug(f"ChromaDB 初始化完成,集合包含 {count} 个节点") except Exception as e: @@ -118,14 +132,15 @@ class VectorStore: else: metadata[key] = str(value) - # ChromaDB add() 是同步阻塞操作,必须在线程中执行 + # ChromaDB add() 是同步阻塞操作,必须在线程中执行,使用锁保护 def _add_node(): - self.collection.add( - ids=[node.id], - embeddings=[node.embedding.tolist()], - metadatas=[metadata], - documents=[node.content], - ) + with _chromadb_lock: + self.collection.add( + ids=[node.id], + embeddings=[node.embedding.tolist()], + metadatas=[metadata], + documents=[node.content], + ) await asyncio.to_thread(_add_node) @@ -171,14 +186,15 @@ class VectorStore: metadata[key] = str(value) metadatas.append(metadata) - # ChromaDB add() 是同步阻塞操作,必须在线程中执行 + # ChromaDB add() 是同步阻塞操作,必须在线程中执行,使用锁保护 def _add_batch(): - self.collection.add( - ids=[n.id for n in valid_nodes], - embeddings=[n.embedding.tolist() for n in valid_nodes], # type: ignore - metadatas=metadatas, - documents=[n.content for n in valid_nodes], - ) + with _chromadb_lock: + self.collection.add( + ids=[n.id for n in valid_nodes], + embeddings=[n.embedding.tolist() for n in valid_nodes], # type: ignore + metadatas=metadatas, + documents=[n.content for n in valid_nodes], + ) await asyncio.to_thread(_add_batch) @@ -214,13 +230,14 @@ class VectorStore: if node_types: where_filter = {"node_type": {"$in": [nt.value for nt in node_types]}} - # ChromaDB query() 是同步阻塞操作,必须在线程中执行 + # ChromaDB query() 是同步阻塞操作,必须在线程中执行,使用锁保护 def _query(): - return self.collection.query( - query_embeddings=[query_embedding.tolist()], - n_results=limit, - where=where_filter, - ) + with _chromadb_lock: + return self.collection.query( + query_embeddings=[query_embedding.tolist()], + n_results=limit, + where=where_filter, + ) results = await asyncio.to_thread(_query) @@ -383,9 +400,10 @@ class VectorStore: raise RuntimeError("向量存储未初始化") try: - # ChromaDB get() 是同步阻塞操作,必须在线程中执行 + # ChromaDB get() 是同步阻塞操作,必须在线程中执行,使用锁保护 def _get(): - return self.collection.get(ids=[node_id], include=["metadatas", "embeddings"]) + with _chromadb_lock: + return self.collection.get(ids=[node_id], include=["metadatas", "embeddings"]) result = await asyncio.to_thread(_get) @@ -420,9 +438,10 @@ class VectorStore: raise RuntimeError("向量存储未初始化") try: - # ChromaDB delete() 是同步阻塞操作,必须在线程中执行 + # ChromaDB delete() 是同步阻塞操作,必须在线程中执行,使用锁保护 def _delete(): - self.collection.delete(ids=[node_id]) + with _chromadb_lock: + self.collection.delete(ids=[node_id]) await asyncio.to_thread(_delete) logger.debug(f"删除节点: {node_id}") @@ -443,9 +462,10 @@ class VectorStore: raise RuntimeError("向量存储未初始化") try: - # ChromaDB update() 是同步阻塞操作,必须在线程中执行 + # ChromaDB update() 是同步阻塞操作,必须在线程中执行,使用锁保护 def _update(): - self.collection.update(ids=[node_id], embeddings=[embedding.tolist()]) + with _chromadb_lock: + self.collection.update(ids=[node_id], embeddings=[embedding.tolist()]) await asyncio.to_thread(_update) logger.debug(f"更新节点 embedding: {node_id}") @@ -458,13 +478,17 @@ class VectorStore: """获取向量存储中的节点总数(同步方法,谨慎在 async 上下文中使用)""" if not self.collection: return 0 - return self.collection.count() + with _chromadb_lock: + return self.collection.count() async def get_total_count_async(self) -> int: """异步获取向量存储中的节点总数""" if not self.collection: return 0 - return await asyncio.to_thread(self.collection.count) + def _count(): + with _chromadb_lock: + return self.collection.count() + return await asyncio.to_thread(_count) async def clear(self) -> None: """清空向量存储(危险操作,仅用于测试)""" @@ -472,13 +496,14 @@ class VectorStore: return try: - # ChromaDB delete_collection 和 get_or_create_collection 都是同步阻塞操作 + # ChromaDB delete_collection 和 get_or_create_collection 都是同步阻塞操作,使用锁保护 def _clear(): - self.client.delete_collection(self.collection_name) - return self.client.get_or_create_collection( - name=self.collection_name, - metadata={"description": "Memory graph node embeddings"}, - ) + with _chromadb_lock: + self.client.delete_collection(self.collection_name) + return self.client.get_or_create_collection( + name=self.collection_name, + metadata={"description": "Memory graph node embeddings"}, + ) self.collection = await asyncio.to_thread(_clear) logger.warning(f"向量存储已清空: {self.collection_name}")