This commit is contained in:
明天好像没什么
2025-11-07 21:01:45 +08:00
parent 80b040da2f
commit c8d7c09625
49 changed files with 854 additions and 872 deletions

View File

@@ -17,19 +17,19 @@
import argparse
import json
from pathlib import Path
from typing import Dict, Any, List, Tuple
import logging
from pathlib import Path
from typing import Any
import orjson
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[
logging.StreamHandler(),
logging.FileHandler('embedding_cleanup.log', encoding='utf-8')
logging.FileHandler("embedding_cleanup.log", encoding="utf-8")
]
)
logger = logging.getLogger(__name__)
@@ -49,13 +49,13 @@ class EmbeddingCleaner:
self.cleaned_files = []
self.errors = []
self.stats = {
'files_processed': 0,
'embedings_removed': 0,
'bytes_saved': 0,
'nodes_processed': 0
"files_processed": 0,
"embedings_removed": 0,
"bytes_saved": 0,
"nodes_processed": 0
}
def find_json_files(self) -> List[Path]:
def find_json_files(self) -> list[Path]:
"""查找可能包含向量数据的 JSON 文件"""
json_files = []
@@ -65,7 +65,7 @@ class EmbeddingCleaner:
json_files.append(memory_graph_file)
# 测试数据文件
test_dir = self.data_dir / "test_*"
self.data_dir / "test_*"
for test_path in self.data_dir.glob("test_*/memory_graph.json"):
if test_path.exists():
json_files.append(test_path)
@@ -82,7 +82,7 @@ class EmbeddingCleaner:
logger.info(f"找到 {len(json_files)} 个需要处理的 JSON 文件")
return json_files
def analyze_embedding_in_data(self, data: Dict[str, Any]) -> int:
def analyze_embedding_in_data(self, data: dict[str, Any]) -> int:
"""
分析数据中的 embedding 字段数量
@@ -97,7 +97,7 @@ class EmbeddingCleaner:
def count_embeddings(obj):
nonlocal embedding_count
if isinstance(obj, dict):
if 'embedding' in obj:
if "embedding" in obj:
embedding_count += 1
for value in obj.values():
count_embeddings(value)
@@ -108,7 +108,7 @@ class EmbeddingCleaner:
count_embeddings(data)
return embedding_count
def clean_embedding_from_data(self, data: Dict[str, Any]) -> Tuple[Dict[str, Any], int]:
def clean_embedding_from_data(self, data: dict[str, Any]) -> tuple[dict[str, Any], int]:
"""
从数据中移除 embedding 字段
@@ -123,8 +123,8 @@ class EmbeddingCleaner:
def remove_embeddings(obj):
nonlocal removed_count
if isinstance(obj, dict):
if 'embedding' in obj:
del obj['embedding']
if "embedding" in obj:
del obj["embedding"]
removed_count += 1
for value in obj.values():
remove_embeddings(value)
@@ -162,14 +162,14 @@ class EmbeddingCleaner:
data = orjson.loads(original_content)
except orjson.JSONDecodeError:
# 回退到标准 json
with open(file_path, 'r', encoding='utf-8') as f:
with open(file_path, encoding="utf-8") as f:
data = json.load(f)
# 分析 embedding 数据
embedding_count = self.analyze_embedding_in_data(data)
if embedding_count == 0:
logger.info(f" ✓ 文件中没有 embedding 数据,跳过")
logger.info(" ✓ 文件中没有 embedding 数据,跳过")
return True
logger.info(f" 发现 {embedding_count} 个 embedding 字段")
@@ -193,30 +193,30 @@ class EmbeddingCleaner:
cleaned_data,
indent=2,
ensure_ascii=False
).encode('utf-8')
).encode("utf-8")
cleaned_size = len(cleaned_content)
bytes_saved = original_size - cleaned_size
# 原子写入
temp_file = file_path.with_suffix('.tmp')
temp_file = file_path.with_suffix(".tmp")
temp_file.write_bytes(cleaned_content)
temp_file.replace(file_path)
logger.info(f" ✓ 清理完成:")
logger.info(" ✓ 清理完成:")
logger.info(f" - 移除 embedding 字段: {removed_count}")
logger.info(f" - 节省空间: {bytes_saved:,} 字节 ({bytes_saved/original_size*100:.1f}%)")
logger.info(f" - 新文件大小: {cleaned_size:,} 字节")
# 更新统计
self.stats['embedings_removed'] += removed_count
self.stats['bytes_saved'] += bytes_saved
self.stats["embedings_removed"] += removed_count
self.stats["bytes_saved"] += bytes_saved
else:
logger.info(f" [试运行] 将移除 {embedding_count} 个 embedding 字段")
self.stats['embedings_removed'] += embedding_count
self.stats["embedings_removed"] += embedding_count
self.stats['files_processed'] += 1
self.stats["files_processed"] += 1
self.cleaned_files.append(file_path)
return True
@@ -236,12 +236,12 @@ class EmbeddingCleaner:
节点数量
"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
with open(file_path, encoding="utf-8") as f:
data = json.load(f)
node_count = 0
if 'nodes' in data and isinstance(data['nodes'], list):
node_count = len(data['nodes'])
if "nodes" in data and isinstance(data["nodes"], list):
node_count = len(data["nodes"])
return node_count
@@ -268,7 +268,7 @@ class EmbeddingCleaner:
# 统计总节点数
total_nodes = sum(self.analyze_nodes_in_file(f) for f in json_files)
self.stats['nodes_processed'] = total_nodes
self.stats["nodes_processed"] = total_nodes
logger.info(f"总计 {len(json_files)} 个文件,{total_nodes} 个节点")
@@ -295,8 +295,8 @@ class EmbeddingCleaner:
if not dry_run:
logger.info(f"节省空间: {self.stats['bytes_saved']:,} 字节")
if self.stats['bytes_saved'] > 0:
mb_saved = self.stats['bytes_saved'] / 1024 / 1024
if self.stats["bytes_saved"] > 0:
mb_saved = self.stats["bytes_saved"] / 1024 / 1024
logger.info(f"节省空间: {mb_saved:.2f} MB")
if self.errors:
@@ -342,7 +342,7 @@ def main():
print(" 请确保向量数据库正在正常工作。")
print()
response = input("确认继续?(yes/no): ")
if response.lower() not in ['yes', 'y', '']:
if response.lower() not in ["yes", "y", ""]:
print("操作已取消")
return
@@ -352,4 +352,4 @@ def main():
if __name__ == "__main__":
main()
main()