ruff
This commit is contained in:
@@ -17,19 +17,19 @@
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Tuple
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
handlers=[
|
||||
logging.StreamHandler(),
|
||||
logging.FileHandler('embedding_cleanup.log', encoding='utf-8')
|
||||
logging.FileHandler("embedding_cleanup.log", encoding="utf-8")
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -49,13 +49,13 @@ class EmbeddingCleaner:
|
||||
self.cleaned_files = []
|
||||
self.errors = []
|
||||
self.stats = {
|
||||
'files_processed': 0,
|
||||
'embedings_removed': 0,
|
||||
'bytes_saved': 0,
|
||||
'nodes_processed': 0
|
||||
"files_processed": 0,
|
||||
"embedings_removed": 0,
|
||||
"bytes_saved": 0,
|
||||
"nodes_processed": 0
|
||||
}
|
||||
|
||||
def find_json_files(self) -> List[Path]:
|
||||
def find_json_files(self) -> list[Path]:
|
||||
"""查找可能包含向量数据的 JSON 文件"""
|
||||
json_files = []
|
||||
|
||||
@@ -65,7 +65,7 @@ class EmbeddingCleaner:
|
||||
json_files.append(memory_graph_file)
|
||||
|
||||
# 测试数据文件
|
||||
test_dir = self.data_dir / "test_*"
|
||||
self.data_dir / "test_*"
|
||||
for test_path in self.data_dir.glob("test_*/memory_graph.json"):
|
||||
if test_path.exists():
|
||||
json_files.append(test_path)
|
||||
@@ -82,7 +82,7 @@ class EmbeddingCleaner:
|
||||
logger.info(f"找到 {len(json_files)} 个需要处理的 JSON 文件")
|
||||
return json_files
|
||||
|
||||
def analyze_embedding_in_data(self, data: Dict[str, Any]) -> int:
|
||||
def analyze_embedding_in_data(self, data: dict[str, Any]) -> int:
|
||||
"""
|
||||
分析数据中的 embedding 字段数量
|
||||
|
||||
@@ -97,7 +97,7 @@ class EmbeddingCleaner:
|
||||
def count_embeddings(obj):
|
||||
nonlocal embedding_count
|
||||
if isinstance(obj, dict):
|
||||
if 'embedding' in obj:
|
||||
if "embedding" in obj:
|
||||
embedding_count += 1
|
||||
for value in obj.values():
|
||||
count_embeddings(value)
|
||||
@@ -108,7 +108,7 @@ class EmbeddingCleaner:
|
||||
count_embeddings(data)
|
||||
return embedding_count
|
||||
|
||||
def clean_embedding_from_data(self, data: Dict[str, Any]) -> Tuple[Dict[str, Any], int]:
|
||||
def clean_embedding_from_data(self, data: dict[str, Any]) -> tuple[dict[str, Any], int]:
|
||||
"""
|
||||
从数据中移除 embedding 字段
|
||||
|
||||
@@ -123,8 +123,8 @@ class EmbeddingCleaner:
|
||||
def remove_embeddings(obj):
|
||||
nonlocal removed_count
|
||||
if isinstance(obj, dict):
|
||||
if 'embedding' in obj:
|
||||
del obj['embedding']
|
||||
if "embedding" in obj:
|
||||
del obj["embedding"]
|
||||
removed_count += 1
|
||||
for value in obj.values():
|
||||
remove_embeddings(value)
|
||||
@@ -162,14 +162,14 @@ class EmbeddingCleaner:
|
||||
data = orjson.loads(original_content)
|
||||
except orjson.JSONDecodeError:
|
||||
# 回退到标准 json
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
with open(file_path, encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
# 分析 embedding 数据
|
||||
embedding_count = self.analyze_embedding_in_data(data)
|
||||
|
||||
if embedding_count == 0:
|
||||
logger.info(f" ✓ 文件中没有 embedding 数据,跳过")
|
||||
logger.info(" ✓ 文件中没有 embedding 数据,跳过")
|
||||
return True
|
||||
|
||||
logger.info(f" 发现 {embedding_count} 个 embedding 字段")
|
||||
@@ -193,30 +193,30 @@ class EmbeddingCleaner:
|
||||
cleaned_data,
|
||||
indent=2,
|
||||
ensure_ascii=False
|
||||
).encode('utf-8')
|
||||
).encode("utf-8")
|
||||
|
||||
cleaned_size = len(cleaned_content)
|
||||
bytes_saved = original_size - cleaned_size
|
||||
|
||||
# 原子写入
|
||||
temp_file = file_path.with_suffix('.tmp')
|
||||
temp_file = file_path.with_suffix(".tmp")
|
||||
temp_file.write_bytes(cleaned_content)
|
||||
temp_file.replace(file_path)
|
||||
|
||||
logger.info(f" ✓ 清理完成:")
|
||||
logger.info(" ✓ 清理完成:")
|
||||
logger.info(f" - 移除 embedding 字段: {removed_count}")
|
||||
logger.info(f" - 节省空间: {bytes_saved:,} 字节 ({bytes_saved/original_size*100:.1f}%)")
|
||||
logger.info(f" - 新文件大小: {cleaned_size:,} 字节")
|
||||
|
||||
# 更新统计
|
||||
self.stats['embedings_removed'] += removed_count
|
||||
self.stats['bytes_saved'] += bytes_saved
|
||||
self.stats["embedings_removed"] += removed_count
|
||||
self.stats["bytes_saved"] += bytes_saved
|
||||
|
||||
else:
|
||||
logger.info(f" [试运行] 将移除 {embedding_count} 个 embedding 字段")
|
||||
self.stats['embedings_removed'] += embedding_count
|
||||
self.stats["embedings_removed"] += embedding_count
|
||||
|
||||
self.stats['files_processed'] += 1
|
||||
self.stats["files_processed"] += 1
|
||||
self.cleaned_files.append(file_path)
|
||||
return True
|
||||
|
||||
@@ -236,12 +236,12 @@ class EmbeddingCleaner:
|
||||
节点数量
|
||||
"""
|
||||
try:
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
with open(file_path, encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
node_count = 0
|
||||
if 'nodes' in data and isinstance(data['nodes'], list):
|
||||
node_count = len(data['nodes'])
|
||||
if "nodes" in data and isinstance(data["nodes"], list):
|
||||
node_count = len(data["nodes"])
|
||||
|
||||
return node_count
|
||||
|
||||
@@ -268,7 +268,7 @@ class EmbeddingCleaner:
|
||||
|
||||
# 统计总节点数
|
||||
total_nodes = sum(self.analyze_nodes_in_file(f) for f in json_files)
|
||||
self.stats['nodes_processed'] = total_nodes
|
||||
self.stats["nodes_processed"] = total_nodes
|
||||
|
||||
logger.info(f"总计 {len(json_files)} 个文件,{total_nodes} 个节点")
|
||||
|
||||
@@ -295,8 +295,8 @@ class EmbeddingCleaner:
|
||||
|
||||
if not dry_run:
|
||||
logger.info(f"节省空间: {self.stats['bytes_saved']:,} 字节")
|
||||
if self.stats['bytes_saved'] > 0:
|
||||
mb_saved = self.stats['bytes_saved'] / 1024 / 1024
|
||||
if self.stats["bytes_saved"] > 0:
|
||||
mb_saved = self.stats["bytes_saved"] / 1024 / 1024
|
||||
logger.info(f"节省空间: {mb_saved:.2f} MB")
|
||||
|
||||
if self.errors:
|
||||
@@ -342,7 +342,7 @@ def main():
|
||||
print(" 请确保向量数据库正在正常工作。")
|
||||
print()
|
||||
response = input("确认继续?(yes/no): ")
|
||||
if response.lower() not in ['yes', 'y', '是']:
|
||||
if response.lower() not in ["yes", "y", "是"]:
|
||||
print("操作已取消")
|
||||
return
|
||||
|
||||
@@ -352,4 +352,4 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user