优化 Windows 平台的文件替换操作,增加安全的原子写入功能,改进备份文件清理机制

This commit is contained in:
Windpicker-owo
2025-11-13 21:42:46 +08:00
parent c47678fa12
commit 90a8c472b4

View File

@@ -6,6 +6,9 @@ from __future__ import annotations
import asyncio import asyncio
import json import json
import os
import sys
from collections.abc import Callable
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
@@ -18,6 +21,106 @@ from src.memory_graph.storage.graph_store import GraphStore
logger = get_logger(__name__) logger = get_logger(__name__)
# Windows 平台检测
IS_WINDOWS = sys.platform == "win32"
# Windows 平台检测
IS_WINDOWS = sys.platform == "win32"
async def safe_atomic_write(temp_path: Path, target_path: Path, max_retries: int = 5) -> None:
"""
安全的原子写入操作(针对 Windows 文件锁优化)
Args:
temp_path: 临时文件路径
target_path: 目标文件路径
max_retries: 最大重试次数
Raises:
OSError: 所有重试都失败时抛出
"""
last_error: Exception | None = None
for attempt in range(max_retries):
try:
if IS_WINDOWS:
# Windows 特殊处理:多步骤原子替换
if target_path.exists():
# 策略1: 尝试直接删除
try:
os.unlink(target_path)
except OSError:
# 策略2: 重命名为 .old 文件
old_file = target_path.with_suffix(".old")
try:
if old_file.exists():
os.unlink(old_file)
target_path.rename(old_file)
except OSError:
# 策略3: 使用时间戳后缀
from datetime import datetime
backup_file = target_path.with_suffix(f".bak_{datetime.now().strftime('%H%M%S')}")
target_path.rename(backup_file)
# 标记稍后清理
asyncio.create_task(_cleanup_backup_files(target_path.parent, target_path.stem))
# 执行重命名
temp_path.rename(target_path)
else:
# Unix/Linux: 直接使用 replace (原子操作)
temp_path.replace(target_path)
# 成功
return
except OSError as e:
last_error = e
if attempt < max_retries - 1:
# 指数退避重试
wait_time = 0.05 * (2 ** attempt)
logger.warning(
f"文件替换失败 (尝试 {attempt + 1}/{max_retries}), "
f"等待 {wait_time:.3f}s 后重试: {e}"
)
await asyncio.sleep(wait_time)
else:
logger.error(f"文件替换失败,已达到最大重试次数 ({max_retries})")
# 所有重试失败
if last_error:
raise last_error
raise OSError(f"文件替换失败: {temp_path} -> {target_path}")
async def _cleanup_backup_files(directory: Path, file_stem: str, keep_recent: int = 3) -> None:
"""
清理临时备份文件(后台异步任务)
Args:
directory: 目录路径
file_stem: 文件主名(不含扩展名)
keep_recent: 保留最近的文件数量
"""
try:
# 延迟执行,避免立即清理可能仍在使用的文件
await asyncio.sleep(5)
# 查找所有备份文件
pattern = f"{file_stem}.bak_*"
backup_files = sorted(directory.glob(pattern), key=lambda p: p.stat().st_mtime, reverse=True)
# 删除超出保留数量的文件
for old_file in backup_files[keep_recent:]:
try:
old_file.unlink()
logger.debug(f"已清理旧备份文件: {old_file.name}")
except OSError as e:
logger.debug(f"清理备份文件失败: {old_file.name}, {e}")
except Exception as e:
logger.debug(f"清理备份文件任务失败: {e}")
class PersistenceManager: class PersistenceManager:
""" """
@@ -90,26 +193,10 @@ class PersistenceManager:
async with aiofiles.open(temp_file, "wb") as f: async with aiofiles.open(temp_file, "wb") as f:
await f.write(json_data) await f.write(json_data)
# 在Windows上确保目标文件没有被占用 # 使用安全的原子写入
if self.graph_file.exists(): await safe_atomic_write(temp_file, self.graph_file)
import os
try:
os.unlink(self.graph_file)
except OSError:
# 如果无法删除,等待一小段时间再重试
await asyncio.sleep(0.1)
try:
os.unlink(self.graph_file)
except OSError:
# 如果还是失败,使用备用策略
backup_file = self.graph_file.with_suffix(".bak")
if backup_file.exists():
os.unlink(backup_file)
self.graph_file.rename(backup_file)
temp_file.replace(self.graph_file) logger.debug(f"图数据已保存: {self.graph_file}, 大小: {len(json_data) / 1024:.2f} KB")
logger.info(f"图数据已保存: {self.graph_file}, 大小: {len(json_data) / 1024:.2f} KB")
except Exception as e: except Exception as e:
logger.error(f"保存图数据失败: {e}", exc_info=True) logger.error(f"保存图数据失败: {e}", exc_info=True)
@@ -127,36 +214,41 @@ class PersistenceManager:
return None return None
async with self._file_lock: # 使用文件锁防止并发访问 async with self._file_lock: # 使用文件锁防止并发访问
try: try:
# 读取文件,添加重试机制处理可能的文件锁定 # 读取文件,添加重试机制处理可能的文件锁定
max_retries = 3 data = None
for attempt in range(max_retries): max_retries = 3
try: for attempt in range(max_retries):
async with aiofiles.open(self.graph_file, "rb") as f: try:
json_data = await f.read() async with aiofiles.open(self.graph_file, "rb") as f:
data = orjson.loads(json_data) json_data = await f.read()
break data = orjson.loads(json_data)
except OSError as e: break
if attempt == max_retries - 1: except OSError as e:
raise if attempt == max_retries - 1:
logger.warning(f"读取图数据文件失败 (尝试 {attempt + 1}/{max_retries}): {e}") raise
await asyncio.sleep(0.1 * (attempt + 1)) logger.warning(f"读取图数据文件失败 (尝试 {attempt + 1}/{max_retries}): {e}")
await asyncio.sleep(0.1 * (attempt + 1))
# 检查版本(未来可能需要数据迁移) if data is None:
version = data.get("metadata", {}).get("version", "unknown") logger.error("无法读取图数据文件")
logger.info(f"加载图数据: version={version}")
# 恢复图存储
graph_store = GraphStore.from_dict(data)
logger.info(f"图数据加载完成: {graph_store.get_statistics()}")
return graph_store
except Exception as e:
logger.error(f"加载图数据失败: {e}", exc_info=True)
# 尝试加载备份
return await self._load_from_backup() return await self._load_from_backup()
# 检查版本(未来可能需要数据迁移)
version = data.get("metadata", {}).get("version", "unknown")
logger.info(f"加载图数据: version={version}")
# 恢复图存储
graph_store = GraphStore.from_dict(data)
logger.info(f"图数据加载完成: {graph_store.get_statistics()}")
return graph_store
except Exception as e:
logger.error(f"加载图数据失败: {e}", exc_info=True)
# 尝试加载备份
return await self._load_from_backup()
async def save_staged_memories(self, staged_memories: list[StagedMemory]) -> None: async def save_staged_memories(self, staged_memories: list[StagedMemory]) -> None:
""" """
保存临时记忆列表 保存临时记忆列表
@@ -181,24 +273,8 @@ class PersistenceManager:
async with aiofiles.open(temp_file, "wb") as f: async with aiofiles.open(temp_file, "wb") as f:
await f.write(json_data) await f.write(json_data)
# 在Windows上确保目标文件没有被占用 # 使用安全的原子写入
if self.staged_file.exists(): await safe_atomic_write(temp_file, self.staged_file)
import os
try:
os.unlink(self.staged_file)
except OSError:
# 如果无法删除,等待一小段时间再重试
await asyncio.sleep(0.1)
try:
os.unlink(self.staged_file)
except OSError:
# 如果还是失败,使用备用策略
backup_file = self.staged_file.with_suffix(".bak")
if backup_file.exists():
os.unlink(backup_file)
self.staged_file.rename(backup_file)
temp_file.replace(self.staged_file)
logger.info(f"临时记忆已保存: {len(staged_memories)}") logger.info(f"临时记忆已保存: {len(staged_memories)}")
@@ -218,30 +294,35 @@ class PersistenceManager:
return [] return []
async with self._file_lock: # 使用文件锁防止并发访问 async with self._file_lock: # 使用文件锁防止并发访问
try: try:
# 读取文件,添加重试机制处理可能的文件锁定 # 读取文件,添加重试机制处理可能的文件锁定
max_retries = 3 data = None
for attempt in range(max_retries): max_retries = 3
try: for attempt in range(max_retries):
async with aiofiles.open(self.staged_file, "rb") as f: try:
json_data = await f.read() async with aiofiles.open(self.staged_file, "rb") as f:
data = orjson.loads(json_data) json_data = await f.read()
break data = orjson.loads(json_data)
except OSError as e: break
if attempt == max_retries - 1: except OSError as e:
raise if attempt == max_retries - 1:
logger.warning(f"读取临时记忆文件失败 (尝试 {attempt + 1}/{max_retries}): {e}") raise
await asyncio.sleep(0.1 * (attempt + 1)) logger.warning(f"读取临时记忆文件失败 (尝试 {attempt + 1}/{max_retries}): {e}")
await asyncio.sleep(0.1 * (attempt + 1))
staged_memories = [StagedMemory.from_dict(sm) for sm in data.get("staged_memories", [])] if data is None:
logger.error("无法读取临时记忆文件")
logger.info(f"临时记忆加载完成: {len(staged_memories)}")
return staged_memories
except Exception as e:
logger.error(f"加载临时记忆失败: {e}", exc_info=True)
return [] return []
staged_memories = [StagedMemory.from_dict(sm) for sm in data.get("staged_memories", [])]
logger.info(f"临时记忆加载完成: {len(staged_memories)}")
return staged_memories
except Exception as e:
logger.error(f"加载临时记忆失败: {e}", exc_info=True)
return []
async def create_backup(self) -> Path | None: async def create_backup(self) -> Path | None:
""" """
创建当前数据的备份 创建当前数据的备份
@@ -286,6 +367,7 @@ class PersistenceManager:
logger.warning(f"尝试从备份恢复: {latest_backup}") logger.warning(f"尝试从备份恢复: {latest_backup}")
# 读取备份文件,添加重试机制 # 读取备份文件,添加重试机制
data = None
max_retries = 3 max_retries = 3
for attempt in range(max_retries): for attempt in range(max_retries):
try: try:
@@ -299,6 +381,10 @@ class PersistenceManager:
logger.warning(f"读取备份文件失败 (尝试 {attempt + 1}/{max_retries}): {e}") logger.warning(f"读取备份文件失败 (尝试 {attempt + 1}/{max_retries}): {e}")
await asyncio.sleep(0.1 * (attempt + 1)) await asyncio.sleep(0.1 * (attempt + 1))
if data is None:
logger.error("无法从备份读取数据")
return None
graph_store = GraphStore.from_dict(data) graph_store = GraphStore.from_dict(data)
logger.info(f"从备份恢复成功: {graph_store.get_statistics()}") logger.info(f"从备份恢复成功: {graph_store.get_statistics()}")
@@ -329,7 +415,7 @@ class PersistenceManager:
async def start_auto_save( async def start_auto_save(
self, self,
graph_store: GraphStore, graph_store: GraphStore,
staged_memories_getter: callable | None = None, staged_memories_getter: Callable[[], list[StagedMemory]] | None = None,
) -> None: ) -> None:
""" """
启动自动保存任务 启动自动保存任务