diff --git a/src/memory_graph/storage/persistence.py b/src/memory_graph/storage/persistence.py index 70e745099..46ed90ba1 100644 --- a/src/memory_graph/storage/persistence.py +++ b/src/memory_graph/storage/persistence.py @@ -6,6 +6,9 @@ from __future__ import annotations import asyncio import json +import os +import sys +from collections.abc import Callable from datetime import datetime from pathlib import Path @@ -18,6 +21,106 @@ from src.memory_graph.storage.graph_store import GraphStore logger = get_logger(__name__) +# Windows 平台检测 +IS_WINDOWS = sys.platform == "win32" + +# Windows 平台检测 +IS_WINDOWS = sys.platform == "win32" + + +async def safe_atomic_write(temp_path: Path, target_path: Path, max_retries: int = 5) -> None: + """ + 安全的原子写入操作(针对 Windows 文件锁优化) + + Args: + temp_path: 临时文件路径 + target_path: 目标文件路径 + max_retries: 最大重试次数 + + Raises: + OSError: 所有重试都失败时抛出 + """ + last_error: Exception | None = None + + for attempt in range(max_retries): + try: + if IS_WINDOWS: + # Windows 特殊处理:多步骤原子替换 + if target_path.exists(): + # 策略1: 尝试直接删除 + try: + os.unlink(target_path) + except OSError: + # 策略2: 重命名为 .old 文件 + old_file = target_path.with_suffix(".old") + try: + if old_file.exists(): + os.unlink(old_file) + target_path.rename(old_file) + except OSError: + # 策略3: 使用时间戳后缀 + from datetime import datetime + backup_file = target_path.with_suffix(f".bak_{datetime.now().strftime('%H%M%S')}") + target_path.rename(backup_file) + # 标记稍后清理 + asyncio.create_task(_cleanup_backup_files(target_path.parent, target_path.stem)) + + # 执行重命名 + temp_path.rename(target_path) + else: + # Unix/Linux: 直接使用 replace (原子操作) + temp_path.replace(target_path) + + # 成功 + return + + except OSError as e: + last_error = e + if attempt < max_retries - 1: + # 指数退避重试 + wait_time = 0.05 * (2 ** attempt) + logger.warning( + f"文件替换失败 (尝试 {attempt + 1}/{max_retries}), " + f"等待 {wait_time:.3f}s 后重试: {e}" + ) + await asyncio.sleep(wait_time) + else: + logger.error(f"文件替换失败,已达到最大重试次数 ({max_retries})") + + # 所有重试失败 + if last_error: + raise last_error + raise OSError(f"文件替换失败: {temp_path} -> {target_path}") + + +async def _cleanup_backup_files(directory: Path, file_stem: str, keep_recent: int = 3) -> None: + """ + 清理临时备份文件(后台异步任务) + + Args: + directory: 目录路径 + file_stem: 文件主名(不含扩展名) + keep_recent: 保留最近的文件数量 + """ + try: + # 延迟执行,避免立即清理可能仍在使用的文件 + await asyncio.sleep(5) + + # 查找所有备份文件 + pattern = f"{file_stem}.bak_*" + backup_files = sorted(directory.glob(pattern), key=lambda p: p.stat().st_mtime, reverse=True) + + # 删除超出保留数量的文件 + for old_file in backup_files[keep_recent:]: + try: + old_file.unlink() + logger.debug(f"已清理旧备份文件: {old_file.name}") + except OSError as e: + logger.debug(f"清理备份文件失败: {old_file.name}, {e}") + + except Exception as e: + logger.debug(f"清理备份文件任务失败: {e}") + class PersistenceManager: """ @@ -90,26 +193,10 @@ class PersistenceManager: async with aiofiles.open(temp_file, "wb") as f: await f.write(json_data) - # 在Windows上,确保目标文件没有被占用 - if self.graph_file.exists(): - import os - try: - os.unlink(self.graph_file) - except OSError: - # 如果无法删除,等待一小段时间再重试 - await asyncio.sleep(0.1) - try: - os.unlink(self.graph_file) - except OSError: - # 如果还是失败,使用备用策略 - backup_file = self.graph_file.with_suffix(".bak") - if backup_file.exists(): - os.unlink(backup_file) - self.graph_file.rename(backup_file) + # 使用安全的原子写入 + await safe_atomic_write(temp_file, self.graph_file) - temp_file.replace(self.graph_file) - - logger.info(f"图数据已保存: {self.graph_file}, 大小: {len(json_data) / 1024:.2f} KB") + logger.debug(f"图数据已保存: {self.graph_file}, 大小: {len(json_data) / 1024:.2f} KB") except Exception as e: logger.error(f"保存图数据失败: {e}", exc_info=True) @@ -127,36 +214,41 @@ class PersistenceManager: return None async with self._file_lock: # 使用文件锁防止并发访问 - try: - # 读取文件,添加重试机制处理可能的文件锁定 - max_retries = 3 - for attempt in range(max_retries): - try: - async with aiofiles.open(self.graph_file, "rb") as f: - json_data = await f.read() - data = orjson.loads(json_data) - break - except OSError as e: - if attempt == max_retries - 1: - raise - logger.warning(f"读取图数据文件失败 (尝试 {attempt + 1}/{max_retries}): {e}") - await asyncio.sleep(0.1 * (attempt + 1)) + try: + # 读取文件,添加重试机制处理可能的文件锁定 + data = None + max_retries = 3 + for attempt in range(max_retries): + try: + async with aiofiles.open(self.graph_file, "rb") as f: + json_data = await f.read() + data = orjson.loads(json_data) + break + except OSError as e: + if attempt == max_retries - 1: + raise + logger.warning(f"读取图数据文件失败 (尝试 {attempt + 1}/{max_retries}): {e}") + await asyncio.sleep(0.1 * (attempt + 1)) - # 检查版本(未来可能需要数据迁移) - version = data.get("metadata", {}).get("version", "unknown") - logger.info(f"加载图数据: version={version}") - - # 恢复图存储 - graph_store = GraphStore.from_dict(data) - - logger.info(f"图数据加载完成: {graph_store.get_statistics()}") - return graph_store - - except Exception as e: - logger.error(f"加载图数据失败: {e}", exc_info=True) - # 尝试加载备份 + if data is None: + logger.error("无法读取图数据文件") return await self._load_from_backup() + # 检查版本(未来可能需要数据迁移) + version = data.get("metadata", {}).get("version", "unknown") + logger.info(f"加载图数据: version={version}") + + # 恢复图存储 + graph_store = GraphStore.from_dict(data) + + logger.info(f"图数据加载完成: {graph_store.get_statistics()}") + return graph_store + + except Exception as e: + logger.error(f"加载图数据失败: {e}", exc_info=True) + # 尝试加载备份 + return await self._load_from_backup() + async def save_staged_memories(self, staged_memories: list[StagedMemory]) -> None: """ 保存临时记忆列表 @@ -181,24 +273,8 @@ class PersistenceManager: async with aiofiles.open(temp_file, "wb") as f: await f.write(json_data) - # 在Windows上,确保目标文件没有被占用 - if self.staged_file.exists(): - import os - try: - os.unlink(self.staged_file) - except OSError: - # 如果无法删除,等待一小段时间再重试 - await asyncio.sleep(0.1) - try: - os.unlink(self.staged_file) - except OSError: - # 如果还是失败,使用备用策略 - backup_file = self.staged_file.with_suffix(".bak") - if backup_file.exists(): - os.unlink(backup_file) - self.staged_file.rename(backup_file) - - temp_file.replace(self.staged_file) + # 使用安全的原子写入 + await safe_atomic_write(temp_file, self.staged_file) logger.info(f"临时记忆已保存: {len(staged_memories)} 条") @@ -218,30 +294,35 @@ class PersistenceManager: return [] async with self._file_lock: # 使用文件锁防止并发访问 - try: - # 读取文件,添加重试机制处理可能的文件锁定 - max_retries = 3 - for attempt in range(max_retries): - try: - async with aiofiles.open(self.staged_file, "rb") as f: - json_data = await f.read() - data = orjson.loads(json_data) - break - except OSError as e: - if attempt == max_retries - 1: - raise - logger.warning(f"读取临时记忆文件失败 (尝试 {attempt + 1}/{max_retries}): {e}") - await asyncio.sleep(0.1 * (attempt + 1)) + try: + # 读取文件,添加重试机制处理可能的文件锁定 + data = None + max_retries = 3 + for attempt in range(max_retries): + try: + async with aiofiles.open(self.staged_file, "rb") as f: + json_data = await f.read() + data = orjson.loads(json_data) + break + except OSError as e: + if attempt == max_retries - 1: + raise + logger.warning(f"读取临时记忆文件失败 (尝试 {attempt + 1}/{max_retries}): {e}") + await asyncio.sleep(0.1 * (attempt + 1)) - staged_memories = [StagedMemory.from_dict(sm) for sm in data.get("staged_memories", [])] - - logger.info(f"临时记忆加载完成: {len(staged_memories)} 条") - return staged_memories - - except Exception as e: - logger.error(f"加载临时记忆失败: {e}", exc_info=True) + if data is None: + logger.error("无法读取临时记忆文件") return [] + staged_memories = [StagedMemory.from_dict(sm) for sm in data.get("staged_memories", [])] + + logger.info(f"临时记忆加载完成: {len(staged_memories)} 条") + return staged_memories + + except Exception as e: + logger.error(f"加载临时记忆失败: {e}", exc_info=True) + return [] + async def create_backup(self) -> Path | None: """ 创建当前数据的备份 @@ -286,6 +367,7 @@ class PersistenceManager: logger.warning(f"尝试从备份恢复: {latest_backup}") # 读取备份文件,添加重试机制 + data = None max_retries = 3 for attempt in range(max_retries): try: @@ -299,6 +381,10 @@ class PersistenceManager: logger.warning(f"读取备份文件失败 (尝试 {attempt + 1}/{max_retries}): {e}") await asyncio.sleep(0.1 * (attempt + 1)) + if data is None: + logger.error("无法从备份读取数据") + return None + graph_store = GraphStore.from_dict(data) logger.info(f"从备份恢复成功: {graph_store.get_statistics()}") @@ -329,7 +415,7 @@ class PersistenceManager: async def start_auto_save( self, graph_store: GraphStore, - staged_memories_getter: callable | None = None, + staged_memories_getter: Callable[[], list[StagedMemory]] | None = None, ) -> None: """ 启动自动保存任务