diff --git a/clean_embedding_data.py b/clean_embedding_data.py new file mode 100644 index 000000000..1cac5c870 --- /dev/null +++ b/clean_embedding_data.py @@ -0,0 +1,355 @@ +#!/usr/bin/env python3 +""" +清理记忆数据中的向量数据 + +此脚本用于清理现有 JSON 文件中的 embedding 字段,确保向量数据只存储在专门的向量数据库中。 +这样可以: +1. 减少 JSON 文件大小 +2. 提高读写性能 +3. 避免数据冗余 +4. 确保数据一致性 + +使用方法: + python clean_embedding_data.py [--dry-run] + + --dry-run: 仅显示将要清理的统计信息,不实际修改文件 +""" + +import argparse +import json +from pathlib import Path +from typing import Dict, Any, List, Tuple +import logging + +import orjson + +# 配置日志 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + handlers=[ + logging.StreamHandler(), + logging.FileHandler('embedding_cleanup.log', encoding='utf-8') + ] +) +logger = logging.getLogger(__name__) + + +class EmbeddingCleaner: + """向量数据清理器""" + + def __init__(self, data_dir: str = "data"): + """ + 初始化清理器 + + Args: + data_dir: 数据目录路径 + """ + self.data_dir = Path(data_dir) + self.cleaned_files = [] + self.errors = [] + self.stats = { + 'files_processed': 0, + 'embedings_removed': 0, + 'bytes_saved': 0, + 'nodes_processed': 0 + } + + def find_json_files(self) -> List[Path]: + """查找可能包含向量数据的 JSON 文件""" + json_files = [] + + # 记忆图数据文件 + memory_graph_file = self.data_dir / "memory_graph" / "memory_graph.json" + if memory_graph_file.exists(): + json_files.append(memory_graph_file) + + # 测试数据文件 + test_dir = self.data_dir / "test_*" + for test_path in self.data_dir.glob("test_*/memory_graph.json"): + if test_path.exists(): + json_files.append(test_path) + + # 其他可能的记忆相关文件 + potential_files = [ + self.data_dir / "memory_metadata_index.json", + ] + + for file_path in potential_files: + if file_path.exists(): + json_files.append(file_path) + + logger.info(f"找到 {len(json_files)} 个需要处理的 JSON 文件") + return json_files + + def analyze_embedding_in_data(self, data: Dict[str, Any]) -> int: + """ + 分析数据中的 embedding 字段数量 + + Args: + data: 要分析的数据 + + Returns: + embedding 字段的数量 + """ + embedding_count = 0 + + def count_embeddings(obj): + nonlocal embedding_count + if isinstance(obj, dict): + if 'embedding' in obj: + embedding_count += 1 + for value in obj.values(): + count_embeddings(value) + elif isinstance(obj, list): + for item in obj: + count_embeddings(item) + + count_embeddings(data) + return embedding_count + + def clean_embedding_from_data(self, data: Dict[str, Any]) -> Tuple[Dict[str, Any], int]: + """ + 从数据中移除 embedding 字段 + + Args: + data: 要清理的数据 + + Returns: + (清理后的数据, 移除的 embedding 数量) + """ + removed_count = 0 + + def remove_embeddings(obj): + nonlocal removed_count + if isinstance(obj, dict): + if 'embedding' in obj: + del obj['embedding'] + removed_count += 1 + for value in obj.values(): + remove_embeddings(value) + elif isinstance(obj, list): + for item in obj: + remove_embeddings(item) + + # 创建深拷贝以避免修改原数据 + import copy + cleaned_data = copy.deepcopy(data) + remove_embeddings(cleaned_data) + + return cleaned_data, removed_count + + def process_file(self, file_path: Path, dry_run: bool = False) -> bool: + """ + 处理单个文件 + + Args: + file_path: 文件路径 + dry_run: 是否为试运行模式 + + Returns: + 是否处理成功 + """ + try: + logger.info(f"处理文件: {file_path}") + + # 读取原文件 + original_content = file_path.read_bytes() + original_size = len(original_content) + + # 解析 JSON 数据 + try: + data = orjson.loads(original_content) + except orjson.JSONDecodeError: + # 回退到标准 json + with open(file_path, 'r', encoding='utf-8') as f: + data = json.load(f) + + # 分析 embedding 数据 + embedding_count = self.analyze_embedding_in_data(data) + + if embedding_count == 0: + logger.info(f" ✓ 文件中没有 embedding 数据,跳过") + return True + + logger.info(f" 发现 {embedding_count} 个 embedding 字段") + + if not dry_run: + # 清理 embedding 数据 + cleaned_data, removed_count = self.clean_embedding_from_data(data) + + if removed_count != embedding_count: + logger.warning(f" ⚠️ 清理数量不一致: 分析发现 {embedding_count}, 实际清理 {removed_count}") + + # 序列化清理后的数据 + try: + cleaned_content = orjson.dumps( + cleaned_data, + option=orjson.OPT_INDENT_2 | orjson.OPT_SERIALIZE_NUMPY + ) + except Exception: + # 回退到标准 json + cleaned_content = json.dumps( + cleaned_data, + indent=2, + ensure_ascii=False + ).encode('utf-8') + + cleaned_size = len(cleaned_content) + bytes_saved = original_size - cleaned_size + + # 原子写入 + temp_file = file_path.with_suffix('.tmp') + temp_file.write_bytes(cleaned_content) + temp_file.replace(file_path) + + logger.info(f" ✓ 清理完成:") + logger.info(f" - 移除 embedding 字段: {removed_count}") + logger.info(f" - 节省空间: {bytes_saved:,} 字节 ({bytes_saved/original_size*100:.1f}%)") + logger.info(f" - 新文件大小: {cleaned_size:,} 字节") + + # 更新统计 + self.stats['embedings_removed'] += removed_count + self.stats['bytes_saved'] += bytes_saved + + else: + logger.info(f" [试运行] 将移除 {embedding_count} 个 embedding 字段") + self.stats['embedings_removed'] += embedding_count + + self.stats['files_processed'] += 1 + self.cleaned_files.append(file_path) + return True + + except Exception as e: + logger.error(f" ❌ 处理失败: {e}") + self.errors.append((str(file_path), str(e))) + return False + + def analyze_nodes_in_file(self, file_path: Path) -> int: + """ + 分析文件中的节点数量 + + Args: + file_path: 文件路径 + + Returns: + 节点数量 + """ + try: + with open(file_path, 'r', encoding='utf-8') as f: + data = json.load(f) + + node_count = 0 + if 'nodes' in data and isinstance(data['nodes'], list): + node_count = len(data['nodes']) + + return node_count + + except Exception as e: + logger.warning(f"分析节点数量失败: {e}") + return 0 + + def run(self, dry_run: bool = False): + """ + 运行清理过程 + + Args: + dry_run: 是否为试运行模式 + """ + logger.info("开始向量数据清理") + logger.info(f"模式: {'试运行' if dry_run else '正式执行'}") + + # 查找要处理的文件 + json_files = self.find_json_files() + + if not json_files: + logger.info("没有找到需要处理的文件") + return + + # 统计总节点数 + total_nodes = sum(self.analyze_nodes_in_file(f) for f in json_files) + self.stats['nodes_processed'] = total_nodes + + logger.info(f"总计 {len(json_files)} 个文件,{total_nodes} 个节点") + + # 处理每个文件 + success_count = 0 + for file_path in json_files: + if self.process_file(file_path, dry_run): + success_count += 1 + + # 输出统计信息 + self.print_summary(dry_run, success_count, len(json_files)) + + def print_summary(self, dry_run: bool, success_count: int, total_files: int): + """打印清理摘要""" + logger.info("=" * 60) + logger.info("清理摘要") + logger.info("=" * 60) + + mode = "试运行" if dry_run else "正式执行" + logger.info(f"执行模式: {mode}") + logger.info(f"处理文件: {success_count}/{total_files}") + logger.info(f"处理节点: {self.stats['nodes_processed']}") + logger.info(f"清理 embedding 字段: {self.stats['embedings_removed']}") + + if not dry_run: + logger.info(f"节省空间: {self.stats['bytes_saved']:,} 字节") + if self.stats['bytes_saved'] > 0: + mb_saved = self.stats['bytes_saved'] / 1024 / 1024 + logger.info(f"节省空间: {mb_saved:.2f} MB") + + if self.errors: + logger.warning(f"遇到 {len(self.errors)} 个错误:") + for file_path, error in self.errors: + logger.warning(f" {file_path}: {error}") + + if success_count == total_files and not self.errors: + logger.info("所有文件处理成功!") + + logger.info("=" * 60) + + +def main(): + """主函数""" + parser = argparse.ArgumentParser( + description="清理记忆数据中的向量数据", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +示例用法: + python clean_embedding_data.py --dry-run # 试运行,查看统计信息 + python clean_embedding_data.py # 正式执行清理 + """ + ) + + parser.add_argument( + "--dry-run", + action="store_true", + help="试运行模式,不实际修改文件" + ) + + parser.add_argument( + "--data-dir", + default="data", + help="数据目录路径 (默认: data)" + ) + + args = parser.parse_args() + + # 确认操作 + if not args.dry_run: + print("警告:此操作将永久删除 JSON 文件中的 embedding 数据!") + print(" 请确保向量数据库正在正常工作。") + print() + response = input("确认继续?(yes/no): ") + if response.lower() not in ['yes', 'y', '是']: + print("操作已取消") + return + + # 执行清理 + cleaner = EmbeddingCleaner(args.data_dir) + cleaner.run(dry_run=args.dry_run) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/memory_graph/core/builder.py b/src/memory_graph/core/builder.py index 5c93dd3dc..26a94d019 100644 --- a/src/memory_graph/core/builder.py +++ b/src/memory_graph/core/builder.py @@ -355,12 +355,12 @@ class MemoryBuilder: for node_id in self.graph_store.graph.nodes(): node_data = self.graph_store.graph.nodes[node_id] if node_data.get("content") == content and node_data.get("node_type") == node_type.value: - # 重建 MemoryNode 对象 + # 重建 MemoryNode 对象(embedding 数据从向量数据库单独获取) return MemoryNode( id=node_id, content=node_data["content"], node_type=NodeType(node_data["node_type"]), - embedding=node_data.get("embedding"), + embedding=None, # 图存储不包含 embedding,需要从向量数据库获取 metadata=node_data.get("metadata", {}), ) @@ -405,7 +405,7 @@ class MemoryBuilder: id=node_id, content=node_data["content"], node_type=NodeType(node_data["node_type"]), - embedding=node_data.get("embedding"), + embedding=None, # 图存储不包含 embedding,需要从向量数据库获取 metadata=node_data.get("metadata", {}), ) # 添加当前记忆ID到元数据 @@ -455,7 +455,7 @@ class MemoryBuilder: id=node_id, content=node_data["content"], node_type=NodeType(node_data["node_type"]), - embedding=node_data.get("embedding"), + embedding=None, # 图存储不包含 embedding,需要从向量数据库获取 metadata=node_data.get("metadata", {}), ) return existing_node diff --git a/src/memory_graph/models.py b/src/memory_graph/models.py index c01716aff..6f7f1109d 100644 --- a/src/memory_graph/models.py +++ b/src/memory_graph/models.py @@ -71,11 +71,11 @@ class MemoryNode: def to_dict(self) -> dict[str, Any]: """转换为字典(用于序列化)""" + # 不序列化 embedding 数据,向量数据由专门的向量数据库管理 return { "id": self.id, "content": self.content, "node_type": self.node_type.value, - "embedding": self.embedding.tolist() if self.embedding is not None else None, "metadata": self.metadata, "created_at": self.created_at.isoformat(), } @@ -83,15 +83,12 @@ class MemoryNode: @classmethod def from_dict(cls, data: dict[str, Any]) -> MemoryNode: """从字典创建节点""" - embedding = None - if data.get("embedding") is not None: - embedding = np.array(data["embedding"]) - + # 不从 JSON 中读取 embedding 数据,向量数据由专门的向量数据库管理 return cls( id=data["id"], content=data["content"], node_type=NodeType(data["node_type"]), - embedding=embedding, + embedding=None, # 向量数据需要从向量数据库中单独加载 metadata=data.get("metadata", {}), created_at=datetime.fromisoformat(data["created_at"]), )