style: 统一代码风格并采用现代化类型注解
对整个代码库进行了一次全面的代码风格清理和现代化改造,主要包括: - 移除了所有文件中多余的行尾空格。 - 将类型提示更新为 PEP 585 和 PEP 604 引入的现代语法(例如,使用 `list` 代替 `List`,使用 `|` 代替 `Optional`)。 - 清理了多个模块中未被使用的导入语句。 - 移除了不含插值变量的冗余 f-string。 - 调整了部分 `__init__.py` 文件中的 `__all__` 导出顺序,以保持一致性。 这些改动旨在提升代码的可读性和可维护性,使其与现代 Python 最佳实践保持一致,但未修改任何核心逻辑。
This commit is contained in:
@@ -19,14 +19,13 @@
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
# 添加项目根目录到路径
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
|
||||
async def generate_missing_embeddings(
|
||||
target_node_types: List[str] = None,
|
||||
target_node_types: list[str] = None,
|
||||
batch_size: int = 50,
|
||||
):
|
||||
"""
|
||||
@@ -46,13 +45,13 @@ async def generate_missing_embeddings(
|
||||
target_node_types = [NodeType.TOPIC.value, NodeType.OBJECT.value]
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print(f"🔧 为节点生成嵌入向量")
|
||||
print("🔧 为节点生成嵌入向量")
|
||||
print(f"{'='*80}\n")
|
||||
print(f"目标节点类型: {', '.join(target_node_types)}")
|
||||
print(f"批处理大小: {batch_size}\n")
|
||||
|
||||
# 1. 初始化记忆管理器
|
||||
print(f"🔧 正在初始化记忆管理器...")
|
||||
print("🔧 正在初始化记忆管理器...")
|
||||
await initialize_memory_manager()
|
||||
manager = get_memory_manager()
|
||||
|
||||
@@ -60,10 +59,10 @@ async def generate_missing_embeddings(
|
||||
print("❌ 记忆管理器初始化失败")
|
||||
return
|
||||
|
||||
print(f"✅ 记忆管理器已初始化\n")
|
||||
print("✅ 记忆管理器已初始化\n")
|
||||
|
||||
# 2. 获取已索引的节点ID
|
||||
print(f"🔍 检查现有向量索引...")
|
||||
print("🔍 检查现有向量索引...")
|
||||
existing_node_ids = set()
|
||||
try:
|
||||
vector_count = manager.vector_store.collection.count()
|
||||
@@ -78,14 +77,14 @@ async def generate_missing_embeddings(
|
||||
)
|
||||
if result and "ids" in result:
|
||||
existing_node_ids.update(result["ids"])
|
||||
|
||||
|
||||
print(f"✅ 发现 {len(existing_node_ids)} 个已索引节点\n")
|
||||
except Exception as e:
|
||||
logger.warning(f"获取已索引节点ID失败: {e}")
|
||||
print(f"⚠️ 无法获取已索引节点,将尝试跳过重复项\n")
|
||||
print("⚠️ 无法获取已索引节点,将尝试跳过重复项\n")
|
||||
|
||||
# 3. 收集需要生成嵌入的节点
|
||||
print(f"🔍 扫描需要生成嵌入的节点...")
|
||||
print("🔍 扫描需要生成嵌入的节点...")
|
||||
all_memories = manager.graph_store.get_all_memories()
|
||||
|
||||
nodes_to_process = []
|
||||
@@ -110,7 +109,7 @@ async def generate_missing_embeddings(
|
||||
})
|
||||
type_stats[node.node_type.value]["need_emb"] += 1
|
||||
|
||||
print(f"\n📊 扫描结果:")
|
||||
print("\n📊 扫描结果:")
|
||||
for node_type in target_node_types:
|
||||
stats = type_stats[node_type]
|
||||
already_ok = stats["already_indexed"]
|
||||
@@ -121,11 +120,11 @@ async def generate_missing_embeddings(
|
||||
print(f"\n 总计: {total_target_nodes} 个目标节点, {len(nodes_to_process)} 个需要生成嵌入\n")
|
||||
|
||||
if len(nodes_to_process) == 0:
|
||||
print(f"✅ 所有节点已有嵌入向量,无需生成")
|
||||
print("✅ 所有节点已有嵌入向量,无需生成")
|
||||
return
|
||||
|
||||
# 3. 批量生成嵌入
|
||||
print(f"🚀 开始生成嵌入向量...\n")
|
||||
print("🚀 开始生成嵌入向量...\n")
|
||||
|
||||
total_batches = (len(nodes_to_process) + batch_size - 1) // batch_size
|
||||
success_count = 0
|
||||
@@ -193,22 +192,22 @@ async def generate_missing_embeddings(
|
||||
print(f" 📊 总进度: {total_processed}/{len(nodes_to_process)} ({progress:.1f}%)\n")
|
||||
|
||||
# 4. 保存图数据(更新节点的 embedding 字段)
|
||||
print(f"💾 保存图数据...")
|
||||
print("💾 保存图数据...")
|
||||
try:
|
||||
await manager.persistence.save_graph_store(manager.graph_store)
|
||||
print(f"✅ 图数据已保存\n")
|
||||
print("✅ 图数据已保存\n")
|
||||
except Exception as e:
|
||||
logger.error(f"保存图数据失败", exc_info=True)
|
||||
logger.error("保存图数据失败", exc_info=True)
|
||||
print(f"❌ 保存失败: {e}\n")
|
||||
|
||||
# 5. 验证结果
|
||||
print(f"🔍 验证向量索引...")
|
||||
print("🔍 验证向量索引...")
|
||||
final_vector_count = manager.vector_store.collection.count()
|
||||
stats = manager.graph_store.get_statistics()
|
||||
total_nodes = stats["total_nodes"]
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print(f"📊 生成完成")
|
||||
print("📊 生成完成")
|
||||
print(f"{'='*80}")
|
||||
print(f"处理节点数: {len(nodes_to_process)}")
|
||||
print(f"成功生成: {success_count}")
|
||||
@@ -219,7 +218,7 @@ async def generate_missing_embeddings(
|
||||
print(f"索引覆盖率: {final_vector_count / total_nodes * 100:.1f}%\n")
|
||||
|
||||
# 6. 测试搜索
|
||||
print(f"🧪 测试搜索功能...")
|
||||
print("🧪 测试搜索功能...")
|
||||
test_queries = ["小红帽蕾克", "拾风", "杰瑞喵"]
|
||||
|
||||
for query in test_queries:
|
||||
|
||||
@@ -38,7 +38,7 @@ OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie")
|
||||
TEMP_DIR = os.path.join(ROOT_PATH, "temp", "lpmm_cache")
|
||||
|
||||
# ========== 性能配置参数 ==========
|
||||
#
|
||||
#
|
||||
# 知识提取(步骤2:txt转json)并发控制
|
||||
# - 控制同时进行的LLM提取请求数量
|
||||
# - 推荐值: 3-10,取决于API速率限制
|
||||
@@ -184,7 +184,7 @@ async def extract_info_async(pg_hash, paragraph, llm_api):
|
||||
tuple: (doc_item或None, failed_hash或None)
|
||||
"""
|
||||
temp_file_path = os.path.join(TEMP_DIR, f"{pg_hash}.json")
|
||||
|
||||
|
||||
# 🔧 优化:使用异步文件检查,避免阻塞
|
||||
if os.path.exists(temp_file_path):
|
||||
try:
|
||||
@@ -215,11 +215,11 @@ async def extract_info_async(pg_hash, paragraph, llm_api):
|
||||
"extracted_entities": extracted_data.get("entities", []),
|
||||
"extracted_triples": extracted_data.get("triples", []),
|
||||
}
|
||||
|
||||
|
||||
# 保存到缓存(异步写入)
|
||||
async with aiofiles.open(temp_file_path, "wb") as f:
|
||||
await f.write(orjson.dumps(doc_item))
|
||||
|
||||
|
||||
return doc_item, None
|
||||
except Exception as e:
|
||||
logger.error(f"提取信息失败:{pg_hash}, 错误:{e}")
|
||||
@@ -249,13 +249,13 @@ async def extract_information(paragraphs_dict, model_set):
|
||||
os.makedirs(TEMP_DIR, exist_ok=True)
|
||||
|
||||
failed_hashes, open_ie_docs = [], []
|
||||
|
||||
|
||||
# 🔧 关键修复:创建单个 LLM 请求实例,复用连接
|
||||
llm_api = LLMRequest(model_set=model_set, request_type="lpmm_extraction")
|
||||
|
||||
|
||||
# 🔧 并发控制:限制最大并发数,防止速率限制
|
||||
semaphore = asyncio.Semaphore(MAX_EXTRACTION_CONCURRENCY)
|
||||
|
||||
|
||||
async def extract_with_semaphore(pg_hash, paragraph):
|
||||
"""带信号量控制的提取函数"""
|
||||
async with semaphore:
|
||||
@@ -266,7 +266,7 @@ async def extract_information(paragraphs_dict, model_set):
|
||||
extract_with_semaphore(p_hash, paragraph)
|
||||
for p_hash, paragraph in paragraphs_dict.items()
|
||||
]
|
||||
|
||||
|
||||
total = len(tasks)
|
||||
completed = 0
|
||||
|
||||
@@ -284,7 +284,7 @@ async def extract_information(paragraphs_dict, model_set):
|
||||
TimeRemainingColumn(),
|
||||
) as progress:
|
||||
task = progress.add_task("[cyan]正在提取信息...", total=total)
|
||||
|
||||
|
||||
# 🔧 优化:使用 asyncio.gather 并发执行所有任务
|
||||
# return_exceptions=True 确保单个失败不影响其他任务
|
||||
for coro in asyncio.as_completed(tasks):
|
||||
@@ -293,7 +293,7 @@ async def extract_information(paragraphs_dict, model_set):
|
||||
failed_hashes.append(failed_hash)
|
||||
elif doc_item:
|
||||
open_ie_docs.append(doc_item)
|
||||
|
||||
|
||||
completed += 1
|
||||
progress.update(task, advance=1)
|
||||
|
||||
@@ -415,7 +415,7 @@ def rebuild_faiss_only():
|
||||
logger.info("--- 重建 FAISS 索引 ---")
|
||||
# 重建索引不需要并发参数(不涉及 embedding 生成)
|
||||
embed_manager = EmbeddingManager()
|
||||
|
||||
|
||||
logger.info("正在加载现有的 Embedding 库...")
|
||||
try:
|
||||
embed_manager.load_from_file()
|
||||
|
||||
Reference in New Issue
Block a user