diff --git a/src/chat/interest_system/bot_interest_manager.py b/src/chat/interest_system/bot_interest_manager.py index 66965c421..959796a51 100644 --- a/src/chat/interest_system/bot_interest_manager.py +++ b/src/chat/interest_system/bot_interest_manager.py @@ -1023,8 +1023,10 @@ class BotInterestManager: return None # 读取缓存文件 - with open(cache_file, "rb") as f: - cache_data = orjson.loads(f.read()) + import aiofiles + async with aiofiles.open(cache_file, "rb") as f: + content = await f.read() + cache_data = orjson.loads(content) # 验证缓存版本和embedding模型 cache_version = cache_data.get("version", 1) @@ -1074,8 +1076,9 @@ class BotInterestManager: } # 写入文件 - with open(cache_file, "wb") as f: - f.write(orjson.dumps(cache_data, option=orjson.OPT_INDENT_2)) + import aiofiles + async with aiofiles.open(cache_file, "wb") as f: + await f.write(orjson.dumps(cache_data, option=orjson.OPT_INDENT_2)) logger.debug(f"💾 已保存 {len(self.embedding_cache)} 个标签embedding和 {len(self.expanded_embedding_cache)} 个扩展embedding到缓存文件: {cache_file}") diff --git a/src/memory_graph/storage/persistence.py b/src/memory_graph/storage/persistence.py index bb6dc2946..23b6a0735 100644 --- a/src/memory_graph/storage/persistence.py +++ b/src/memory_graph/storage/persistence.py @@ -9,6 +9,7 @@ import json from datetime import datetime from pathlib import Path +import aiofiles import orjson from src.common.logger import get_logger @@ -84,7 +85,8 @@ class PersistenceManager: # 原子写入(先写临时文件,再重命名) temp_file = self.graph_file.with_suffix(".tmp") - temp_file.write_bytes(json_data) + async with aiofiles.open(temp_file, "wb") as f: + await f.write(json_data) temp_file.replace(self.graph_file) logger.info(f"图数据已保存: {self.graph_file}, 大小: {len(json_data) / 1024:.2f} KB") @@ -106,7 +108,8 @@ class PersistenceManager: try: # 读取文件 - json_data = self.graph_file.read_bytes() + async with aiofiles.open(self.graph_file, "rb") as f: + json_data = await f.read() data = orjson.loads(json_data) # 检查版本(未来可能需要数据迁移) @@ -144,7 +147,8 @@ class PersistenceManager: json_data = orjson.dumps(data, option=orjson.OPT_INDENT_2 | orjson.OPT_SERIALIZE_NUMPY) temp_file = self.staged_file.with_suffix(".tmp") - temp_file.write_bytes(json_data) + async with aiofiles.open(temp_file, "wb") as f: + await f.write(json_data) temp_file.replace(self.staged_file) logger.info(f"临时记忆已保存: {len(staged_memories)} 条") @@ -165,7 +169,8 @@ class PersistenceManager: return [] try: - json_data = self.staged_file.read_bytes() + async with aiofiles.open(self.staged_file, "rb") as f: + json_data = await f.read() data = orjson.loads(json_data) staged_memories = [StagedMemory.from_dict(sm) for sm in data.get("staged_memories", [])] @@ -190,9 +195,10 @@ class PersistenceManager: if self.graph_file.exists(): # 复制图数据文件 - import shutil - - shutil.copy2(self.graph_file, backup_file) + async with aiofiles.open(self.graph_file, "rb") as src: + async with aiofiles.open(backup_file, "wb") as dst: + while chunk := await src.read(8192): + await dst.write(chunk) # 清理旧备份(只保留最近10个) await self._cleanup_old_backups(keep=10) @@ -219,7 +225,8 @@ class PersistenceManager: latest_backup = backup_files[0] logger.warning(f"尝试从备份恢复: {latest_backup}") - json_data = latest_backup.read_bytes() + async with aiofiles.open(latest_backup, "rb") as f: + json_data = await f.read() data = orjson.loads(json_data) graph_store = GraphStore.from_dict(data) @@ -323,8 +330,9 @@ class PersistenceManager: # 使用标准 json 以获得更好的可读性 output_file.parent.mkdir(parents=True, exist_ok=True) - with output_file.open("w", encoding="utf-8") as f: - json.dump(data, f, ensure_ascii=False, indent=2) + async with aiofiles.open(output_file, "w", encoding="utf-8") as f: + json_str = json.dumps(data, ensure_ascii=False, indent=2) + await f.write(json_str) logger.info(f"图数据已导出: {output_file}") @@ -343,8 +351,9 @@ class PersistenceManager: GraphStore 对象 """ try: - with input_file.open("r", encoding="utf-8") as f: - data = json.load(f) + async with aiofiles.open(input_file, "r", encoding="utf-8") as f: + content = await f.read() + data = json.loads(content) graph_store = GraphStore.from_dict(data) logger.info(f"图数据已导入: {graph_store.get_statistics()}")