Files
Mofox-Core/scripts/generate_missing_embeddings.py

268 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
为现有节点生成嵌入向量
批量为图存储中缺少嵌入向量的节点生成并索引嵌入向量
使用场景:
1. 历史记忆节点没有嵌入向量
2. 嵌入生成器之前未配置,现在需要补充生成
3. 向量索引损坏需要重建
使用方法:
python scripts/generate_missing_embeddings.py [--node-types TOPIC,OBJECT] [--batch-size 50]
参数说明:
--node-types: 需要生成嵌入的节点类型,默认为 TOPIC,OBJECT
--batch-size: 批量处理大小,默认为 50
"""
import asyncio
import sys
from pathlib import Path
# 添加项目根目录到路径
sys.path.insert(0, str(Path(__file__).parent.parent))
async def generate_missing_embeddings(
target_node_types: list[str] | None = None,
batch_size: int = 50,
):
"""
为缺失嵌入向量的节点生成嵌入
Args:
target_node_types: 需要处理的节点类型列表(如 ["主题", "客体"]
batch_size: 批处理大小
"""
from src.common.logger import get_logger
from src.memory_graph.manager_singleton import get_memory_manager, initialize_memory_manager
from src.memory_graph.models import NodeType
logger = get_logger("generate_missing_embeddings")
if target_node_types is None:
target_node_types = [NodeType.TOPIC.value, NodeType.OBJECT.value]
print(f"\n{'='*80}")
print("🔧 为节点生成嵌入向量")
print(f"{'='*80}\n")
print(f"目标节点类型: {', '.join(target_node_types)}")
print(f"批处理大小: {batch_size}\n")
# 1. 初始化记忆管理器
print("🔧 正在初始化记忆管理器...")
await initialize_memory_manager()
manager = get_memory_manager()
if manager is None:
print("❌ 记忆管理器初始化失败")
return
print("✅ 记忆管理器已初始化\n")
# 2. 获取已索引的节点ID
print("🔍 检查现有向量索引...")
existing_node_ids = set()
try:
vector_count = manager.vector_store.collection.count()
if vector_count > 0:
# 分批获取所有已索引的ID
batch_size_check = 1000
for offset in range(0, vector_count, batch_size_check):
limit = min(batch_size_check, vector_count - offset)
result = manager.vector_store.collection.get(
limit=limit,
offset=offset,
)
if result and "ids" in result:
existing_node_ids.update(result["ids"])
print(f"✅ 发现 {len(existing_node_ids)} 个已索引节点\n")
except Exception as e:
logger.warning(f"获取已索引节点ID失败: {e}")
print("⚠️ 无法获取已索引节点,将尝试跳过重复项\n")
# 3. 收集需要生成嵌入的节点
print("🔍 扫描需要生成嵌入的节点...")
all_memories = manager.graph_store.get_all_memories()
nodes_to_process = []
total_target_nodes = 0
type_stats = {nt: {"total": 0, "need_emb": 0, "already_indexed": 0} for nt in target_node_types}
for memory in all_memories:
for node in memory.nodes:
if node.node_type.value in target_node_types:
total_target_nodes += 1
type_stats[node.node_type.value]["total"] += 1
# 检查是否已在向量索引中
if node.id in existing_node_ids:
type_stats[node.node_type.value]["already_indexed"] += 1
continue
if not node.has_embedding():
nodes_to_process.append({
"node": node,
"memory_id": memory.id,
})
type_stats[node.node_type.value]["need_emb"] += 1
print("\n📊 扫描结果:")
for node_type in target_node_types:
stats = type_stats[node_type]
already_ok = stats["already_indexed"]
coverage = (stats["total"] - stats["need_emb"]) / stats["total"] * 100 if stats["total"] > 0 else 0
print(f" - {node_type}: {stats['total']} 个节点, {stats['need_emb']} 个缺失嵌入, "
f"{already_ok} 个已索引 (覆盖率: {coverage:.1f}%)")
print(f"\n 总计: {total_target_nodes} 个目标节点, {len(nodes_to_process)} 个需要生成嵌入\n")
if len(nodes_to_process) == 0:
print("✅ 所有节点已有嵌入向量,无需生成")
return
# 3. 批量生成嵌入
print("🚀 开始生成嵌入向量...\n")
total_batches = (len(nodes_to_process) + batch_size - 1) // batch_size
success_count = 0
failed_count = 0
indexed_count = 0
for i in range(0, len(nodes_to_process), batch_size):
batch = nodes_to_process[i : i + batch_size]
batch_num = i // batch_size + 1
print(f"📦 批次 {batch_num}/{total_batches} ({len(batch)} 个节点)...")
try:
# 提取文本内容
texts = [item["node"].content for item in batch]
# 批量生成嵌入
embeddings = await manager.embedding_generator.generate_batch(texts)
# 为节点设置嵌入并索引
batch_nodes_for_index = []
for j, (item, embedding) in enumerate(zip(batch, embeddings)):
node = item["node"]
if embedding is not None:
# 设置嵌入向量
node.embedding = embedding
batch_nodes_for_index.append(node)
success_count += 1
else:
failed_count += 1
logger.warning(f" ⚠️ 节点 {node.id[:8]}... '{node.content[:30]}' 嵌入生成失败")
# 批量索引到向量数据库
if batch_nodes_for_index:
try:
await manager.vector_store.add_nodes_batch(batch_nodes_for_index)
indexed_count += len(batch_nodes_for_index)
print(f" ✅ 成功: {len(batch_nodes_for_index)}/{len(batch)} 个节点已生成并索引")
except Exception as e:
# 如果批量失败,尝试逐个添加(跳过重复)
logger.warning(f" 批量索引失败,尝试逐个添加: {e}")
individual_success = 0
for node in batch_nodes_for_index:
try:
await manager.vector_store.add_node(node)
individual_success += 1
indexed_count += 1
except Exception as e2:
if "Expected IDs to be unique" in str(e2):
logger.debug(f" 跳过已存在节点: {node.id}")
else:
logger.error(f" 节点 {node.id} 索引失败: {e2}")
print(f" ⚠️ 逐个索引: {individual_success}/{len(batch_nodes_for_index)} 个成功")
except Exception as e:
failed_count += len(batch)
logger.error(f"批次 {batch_num} 处理失败")
print(f" ❌ 批次处理失败: {e}")
# 显示进度
total_processed = min(i + batch_size, len(nodes_to_process))
progress = total_processed / len(nodes_to_process) * 100
print(f" 📊 总进度: {total_processed}/{len(nodes_to_process)} ({progress:.1f}%)\n")
# 4. 保存图数据(更新节点的 embedding 字段)
print("💾 保存图数据...")
try:
await manager.persistence.save_graph_store(manager.graph_store)
print("✅ 图数据已保存\n")
except Exception as e:
logger.error("保存图数据失败")
print(f"❌ 保存失败: {e}\n")
# 5. 验证结果
print("🔍 验证向量索引...")
final_vector_count = manager.vector_store.collection.count()
stats = manager.graph_store.get_statistics()
total_nodes = stats["total_nodes"]
print(f"\n{'='*80}")
print("📊 生成完成")
print(f"{'='*80}")
print(f"处理节点数: {len(nodes_to_process)}")
print(f"成功生成: {success_count}")
print(f"失败数量: {failed_count}")
print(f"成功索引: {indexed_count}")
print(f"向量索引节点数: {final_vector_count}")
print(f"图存储节点数: {total_nodes}")
print(f"索引覆盖率: {final_vector_count / total_nodes * 100:.1f}%\n")
# 6. 测试搜索
print("🧪 测试搜索功能...")
test_queries = ["小红帽蕾克", "拾风", "杰瑞喵"]
for query in test_queries:
results = await manager.search_memories(query=query, top_k=3)
if results:
print(f"\n✅ 查询 '{query}' 找到 {len(results)} 条记忆:")
for i, memory in enumerate(results[:2], 1):
subject_node = memory.get_subject_node()
# 获取主题节点遍历所有节点找TOPIC类型
from src.memory_graph.models import NodeType
topic_nodes = [n for n in memory.nodes if n.node_type == NodeType.TOPIC]
subject = subject_node.content if subject_node else "?"
topic = topic_nodes[0].content if topic_nodes else "?"
print(f" {i}. {subject} - {topic} (重要性: {memory.importance:.2f})")
else:
print(f"\n⚠️ 查询 '{query}' 返回 0 条结果")
async def main():
import argparse
parser = argparse.ArgumentParser(description="为节点生成嵌入向量")
parser.add_argument(
"--node-types",
type=str,
default="主题,客体",
help="需要生成嵌入的节点类型,逗号分隔(默认:主题,客体)",
)
parser.add_argument(
"--batch-size",
type=int,
default=50,
help="批处理大小默认50",
)
args = parser.parse_args()
target_types = [t.strip() for t in args.node_types.split(",")]
await generate_missing_embeddings(
target_node_types=target_types,
batch_size=args.batch_size,
)
if __name__ == "__main__":
asyncio.run(main())