This commit is contained in:
SengokuCola
2025-05-05 22:42:20 +08:00
4 changed files with 65 additions and 15 deletions

View File

@@ -85,6 +85,7 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k
logger.error("系统将于2秒后开始检查数据完整性") logger.error("系统将于2秒后开始检查数据完整性")
sleep(2) sleep(2)
found_missing = False found_missing = False
missing_idxs = []
for doc in getattr(openie_data, "docs", []): for doc in getattr(openie_data, "docs", []):
idx = doc.get("idx", "<无idx>") idx = doc.get("idx", "<无idx>")
passage = doc.get("passage", "<无passage>") passage = doc.get("passage", "<无passage>")
@@ -104,14 +105,38 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k
# print(f"检查: idx={idx}") # print(f"检查: idx={idx}")
if missing: if missing:
found_missing = True found_missing = True
missing_idxs.append(idx)
logger.error("\n") logger.error("\n")
logger.error("数据缺失:") logger.error("数据缺失:")
logger.error(f"对应哈希值:{idx}") logger.error(f"对应哈希值:{idx}")
logger.error(f"对应文段内容内容:{passage}") logger.error(f"对应文段内容内容:{passage}")
logger.error(f"非法原因:{', '.join(missing)}") logger.error(f"非法原因:{', '.join(missing)}")
# 确保提示在所有非法数据输出后再输出
if not found_missing: if not found_missing:
print("所有数据均完整,没有发现缺失字段。") logger.info("所有数据均完整,没有发现缺失字段。")
return False return False
# 新增:提示用户是否删除非法文段继续导入
# 将print移到所有logger.error之后确保不会被冲掉
logger.info("\n检测到非法文段,共{}条。".format(len(missing_idxs)))
logger.info("\n是否删除所有非法文段后继续导入?(y/n): ", end="")
user_choice = input().strip().lower()
if user_choice != "y":
logger.info("用户选择不删除非法文段,程序终止。")
sys.exit(1)
# 删除非法文段
logger.info("正在删除非法文段并继续导入...")
# 过滤掉非法文段
openie_data.docs = [
doc for doc in getattr(openie_data, "docs", []) if doc.get("idx", "<无idx>") not in missing_idxs
]
# 重新提取数据
raw_paragraphs = openie_data.extract_raw_paragraph_dict()
entity_list_data = openie_data.extract_entity_dict()
triple_list_data = openie_data.extract_triple_dict()
# 再次校验
if len(raw_paragraphs) != len(entity_list_data) or len(raw_paragraphs) != len(triple_list_data):
logger.error("删除非法文段后,数据仍不一致,程序终止。")
sys.exit(1)
# 将索引换为对应段落的hash值 # 将索引换为对应段落的hash值
logger.info("正在进行段落去重与重索引") logger.info("正在进行段落去重与重索引")
raw_paragraphs, triple_list_data = hash_deduplicate( raw_paragraphs, triple_list_data = hash_deduplicate(
@@ -179,6 +204,7 @@ def main():
logger.error("请保证你的嵌入模型从未更改,并且在导入时使用相同的模型") logger.error("请保证你的嵌入模型从未更改,并且在导入时使用相同的模型")
# print("检测到嵌入模型与本地存储不一致,已终止导入。请检查模型设置或清空嵌入库后重试。") # print("检测到嵌入模型与本地存储不一致,已终止导入。请检查模型设置或清空嵌入库后重试。")
sys.exit(1) sys.exit(1)
if "不存在" in str(e):
logger.error("如果你是第一次导入知识,请忽略此错误") logger.error("如果你是第一次导入知识,请忽略此错误")
logger.info("Embedding库加载完成") logger.info("Embedding库加载完成")
# 初始化KG # 初始化KG

View File

@@ -5,12 +5,21 @@ import sys # 新增系统模块导入
import datetime # 新增导入 import datetime # 新增导入
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from src.common.logger import get_module_logger from src.common.logger_manager import get_logger
from src.plugins.knowledge.src.lpmmconfig import global_config
logger = get_module_logger("LPMM数据库-原始数据处理") logger = get_logger("lpmm")
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
RAW_DATA_PATH = os.path.join(ROOT_PATH, "data/lpmm_raw_data") RAW_DATA_PATH = os.path.join(ROOT_PATH, "data/lpmm_raw_data")
IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data/imported_lpmm_data") # 新增:确保 RAW_DATA_PATH 存在
if not os.path.exists(RAW_DATA_PATH):
os.makedirs(RAW_DATA_PATH, exist_ok=True)
logger.info(f"已创建目录: {RAW_DATA_PATH}")
if global_config.get("persistence", {}).get("raw_data_path") is not None:
IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, global_config["persistence"]["raw_data_path"])
else:
IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data/imported_lpmm_data")
# 添加项目根目录到 sys.path # 添加项目根目录到 sys.path
@@ -54,7 +63,7 @@ def main():
print("请确保原始数据已放置在正确的目录中。") print("请确保原始数据已放置在正确的目录中。")
confirm = input("确认继续执行?(y/n): ").strip().lower() confirm = input("确认继续执行?(y/n): ").strip().lower()
if confirm != "y": if confirm != "y":
logger.error("操作已取消") logger.info("操作已取消")
sys.exit(1) sys.exit(1)
print("\n" + "=" * 40 + "\n") print("\n" + "=" * 40 + "\n")
@@ -94,6 +103,6 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
print(f"Raw Data Path: {RAW_DATA_PATH}") logger.info(f"原始数据路径: {RAW_DATA_PATH}")
print(f"Imported Data Path: {IMPORTED_DATA_PATH}") logger.info(f"处理后的数据路径: {IMPORTED_DATA_PATH}")
main() main()

View File

@@ -6,7 +6,8 @@ from typing import Dict, List, Tuple
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import tqdm
# import tqdm
import faiss import faiss
from .llm_client import LLMClient from .llm_client import LLMClient
@@ -194,11 +195,25 @@ class EmbeddingStore:
"""从文件中加载""" """从文件中加载"""
if not os.path.exists(self.embedding_file_path): if not os.path.exists(self.embedding_file_path):
raise Exception(f"文件{self.embedding_file_path}不存在") raise Exception(f"文件{self.embedding_file_path}不存在")
logger.info(f"正在从文件{self.embedding_file_path}中加载{self.namespace}嵌入库") logger.info(f"正在从文件{self.embedding_file_path}中加载{self.namespace}嵌入库")
data_frame = pd.read_parquet(self.embedding_file_path, engine="pyarrow") data_frame = pd.read_parquet(self.embedding_file_path, engine="pyarrow")
for _, row in tqdm.tqdm(data_frame.iterrows(), total=len(data_frame)): total = len(data_frame)
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
MofNCompleteColumn(),
"",
TimeElapsedColumn(),
"<",
TimeRemainingColumn(),
transient=False,
) as progress:
task = progress.add_task("加载嵌入库", total=total)
for _, row in data_frame.iterrows():
self.store[row["hash"]] = EmbeddingStoreItem(row["hash"], row["embedding"], row["str"]) self.store[row["hash"]] = EmbeddingStoreItem(row["hash"], row["embedding"], row["str"])
progress.update(task, advance=1)
logger.info(f"{self.namespace}嵌入库加载成功") logger.info(f"{self.namespace}嵌入库加载成功")
try: try:

View File

@@ -39,7 +39,7 @@ provider = "siliconflow" # 服务提供商
model = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B" # 模型名称 model = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B" # 模型名称
[info_extraction] [info_extraction]
workers = 10 workers = 3 # 实体提取同时执行线程数非Pro模型不要设置超过5
[qa.params] [qa.params]
# QA参数配置 # QA参数配置