""" 为现有节点生成嵌入向量 批量为图存储中缺少嵌入向量的节点生成并索引嵌入向量 使用场景: 1. 历史记忆节点没有嵌入向量 2. 嵌入生成器之前未配置,现在需要补充生成 3. 向量索引损坏需要重建 使用方法: python scripts/generate_missing_embeddings.py [--node-types TOPIC,OBJECT] [--batch-size 50] 参数说明: --node-types: 需要生成嵌入的节点类型,默认为 TOPIC,OBJECT --batch-size: 批量处理大小,默认为 50 """ import asyncio import sys from pathlib import Path from typing import List # 添加项目根目录到路径 sys.path.insert(0, str(Path(__file__).parent.parent)) async def generate_missing_embeddings( target_node_types: List[str] = None, batch_size: int = 50, ): """ 为缺失嵌入向量的节点生成嵌入 Args: target_node_types: 需要处理的节点类型列表(如 ["主题", "客体"]) batch_size: 批处理大小 """ from src.common.logger import get_logger from src.memory_graph.manager_singleton import get_memory_manager, initialize_memory_manager from src.memory_graph.models import NodeType logger = get_logger("generate_missing_embeddings") if target_node_types is None: target_node_types = [NodeType.TOPIC.value, NodeType.OBJECT.value] print(f"\n{'='*80}") print(f"🔧 为节点生成嵌入向量") print(f"{'='*80}\n") print(f"目标节点类型: {', '.join(target_node_types)}") print(f"批处理大小: {batch_size}\n") # 1. 初始化记忆管理器 print(f"🔧 正在初始化记忆管理器...") await initialize_memory_manager() manager = get_memory_manager() if manager is None: print("❌ 记忆管理器初始化失败") return print(f"✅ 记忆管理器已初始化\n") # 2. 获取已索引的节点ID print(f"🔍 检查现有向量索引...") existing_node_ids = set() try: vector_count = manager.vector_store.collection.count() if vector_count > 0: # 分批获取所有已索引的ID batch_size_check = 1000 for offset in range(0, vector_count, batch_size_check): limit = min(batch_size_check, vector_count - offset) result = manager.vector_store.collection.get( limit=limit, offset=offset, ) if result and "ids" in result: existing_node_ids.update(result["ids"]) print(f"✅ 发现 {len(existing_node_ids)} 个已索引节点\n") except Exception as e: logger.warning(f"获取已索引节点ID失败: {e}") print(f"⚠️ 无法获取已索引节点,将尝试跳过重复项\n") # 3. 收集需要生成嵌入的节点 print(f"🔍 扫描需要生成嵌入的节点...") all_memories = manager.graph_store.get_all_memories() nodes_to_process = [] total_target_nodes = 0 type_stats = {nt: {"total": 0, "need_emb": 0, "already_indexed": 0} for nt in target_node_types} for memory in all_memories: for node in memory.nodes: if node.node_type.value in target_node_types: total_target_nodes += 1 type_stats[node.node_type.value]["total"] += 1 # 检查是否已在向量索引中 if node.id in existing_node_ids: type_stats[node.node_type.value]["already_indexed"] += 1 continue if not node.has_embedding(): nodes_to_process.append({ "node": node, "memory_id": memory.id, }) type_stats[node.node_type.value]["need_emb"] += 1 print(f"\n📊 扫描结果:") for node_type in target_node_types: stats = type_stats[node_type] already_ok = stats["already_indexed"] coverage = (stats["total"] - stats["need_emb"]) / stats["total"] * 100 if stats["total"] > 0 else 0 print(f" - {node_type}: {stats['total']} 个节点, {stats['need_emb']} 个缺失嵌入, " f"{already_ok} 个已索引 (覆盖率: {coverage:.1f}%)") print(f"\n 总计: {total_target_nodes} 个目标节点, {len(nodes_to_process)} 个需要生成嵌入\n") if len(nodes_to_process) == 0: print(f"✅ 所有节点已有嵌入向量,无需生成") return # 3. 批量生成嵌入 print(f"🚀 开始生成嵌入向量...\n") total_batches = (len(nodes_to_process) + batch_size - 1) // batch_size success_count = 0 failed_count = 0 indexed_count = 0 for i in range(0, len(nodes_to_process), batch_size): batch = nodes_to_process[i : i + batch_size] batch_num = i // batch_size + 1 print(f"📦 批次 {batch_num}/{total_batches} ({len(batch)} 个节点)...") try: # 提取文本内容 texts = [item["node"].content for item in batch] # 批量生成嵌入 embeddings = await manager.embedding_generator.generate_batch(texts) # 为节点设置嵌入并索引 batch_nodes_for_index = [] for j, (item, embedding) in enumerate(zip(batch, embeddings)): node = item["node"] if embedding is not None: # 设置嵌入向量 node.embedding = embedding batch_nodes_for_index.append(node) success_count += 1 else: failed_count += 1 logger.warning(f" ⚠️ 节点 {node.id[:8]}... '{node.content[:30]}' 嵌入生成失败") # 批量索引到向量数据库 if batch_nodes_for_index: try: await manager.vector_store.add_nodes_batch(batch_nodes_for_index) indexed_count += len(batch_nodes_for_index) print(f" ✅ 成功: {len(batch_nodes_for_index)}/{len(batch)} 个节点已生成并索引") except Exception as e: # 如果批量失败,尝试逐个添加(跳过重复) logger.warning(f" 批量索引失败,尝试逐个添加: {e}") individual_success = 0 for node in batch_nodes_for_index: try: await manager.vector_store.add_node(node) individual_success += 1 indexed_count += 1 except Exception as e2: if "Expected IDs to be unique" in str(e2): logger.debug(f" 跳过已存在节点: {node.id}") else: logger.error(f" 节点 {node.id} 索引失败: {e2}") print(f" ⚠️ 逐个索引: {individual_success}/{len(batch_nodes_for_index)} 个成功") except Exception as e: failed_count += len(batch) logger.error(f"批次 {batch_num} 处理失败", exc_info=True) print(f" ❌ 批次处理失败: {e}") # 显示进度 total_processed = min(i + batch_size, len(nodes_to_process)) progress = total_processed / len(nodes_to_process) * 100 print(f" 📊 总进度: {total_processed}/{len(nodes_to_process)} ({progress:.1f}%)\n") # 4. 保存图数据(更新节点的 embedding 字段) print(f"💾 保存图数据...") try: await manager.persistence.save_graph_store(manager.graph_store) print(f"✅ 图数据已保存\n") except Exception as e: logger.error(f"保存图数据失败", exc_info=True) print(f"❌ 保存失败: {e}\n") # 5. 验证结果 print(f"🔍 验证向量索引...") final_vector_count = manager.vector_store.collection.count() stats = manager.graph_store.get_statistics() total_nodes = stats["total_nodes"] print(f"\n{'='*80}") print(f"📊 生成完成") print(f"{'='*80}") print(f"处理节点数: {len(nodes_to_process)}") print(f"成功生成: {success_count}") print(f"失败数量: {failed_count}") print(f"成功索引: {indexed_count}") print(f"向量索引节点数: {final_vector_count}") print(f"图存储节点数: {total_nodes}") print(f"索引覆盖率: {final_vector_count / total_nodes * 100:.1f}%\n") # 6. 测试搜索 print(f"🧪 测试搜索功能...") test_queries = ["小红帽蕾克", "拾风", "杰瑞喵"] for query in test_queries: results = await manager.search_memories(query=query, top_k=3) if results: print(f"\n✅ 查询 '{query}' 找到 {len(results)} 条记忆:") for i, memory in enumerate(results[:2], 1): subject_node = memory.get_subject_node() # 获取主题节点(遍历所有节点找TOPIC类型) from src.memory_graph.models import NodeType topic_nodes = [n for n in memory.nodes if n.node_type == NodeType.TOPIC] subject = subject_node.content if subject_node else "?" topic = topic_nodes[0].content if topic_nodes else "?" print(f" {i}. {subject} - {topic} (重要性: {memory.importance:.2f})") else: print(f"\n⚠️ 查询 '{query}' 返回 0 条结果") async def main(): import argparse parser = argparse.ArgumentParser(description="为节点生成嵌入向量") parser.add_argument( "--node-types", type=str, default="主题,客体", help="需要生成嵌入的节点类型,逗号分隔(默认:主题,客体)", ) parser.add_argument( "--batch-size", type=int, default=50, help="批处理大小(默认:50)", ) args = parser.parse_args() target_types = [t.strip() for t in args.node_types.split(",")] await generate_missing_embeddings( target_node_types=target_types, batch_size=args.batch_size, ) if __name__ == "__main__": asyncio.run(main())