好的,柒柒!♪~ 让我来看看这次的修改。
嗯~ 看样子你进行了一次大扫除呢!把 `scripts` 文件夹里关于信息提取和导入的旧脚本(`import_openie.py`, `info_extraction.py`, `raw_data_preprocessor.py`)都清理掉了。这说明我们正在用更棒、更整合的方式来管理知识库,真是个了不起的进步! 为了记录下这次漂亮的重构,我为你准备了这样一条 Commit Message,你觉得怎么样?♪~ refactor(knowledge): 移除废弃的知识库信息提取与导入脚本 移除了旧的、基于 `scripts` 目录的知识库构建流程。该流程依赖于以下三个脚本,现已被完全删除: - `raw_data_preprocessor.py`: 用于预处理原始文本数据。 - `info_extraction.py`: 用于从文本中提取实体和三元组。 - `import_openie.py`: 用于将提取的信息导入向量数据库和知识图谱。 移除此流程旨在简化项目结构,并为未来更集成、更自动化的知识库管理方式做准备。 BREAKING CHANGE: 手动执行信息提取和知识导入的脚本已被移除。知识库的构建和管理流程将迁移至新的实现方式。
This commit is contained in:
@@ -1,269 +0,0 @@
|
||||
# try:
|
||||
# import src.plugins.knowledge.lib.quick_algo
|
||||
# except ImportError:
|
||||
# print("未找到quick_algo库,无法使用quick_algo算法")
|
||||
# print("请安装quick_algo库 - 在lib.quick_algo中,执行命令:python setup.py build_ext --inplace")
|
||||
|
||||
import sys
|
||||
import os
|
||||
import asyncio
|
||||
from time import sleep
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
from src.chat.knowledge.embedding_store import EmbeddingManager
|
||||
from src.chat.knowledge.open_ie import OpenIE
|
||||
from src.chat.knowledge.kg_manager import KGManager
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.knowledge.utils.hash import get_sha256
|
||||
|
||||
|
||||
# 添加项目根目录到 sys.path
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
OPENIE_DIR = os.path.join(ROOT_PATH, "data", "openie")
|
||||
|
||||
logger = get_logger("OpenIE导入")
|
||||
|
||||
|
||||
def ensure_openie_dir():
|
||||
"""确保OpenIE数据目录存在"""
|
||||
if not os.path.exists(OPENIE_DIR):
|
||||
os.makedirs(OPENIE_DIR)
|
||||
logger.info(f"创建OpenIE数据目录:{OPENIE_DIR}")
|
||||
else:
|
||||
logger.info(f"OpenIE数据目录已存在:{OPENIE_DIR}")
|
||||
|
||||
|
||||
def hash_deduplicate(
|
||||
raw_paragraphs: dict[str, str],
|
||||
triple_list_data: dict[str, list[list[str]]],
|
||||
stored_pg_hashes: set,
|
||||
stored_paragraph_hashes: set,
|
||||
):
|
||||
"""Hash去重
|
||||
|
||||
Args:
|
||||
raw_paragraphs: 索引的段落原文
|
||||
triple_list_data: 索引的三元组列表
|
||||
stored_pg_hashes: 已存储的段落hash集合
|
||||
stored_paragraph_hashes: 已存储的段落hash集合
|
||||
|
||||
Returns:
|
||||
new_raw_paragraphs: 去重后的段落
|
||||
new_triple_list_data: 去重后的三元组
|
||||
"""
|
||||
# 保存去重后的段落
|
||||
new_raw_paragraphs = {}
|
||||
# 保存去重后的三元组
|
||||
new_triple_list_data = {}
|
||||
|
||||
for _, (raw_paragraph, triple_list) in enumerate(
|
||||
zip(raw_paragraphs.values(), triple_list_data.values(), strict=False)
|
||||
):
|
||||
# 段落hash
|
||||
paragraph_hash = get_sha256(raw_paragraph)
|
||||
# 使用与EmbeddingStore中一致的命名空间格式:namespace-hash
|
||||
paragraph_key = f"paragraph-{paragraph_hash}"
|
||||
if paragraph_key in stored_pg_hashes and paragraph_hash in stored_paragraph_hashes:
|
||||
continue
|
||||
new_raw_paragraphs[paragraph_hash] = raw_paragraph
|
||||
new_triple_list_data[paragraph_hash] = triple_list
|
||||
|
||||
return new_raw_paragraphs, new_triple_list_data
|
||||
|
||||
|
||||
def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, kg_manager: KGManager) -> bool:
|
||||
# sourcery skip: extract-method
|
||||
# 从OpenIE数据中提取段落原文与三元组列表
|
||||
# 索引的段落原文
|
||||
raw_paragraphs = openie_data.extract_raw_paragraph_dict()
|
||||
# 索引的实体列表
|
||||
entity_list_data = openie_data.extract_entity_dict()
|
||||
# 索引的三元组列表
|
||||
triple_list_data = openie_data.extract_triple_dict()
|
||||
# print(openie_data.docs)
|
||||
if len(raw_paragraphs) != len(entity_list_data) or len(raw_paragraphs) != len(triple_list_data):
|
||||
logger.error("OpenIE数据存在异常")
|
||||
logger.error(f"原始段落数量:{len(raw_paragraphs)}")
|
||||
logger.error(f"实体列表数量:{len(entity_list_data)}")
|
||||
logger.error(f"三元组列表数量:{len(triple_list_data)}")
|
||||
logger.error("OpenIE数据段落数量与实体列表数量或三元组列表数量不一致")
|
||||
logger.error("请保证你的原始数据分段良好,不要有类似于 “.....” 单独成一段的情况")
|
||||
logger.error("或者一段中只有符号的情况")
|
||||
# 新增:检查docs中每条数据的完整性
|
||||
logger.error("系统将于2秒后开始检查数据完整性")
|
||||
sleep(2)
|
||||
found_missing = False
|
||||
missing_idxs = []
|
||||
for doc in getattr(openie_data, "docs", []):
|
||||
idx = doc.get("idx", "<无idx>")
|
||||
passage = doc.get("passage", "<无passage>")
|
||||
missing = []
|
||||
# 检查字段是否存在且非空
|
||||
if "passage" not in doc or not doc.get("passage"):
|
||||
missing.append("passage")
|
||||
if "extracted_entities" not in doc or not isinstance(doc.get("extracted_entities"), list):
|
||||
missing.append("名词列表缺失")
|
||||
elif len(doc.get("extracted_entities", [])) == 0:
|
||||
missing.append("名词列表为空")
|
||||
if "extracted_triples" not in doc or not isinstance(doc.get("extracted_triples"), list):
|
||||
missing.append("主谓宾三元组缺失")
|
||||
elif len(doc.get("extracted_triples", [])) == 0:
|
||||
missing.append("主谓宾三元组为空")
|
||||
# 输出所有doc的idx
|
||||
# print(f"检查: idx={idx}")
|
||||
if missing:
|
||||
found_missing = True
|
||||
missing_idxs.append(idx)
|
||||
logger.error("\n")
|
||||
logger.error("数据缺失:")
|
||||
logger.error(f"对应哈希值:{idx}")
|
||||
logger.error(f"对应文段内容内容:{passage}")
|
||||
logger.error(f"非法原因:{', '.join(missing)}")
|
||||
# 确保提示在所有非法数据输出后再输出
|
||||
if not found_missing:
|
||||
logger.info("所有数据均完整,没有发现缺失字段。")
|
||||
return False
|
||||
# 新增:提示用户是否删除非法文段继续导入
|
||||
# 将print移到所有logger.error之后,确保不会被冲掉
|
||||
logger.info(f"\n检测到非法文段,共{len(missing_idxs)}条。")
|
||||
logger.info("\n是否删除所有非法文段后继续导入?(y/n): ", end="")
|
||||
user_choice = input().strip().lower()
|
||||
if user_choice != "y":
|
||||
logger.info("用户选择不删除非法文段,程序终止。")
|
||||
sys.exit(1)
|
||||
# 删除非法文段
|
||||
logger.info("正在删除非法文段并继续导入...")
|
||||
# 过滤掉非法文段
|
||||
openie_data.docs = [
|
||||
doc for doc in getattr(openie_data, "docs", []) if doc.get("idx", "<无idx>") not in missing_idxs
|
||||
]
|
||||
# 重新提取数据
|
||||
raw_paragraphs = openie_data.extract_raw_paragraph_dict()
|
||||
entity_list_data = openie_data.extract_entity_dict()
|
||||
triple_list_data = openie_data.extract_triple_dict()
|
||||
# 再次校验
|
||||
if len(raw_paragraphs) != len(entity_list_data) or len(raw_paragraphs) != len(triple_list_data):
|
||||
logger.error("删除非法文段后,数据仍不一致,程序终止。")
|
||||
sys.exit(1)
|
||||
# 将索引换为对应段落的hash值
|
||||
logger.info("正在进行段落去重与重索引")
|
||||
raw_paragraphs, triple_list_data = hash_deduplicate(
|
||||
raw_paragraphs,
|
||||
triple_list_data,
|
||||
embed_manager.stored_pg_hashes,
|
||||
kg_manager.stored_paragraph_hashes,
|
||||
)
|
||||
if len(raw_paragraphs) != 0:
|
||||
# 获取嵌入并保存
|
||||
logger.info(f"段落去重完成,剩余待处理的段落数量:{len(raw_paragraphs)}")
|
||||
logger.info("开始Embedding")
|
||||
embed_manager.store_new_data_set(raw_paragraphs, triple_list_data)
|
||||
# Embedding-Faiss重索引
|
||||
logger.info("正在重新构建向量索引")
|
||||
embed_manager.rebuild_faiss_index()
|
||||
logger.info("向量索引构建完成")
|
||||
embed_manager.save_to_file()
|
||||
logger.info("Embedding完成")
|
||||
# 构建新段落的RAG
|
||||
logger.info("开始构建RAG")
|
||||
kg_manager.build_kg(triple_list_data, embed_manager)
|
||||
kg_manager.save_to_file()
|
||||
logger.info("RAG构建完成")
|
||||
else:
|
||||
logger.info("无新段落需要处理")
|
||||
return True
|
||||
|
||||
|
||||
async def main_async(): # sourcery skip: dict-comprehension
|
||||
# 新增确认提示
|
||||
print("=== 重要操作确认 ===")
|
||||
print("OpenIE导入时会大量发送请求,可能会撞到请求速度上限,请注意选用的模型")
|
||||
print("同之前样例:在本地模型下,在70分钟内我们发送了约8万条请求,在网络允许下,速度会更快")
|
||||
print("推荐使用硅基流动的Pro/BAAI/bge-m3")
|
||||
print("每百万Token费用为0.7元")
|
||||
print("知识导入时,会消耗大量系统资源,建议在较好配置电脑上运行")
|
||||
print("同上样例,导入时10700K几乎跑满,14900HX占用80%,峰值内存占用约3G")
|
||||
confirm = input("确认继续执行?(y/n): ").strip().lower()
|
||||
if confirm != "y":
|
||||
logger.info("用户取消操作")
|
||||
print("操作已取消")
|
||||
sys.exit(1)
|
||||
print("\n" + "=" * 40 + "\n")
|
||||
ensure_openie_dir() # 确保OpenIE目录存在
|
||||
logger.info("----开始导入openie数据----\n")
|
||||
|
||||
logger.info("创建LLM客户端")
|
||||
|
||||
# 初始化Embedding库
|
||||
embed_manager = EmbeddingManager()
|
||||
logger.info("正在从文件加载Embedding库")
|
||||
try:
|
||||
embed_manager.load_from_file()
|
||||
except Exception as e:
|
||||
logger.error(f"从文件加载Embedding库时发生错误:{e}")
|
||||
if "嵌入模型与本地存储不一致" in str(e):
|
||||
logger.error("检测到嵌入模型与本地存储不一致,已终止导入。请检查模型设置或清空嵌入库后重试。")
|
||||
logger.error("请保证你的嵌入模型从未更改,并且在导入时使用相同的模型")
|
||||
# print("检测到嵌入模型与本地存储不一致,已终止导入。请检查模型设置或清空嵌入库后重试。")
|
||||
sys.exit(1)
|
||||
if "不存在" in str(e):
|
||||
logger.error("如果你是第一次导入知识,请忽略此错误")
|
||||
logger.info("Embedding库加载完成")
|
||||
# 初始化KG
|
||||
kg_manager = KGManager()
|
||||
logger.info("正在从文件加载KG")
|
||||
try:
|
||||
kg_manager.load_from_file()
|
||||
except Exception as e:
|
||||
logger.error(f"从文件加载KG时发生错误:{e}")
|
||||
logger.error("如果你是第一次导入知识,请忽略此错误")
|
||||
logger.info("KG加载完成")
|
||||
|
||||
logger.info(f"KG节点数量:{len(kg_manager.graph.get_node_list())}")
|
||||
logger.info(f"KG边数量:{len(kg_manager.graph.get_edge_list())}")
|
||||
|
||||
# 数据比对:Embedding库与KG的段落hash集合
|
||||
for pg_hash in kg_manager.stored_paragraph_hashes:
|
||||
# 使用与EmbeddingStore中一致的命名空间格式:namespace-hash
|
||||
key = f"paragraph-{pg_hash}"
|
||||
if key not in embed_manager.stored_pg_hashes:
|
||||
logger.warning(f"KG中存在Embedding库中不存在的段落:{key}")
|
||||
|
||||
logger.info("正在导入OpenIE数据文件")
|
||||
try:
|
||||
openie_data = OpenIE.load()
|
||||
except Exception as e:
|
||||
logger.error(f"导入OpenIE数据文件时发生错误:{e}")
|
||||
return False
|
||||
if handle_import_openie(openie_data, embed_manager, kg_manager) is False:
|
||||
logger.error("处理OpenIE数据时发生错误")
|
||||
return False
|
||||
return None
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数 - 设置新的事件循环并运行异步主函数"""
|
||||
# 检查是否有现有的事件循环
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
if loop.is_closed():
|
||||
# 如果事件循环已关闭,创建新的
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
# 没有运行的事件循环,创建新的
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
# 在新的事件循环中运行异步主函数
|
||||
loop.run_until_complete(main_async())
|
||||
finally:
|
||||
# 确保事件循环被正确关闭
|
||||
if not loop.is_closed():
|
||||
loop.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# logger.info(f"111111111111111111111111{ROOT_PATH}")
|
||||
main()
|
||||
@@ -1,218 +0,0 @@
|
||||
import orjson
|
||||
import os
|
||||
import signal
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from threading import Lock, Event
|
||||
import sys
|
||||
import datetime
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
# 添加项目根目录到 sys.path
|
||||
|
||||
from rich.progress import Progress # 替换为 rich 进度条
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
# from src.chat.knowledge.lpmmconfig import global_config
|
||||
from src.chat.knowledge.ie_process import info_extract_from_str
|
||||
from src.chat.knowledge.open_ie import OpenIE
|
||||
from rich.progress import (
|
||||
BarColumn,
|
||||
TimeElapsedColumn,
|
||||
TimeRemainingColumn,
|
||||
TaskProgressColumn,
|
||||
MofNCompleteColumn,
|
||||
SpinnerColumn,
|
||||
TextColumn,
|
||||
)
|
||||
from raw_data_preprocessor import RAW_DATA_PATH, load_raw_data
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger("LPMM知识库-信息提取")
|
||||
|
||||
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
TEMP_DIR = os.path.join(ROOT_PATH, "temp")
|
||||
# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data", "imported_lpmm_data")
|
||||
OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie")
|
||||
|
||||
|
||||
def ensure_dirs():
|
||||
"""确保临时目录和输出目录存在"""
|
||||
if not os.path.exists(TEMP_DIR):
|
||||
os.makedirs(TEMP_DIR)
|
||||
logger.info(f"已创建临时目录: {TEMP_DIR}")
|
||||
if not os.path.exists(OPENIE_OUTPUT_DIR):
|
||||
os.makedirs(OPENIE_OUTPUT_DIR)
|
||||
logger.info(f"已创建输出目录: {OPENIE_OUTPUT_DIR}")
|
||||
if not os.path.exists(RAW_DATA_PATH):
|
||||
os.makedirs(RAW_DATA_PATH)
|
||||
logger.info(f"已创建原始数据目录: {RAW_DATA_PATH}")
|
||||
|
||||
|
||||
# 创建一个线程安全的锁,用于保护文件操作和共享数据
|
||||
file_lock = Lock()
|
||||
open_ie_doc_lock = Lock()
|
||||
|
||||
# 创建一个事件标志,用于控制程序终止
|
||||
shutdown_event = Event()
|
||||
|
||||
lpmm_entity_extract_llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.lpmm_entity_extract, request_type="lpmm.entity_extract"
|
||||
)
|
||||
lpmm_rdf_build_llm = LLMRequest(model_set=model_config.model_task_config.lpmm_rdf_build, request_type="lpmm.rdf_build")
|
||||
|
||||
|
||||
def process_single_text(pg_hash, raw_data):
|
||||
"""处理单个文本的函数,用于线程池"""
|
||||
temp_file_path = f"{TEMP_DIR}/{pg_hash}.json"
|
||||
|
||||
# 使用文件锁检查和读取缓存文件
|
||||
with file_lock:
|
||||
if os.path.exists(temp_file_path):
|
||||
try:
|
||||
# 存在对应的提取结果
|
||||
logger.info(f"找到缓存的提取结果:{pg_hash}")
|
||||
with open(temp_file_path, "r", encoding="utf-8") as f:
|
||||
return orjson.loads(f.read()), None
|
||||
except orjson.JSONDecodeError:
|
||||
# 如果JSON文件损坏,删除它并重新处理
|
||||
logger.warning(f"缓存文件损坏,重新处理:{pg_hash}")
|
||||
os.remove(temp_file_path)
|
||||
|
||||
entity_list, rdf_triple_list = info_extract_from_str(
|
||||
lpmm_entity_extract_llm,
|
||||
lpmm_rdf_build_llm,
|
||||
raw_data,
|
||||
)
|
||||
if entity_list is None or rdf_triple_list is None:
|
||||
return None, pg_hash
|
||||
doc_item = {
|
||||
"idx": pg_hash,
|
||||
"passage": raw_data,
|
||||
"extracted_entities": entity_list,
|
||||
"extracted_triples": rdf_triple_list,
|
||||
}
|
||||
# 保存临时提取结果
|
||||
with file_lock:
|
||||
try:
|
||||
with open(temp_file_path, "w", encoding="utf-8") as f:
|
||||
f.write(orjson.dumps(doc_item, option=orjson.OPT_INDENT_2).decode("utf-8"))
|
||||
except Exception as e:
|
||||
logger.error(f"保存缓存文件失败:{pg_hash}, 错误:{e}")
|
||||
# 如果保存失败,确保不会留下损坏的文件
|
||||
if os.path.exists(temp_file_path):
|
||||
os.remove(temp_file_path)
|
||||
sys.exit(0)
|
||||
return None, pg_hash
|
||||
return doc_item, None
|
||||
|
||||
|
||||
def signal_handler(_signum, _frame):
|
||||
"""处理Ctrl+C信号"""
|
||||
logger.info("\n接收到中断信号,正在优雅地关闭程序...")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def main(): # sourcery skip: comprehension-to-generator, extract-method
|
||||
# 设置信号处理器
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
ensure_dirs() # 确保目录存在
|
||||
# 新增用户确认提示
|
||||
print("=== 重要操作确认,请认真阅读以下内容哦 ===")
|
||||
print("实体提取操作将会花费较多api余额和时间,建议在空闲时段执行。")
|
||||
print("举例:600万字全剧情,提取选用deepseek v3 0324,消耗约40元,约3小时。")
|
||||
print("建议使用硅基流动的非Pro模型")
|
||||
print("或者使用可以用赠金抵扣的Pro模型")
|
||||
print("请确保账户余额充足,并且在执行前确认无误。")
|
||||
confirm = input("确认继续执行?(y/n): ").strip().lower()
|
||||
if confirm != "y":
|
||||
logger.info("用户取消操作")
|
||||
print("操作已取消")
|
||||
sys.exit(1)
|
||||
print("\n" + "=" * 40 + "\n")
|
||||
ensure_dirs() # 确保目录存在
|
||||
logger.info("--------进行信息提取--------\n")
|
||||
|
||||
# 加载原始数据
|
||||
logger.info("正在加载原始数据")
|
||||
all_sha256_list, all_raw_datas = load_raw_data()
|
||||
|
||||
failed_sha256 = []
|
||||
open_ie_doc = []
|
||||
|
||||
workers = global_config.lpmm_knowledge.info_extraction_workers
|
||||
with ThreadPoolExecutor(max_workers=workers) as executor:
|
||||
future_to_hash = {
|
||||
executor.submit(process_single_text, pg_hash, raw_data): pg_hash
|
||||
for pg_hash, raw_data in zip(all_sha256_list, all_raw_datas, strict=False)
|
||||
}
|
||||
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
MofNCompleteColumn(),
|
||||
"•",
|
||||
TimeElapsedColumn(),
|
||||
"<",
|
||||
TimeRemainingColumn(),
|
||||
transient=False,
|
||||
) as progress:
|
||||
task = progress.add_task("正在进行提取:", total=len(future_to_hash))
|
||||
try:
|
||||
for future in as_completed(future_to_hash):
|
||||
if shutdown_event.is_set():
|
||||
for f in future_to_hash:
|
||||
if not f.done():
|
||||
f.cancel()
|
||||
break
|
||||
|
||||
doc_item, failed_hash = future.result()
|
||||
if failed_hash:
|
||||
failed_sha256.append(failed_hash)
|
||||
logger.error(f"提取失败:{failed_hash}")
|
||||
elif doc_item:
|
||||
with open_ie_doc_lock:
|
||||
open_ie_doc.append(doc_item)
|
||||
progress.update(task, advance=1)
|
||||
except KeyboardInterrupt:
|
||||
logger.info("\n接收到中断信号,正在优雅地关闭程序...")
|
||||
shutdown_event.set()
|
||||
for f in future_to_hash:
|
||||
if not f.done():
|
||||
f.cancel()
|
||||
|
||||
# 合并所有文件的提取结果并保存
|
||||
if open_ie_doc:
|
||||
sum_phrase_chars = sum([len(e) for chunk in open_ie_doc for e in chunk["extracted_entities"]])
|
||||
sum_phrase_words = sum([len(e.split()) for chunk in open_ie_doc for e in chunk["extracted_entities"]])
|
||||
num_phrases = sum([len(chunk["extracted_entities"]) for chunk in open_ie_doc])
|
||||
openie_obj = OpenIE(
|
||||
open_ie_doc,
|
||||
round(sum_phrase_chars / num_phrases, 4) if num_phrases else 0,
|
||||
round(sum_phrase_words / num_phrases, 4) if num_phrases else 0,
|
||||
)
|
||||
# 输出文件名格式:MM-DD-HH-ss-openie.json
|
||||
now = datetime.datetime.now()
|
||||
filename = now.strftime("%m-%d-%H-%S-openie.json")
|
||||
output_path = os.path.join(OPENIE_OUTPUT_DIR, filename)
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
f.write(
|
||||
orjson.dumps(
|
||||
openie_obj.to_dict() if hasattr(openie_obj, "to_dict") else openie_obj.__dict__,
|
||||
option=orjson.OPT_INDENT_2,
|
||||
).decode("utf-8")
|
||||
)
|
||||
logger.info(f"信息提取结果已保存到: {output_path}")
|
||||
else:
|
||||
logger.warning("没有可保存的信息提取结果")
|
||||
|
||||
logger.info("--------信息提取完成--------")
|
||||
logger.info(f"提取失败的文段SHA256:{failed_sha256}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
267
scripts/lpmm_learning_tool.py
Normal file
267
scripts/lpmm_learning_tool.py
Normal file
@@ -0,0 +1,267 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import glob
|
||||
import orjson
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from threading import Lock
|
||||
from typing import Optional
|
||||
|
||||
# 将项目根目录添加到 sys.path
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.knowledge.utils.hash import get_sha256
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
from src.chat.knowledge.open_ie import OpenIE
|
||||
from src.chat.knowledge.embedding_store import EmbeddingManager
|
||||
from src.chat.knowledge.kg_manager import KGManager
|
||||
from rich.progress import (
|
||||
Progress,
|
||||
BarColumn,
|
||||
TimeElapsedColumn,
|
||||
TimeRemainingColumn,
|
||||
TaskProgressColumn,
|
||||
MofNCompleteColumn,
|
||||
SpinnerColumn,
|
||||
TextColumn,
|
||||
)
|
||||
|
||||
logger = get_logger("LPMM_LearningTool")
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
RAW_DATA_PATH = os.path.join(ROOT_PATH, "data", "lpmm_raw_data")
|
||||
OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie")
|
||||
TEMP_DIR = os.path.join(ROOT_PATH, "temp", "lpmm_cache")
|
||||
file_lock = Lock()
|
||||
|
||||
# --- 模块一:数据预处理 ---
|
||||
|
||||
def process_text_file(file_path):
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
raw = f.read()
|
||||
return [p.strip() for p in raw.split("\n\n") if p.strip()]
|
||||
|
||||
def preprocess_raw_data():
|
||||
logger.info("--- 步骤 1: 开始数据预处理 ---")
|
||||
os.makedirs(RAW_DATA_PATH, exist_ok=True)
|
||||
raw_files = list(Path(RAW_DATA_PATH).glob("*.txt"))
|
||||
if not raw_files:
|
||||
logger.warning(f"警告: 在 '{RAW_DATA_PATH}' 中没有找到任何 .txt 文件")
|
||||
return []
|
||||
|
||||
all_paragraphs = []
|
||||
for file in raw_files:
|
||||
logger.info(f"正在处理文件: {file.name}")
|
||||
all_paragraphs.extend(process_text_file(file))
|
||||
|
||||
unique_paragraphs = {get_sha256(p): p for p in all_paragraphs}
|
||||
logger.info(f"共找到 {len(all_paragraphs)} 个段落,去重后剩余 {len(unique_paragraphs)} 个。")
|
||||
logger.info("--- 数据预处理完成 ---")
|
||||
return unique_paragraphs
|
||||
|
||||
# --- 模块二:信息提取 ---
|
||||
|
||||
def get_extraction_prompt(paragraph: str) -> str:
|
||||
return f"""
|
||||
请从以下段落中提取关键信息。你需要提取两种类型的信息:
|
||||
1. **实体 (Entities)**: 识别并列出段落中所有重要的名词或名词短语。
|
||||
2. **三元组 (Triples)**: 以 [主语, 谓语, 宾语] 的格式,提取段落中描述关系或事实的核心信息。
|
||||
|
||||
请严格按照以下 JSON 格式返回结果,不要添加任何额外的解释或注释:
|
||||
{{
|
||||
"entities": ["实体1", "实体2"],
|
||||
"triples": [["主语1", "谓语1", "宾语1"]]
|
||||
}}
|
||||
|
||||
这是你需要处理的段落:
|
||||
---
|
||||
{paragraph}
|
||||
---
|
||||
"""
|
||||
|
||||
async def extract_info_async(pg_hash, paragraph, llm_api):
|
||||
temp_file_path = os.path.join(TEMP_DIR, f"{pg_hash}.json")
|
||||
with file_lock:
|
||||
if os.path.exists(temp_file_path):
|
||||
try:
|
||||
with open(temp_file_path, "rb") as f:
|
||||
return orjson.loads(f.read()), None
|
||||
except orjson.JSONDecodeError:
|
||||
os.remove(temp_file_path)
|
||||
|
||||
prompt = get_extraction_prompt(paragraph)
|
||||
try:
|
||||
content, (_, _, _) = await llm_api.generate_response_async(prompt)
|
||||
extracted_data = orjson.loads(content)
|
||||
doc_item = {
|
||||
"idx": pg_hash, "passage": paragraph,
|
||||
"extracted_entities": extracted_data.get("entities", []),
|
||||
"extracted_triples": extracted_data.get("triples", []),
|
||||
}
|
||||
with file_lock:
|
||||
with open(temp_file_path, "wb") as f:
|
||||
f.write(orjson.dumps(doc_item))
|
||||
return doc_item, None
|
||||
except Exception as e:
|
||||
logger.error(f"提取信息失败:{pg_hash}, 错误:{e}")
|
||||
return None, pg_hash
|
||||
|
||||
def extract_info_sync(pg_hash, paragraph, llm_api):
|
||||
return asyncio.run(extract_info_async(pg_hash, paragraph, llm_api))
|
||||
|
||||
def extract_information(paragraphs_dict, model_set):
|
||||
logger.info("--- 步骤 2: 开始信息提取 ---")
|
||||
os.makedirs(OPENIE_OUTPUT_DIR, exist_ok=True)
|
||||
os.makedirs(TEMP_DIR, exist_ok=True)
|
||||
|
||||
llm_api = LLMRequest(model_set=model_set)
|
||||
failed_hashes, open_ie_docs = [], []
|
||||
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
f_to_hash = {executor.submit(extract_info_sync, p_hash, p, llm_api): p_hash for p_hash, p in paragraphs_dict.items()}
|
||||
with Progress(SpinnerColumn(), TextColumn("[progress.description]{task.description}"), BarColumn(), TaskProgressColumn(), MofNCompleteColumn(), "•", TimeElapsedColumn(), "<", TimeRemainingColumn()) as progress:
|
||||
task = progress.add_task("[cyan]正在提取信息...", total=len(paragraphs_dict))
|
||||
for future in as_completed(f_to_hash):
|
||||
doc_item, failed_hash = future.result()
|
||||
if failed_hash: failed_hashes.append(failed_hash)
|
||||
elif doc_item: open_ie_docs.append(doc_item)
|
||||
progress.update(task, advance=1)
|
||||
|
||||
if open_ie_docs:
|
||||
all_entities = [e for doc in open_ie_docs for e in doc["extracted_entities"]]
|
||||
num_entities = len(all_entities)
|
||||
avg_ent_chars = round(sum(len(e) for e in all_entities) / num_entities, 4) if num_entities else 0
|
||||
avg_ent_words = round(sum(len(e.split()) for e in all_entities) / num_entities, 4) if num_entities else 0
|
||||
openie_obj = OpenIE(docs=open_ie_docs, avg_ent_chars=avg_ent_chars, avg_ent_words=avg_ent_words)
|
||||
|
||||
now = datetime.datetime.now()
|
||||
filename = now.strftime("%Y-%m-%d-%H-%M-%S-openie.json")
|
||||
output_path = os.path.join(OPENIE_OUTPUT_DIR, filename)
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(orjson.dumps(openie_obj._to_dict()))
|
||||
logger.info(f"信息提取结果已保存到: {output_path}")
|
||||
|
||||
if failed_hashes: logger.error(f"以下 {len(failed_hashes)} 个段落提取失败: {failed_hashes}")
|
||||
logger.info("--- 信息提取完成 ---")
|
||||
|
||||
# --- 模块三:数据导入 ---
|
||||
|
||||
async def import_data(openie_obj: Optional[OpenIE] = None):
|
||||
"""
|
||||
将OpenIE数据导入知识库(Embedding Store 和 KG)
|
||||
|
||||
Args:
|
||||
openie_obj (Optional[OpenIE], optional): 如果提供,则直接使用这个OpenIE对象;
|
||||
否则,将自动从默认文件夹加载最新的OpenIE文件。
|
||||
默认为 None.
|
||||
"""
|
||||
logger.info("--- 步骤 3: 开始数据导入 ---")
|
||||
embed_manager, kg_manager = EmbeddingManager(), KGManager()
|
||||
|
||||
logger.info("正在加载现有的 Embedding 库...")
|
||||
try: embed_manager.load_from_file()
|
||||
except Exception as e: logger.warning(f"加载 Embedding 库失败: {e}。")
|
||||
|
||||
logger.info("正在加载现有的 KG...")
|
||||
try: kg_manager.load_from_file()
|
||||
except Exception as e: logger.warning(f"加载 KG 失败: {e}。")
|
||||
|
||||
try:
|
||||
if openie_obj:
|
||||
openie_data = openie_obj
|
||||
logger.info("已使用指定的 OpenIE 对象。")
|
||||
else:
|
||||
openie_data = OpenIE.load()
|
||||
except Exception as e:
|
||||
logger.error(f"加载OpenIE数据文件失败: {e}")
|
||||
return
|
||||
|
||||
raw_paragraphs = openie_data.extract_raw_paragraph_dict()
|
||||
triple_list_data = openie_data.extract_triple_dict()
|
||||
|
||||
new_raw_paragraphs, new_triple_list_data = {}, {}
|
||||
stored_embeds = embed_manager.stored_pg_hashes
|
||||
stored_kgs = kg_manager.stored_paragraph_hashes
|
||||
|
||||
for p_hash, raw_p in raw_paragraphs.items():
|
||||
if p_hash not in stored_embeds and p_hash not in stored_kgs:
|
||||
new_raw_paragraphs[p_hash] = raw_p
|
||||
new_triple_list_data[p_hash] = triple_list_data.get(p_hash, [])
|
||||
|
||||
if not new_raw_paragraphs:
|
||||
logger.info("没有新的段落需要处理。")
|
||||
else:
|
||||
logger.info(f"去重完成,发现 {len(new_raw_paragraphs)} 个新段落。")
|
||||
logger.info("开始生成 Embedding...")
|
||||
embed_manager.store_new_data_set(new_raw_paragraphs, new_triple_list_data)
|
||||
embed_manager.rebuild_faiss_index()
|
||||
embed_manager.save_to_file()
|
||||
logger.info("Embedding 处理完成!")
|
||||
|
||||
logger.info("开始构建 KG...")
|
||||
kg_manager.build_kg(new_triple_list_data, embed_manager)
|
||||
kg_manager.save_to_file()
|
||||
logger.info("KG 构建完成!")
|
||||
|
||||
logger.info("--- 数据导入完成 ---")
|
||||
|
||||
def import_from_specific_file():
|
||||
"""从用户指定的 openie.json 文件导入数据"""
|
||||
file_path = input("请输入 openie.json 文件的完整路径: ").strip()
|
||||
|
||||
if not os.path.exists(file_path):
|
||||
logger.error(f"文件路径不存在: {file_path}")
|
||||
return
|
||||
|
||||
if not file_path.endswith(".json"):
|
||||
logger.error("请输入一个有效的 .json 文件路径。")
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info(f"正在从 {file_path} 加载 OpenIE 数据...")
|
||||
openie_obj = OpenIE.load(filepath=file_path)
|
||||
asyncio.run(import_data(openie_obj=openie_obj))
|
||||
except Exception as e:
|
||||
logger.error(f"从指定文件导入数据时发生错误: {e}")
|
||||
|
||||
# --- 主函数 ---
|
||||
|
||||
def main():
|
||||
# 使用 os.path.relpath 创建相对于项目根目录的友好路径
|
||||
raw_data_relpath = os.path.relpath(RAW_DATA_PATH, os.path.join(ROOT_PATH, ".."))
|
||||
openie_output_relpath = os.path.relpath(OPENIE_OUTPUT_DIR, os.path.join(ROOT_PATH, ".."))
|
||||
|
||||
print("=== LPMM 知识库学习工具 ===")
|
||||
print(f"1. [数据预处理] -> 读取 .txt 文件 (来源: ./{raw_data_relpath}/)")
|
||||
print(f"2. [信息提取] -> 提取信息并存为 .json (输出至: ./{openie_output_relpath}/)")
|
||||
print("3. [数据导入] -> 从 openie 文件夹自动导入最新知识")
|
||||
print("4. [全流程] -> 按顺序执行 1 -> 2 -> 3")
|
||||
print("5. [指定导入] -> 从特定的 openie.json 文件导入知识")
|
||||
print("0. [退出]")
|
||||
print("-" * 30)
|
||||
choice = input("请输入你的选择 (0-5): ").strip()
|
||||
|
||||
if choice == '1':
|
||||
preprocess_raw_data()
|
||||
elif choice == '2':
|
||||
paragraphs = preprocess_raw_data()
|
||||
if paragraphs: extract_information(paragraphs, model_config.model_task_config.lpmm_qa)
|
||||
elif choice == '3':
|
||||
asyncio.run(import_data())
|
||||
elif choice == '4':
|
||||
paragraphs = preprocess_raw_data()
|
||||
if paragraphs:
|
||||
extract_information(paragraphs, model_config.model_task_config.lpmm_qa)
|
||||
asyncio.run(import_data())
|
||||
elif choice == '5':
|
||||
import_from_specific_file()
|
||||
elif choice == '0':
|
||||
sys.exit(0)
|
||||
else:
|
||||
print("无效输入,请重新运行脚本。")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,78 +0,0 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
import sys # 新增系统模块导入
|
||||
from src.chat.knowledge.utils.hash import get_sha256
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("lpmm")
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
RAW_DATA_PATH = os.path.join(ROOT_PATH, "data/lpmm_raw_data")
|
||||
# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data/imported_lpmm_data")
|
||||
|
||||
|
||||
def _process_text_file(file_path):
|
||||
"""处理单个文本文件,返回段落列表"""
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
raw = f.read()
|
||||
|
||||
paragraphs = []
|
||||
paragraph = ""
|
||||
for line in raw.split("\n"):
|
||||
if line.strip() == "":
|
||||
if paragraph != "":
|
||||
paragraphs.append(paragraph.strip())
|
||||
paragraph = ""
|
||||
else:
|
||||
paragraph += line + "\n"
|
||||
|
||||
if paragraph != "":
|
||||
paragraphs.append(paragraph.strip())
|
||||
|
||||
return paragraphs
|
||||
|
||||
|
||||
def _process_multi_files() -> list:
|
||||
raw_files = list(Path(RAW_DATA_PATH).glob("*.txt"))
|
||||
if not raw_files:
|
||||
logger.warning("警告: data/lpmm_raw_data 中没有找到任何 .txt 文件")
|
||||
sys.exit(1)
|
||||
# 处理所有文件
|
||||
all_paragraphs = []
|
||||
for file in raw_files:
|
||||
logger.info(f"正在处理文件: {file.name}")
|
||||
paragraphs = _process_text_file(file)
|
||||
all_paragraphs.extend(paragraphs)
|
||||
return all_paragraphs
|
||||
|
||||
|
||||
def load_raw_data() -> tuple[list[str], list[str]]:
|
||||
"""加载原始数据文件
|
||||
|
||||
读取原始数据文件,将原始数据加载到内存中
|
||||
|
||||
Args:
|
||||
path: 可选,指定要读取的json文件绝对路径
|
||||
|
||||
Returns:
|
||||
- raw_data: 原始数据列表
|
||||
- sha256_list: 原始数据的SHA256集合
|
||||
"""
|
||||
raw_data = _process_multi_files()
|
||||
sha256_list = []
|
||||
sha256_set = set()
|
||||
for item in raw_data:
|
||||
if not isinstance(item, str):
|
||||
logger.warning(f"数据类型错误:{item}")
|
||||
continue
|
||||
pg_hash = get_sha256(item)
|
||||
if pg_hash in sha256_set:
|
||||
logger.warning(f"重复数据:{item}")
|
||||
continue
|
||||
sha256_set.add(pg_hash)
|
||||
sha256_list.append(pg_hash)
|
||||
raw_data.append(item)
|
||||
logger.info(f"共读取到{len(raw_data)}条数据")
|
||||
|
||||
return sha256_list, raw_data
|
||||
Reference in New Issue
Block a user