总之就是知识库

This commit is contained in:
tt-P607
2025-10-11 14:18:54 +08:00
parent 94e34c9370
commit 0383a999fb
2 changed files with 114 additions and 153 deletions

View File

@@ -192,11 +192,7 @@ async def extract_info_async(pg_hash, paragraph, llm_api):
return None, pg_hash
def extract_info_sync(pg_hash, paragraph, llm_api):
return asyncio.run(extract_info_async(pg_hash, paragraph, llm_api))
def extract_information(paragraphs_dict, model_set):
async def extract_information(paragraphs_dict, model_set):
logger.info("--- 步骤 2: 开始信息提取 ---")
os.makedirs(OPENIE_OUTPUT_DIR, exist_ok=True)
os.makedirs(TEMP_DIR, exist_ok=True)
@@ -204,32 +200,35 @@ def extract_information(paragraphs_dict, model_set):
llm_api = LLMRequest(model_set=model_set)
failed_hashes, open_ie_docs = [], []
with ThreadPoolExecutor(max_workers=5) as executor:
f_to_hash = {
executor.submit(extract_info_sync, p_hash, p, llm_api): p_hash for p_hash, p in paragraphs_dict.items()
}
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
MofNCompleteColumn(),
"",
TimeElapsedColumn(),
"<",
TimeRemainingColumn(),
) as progress:
task = progress.add_task("[cyan]正在提取信息...", total=len(paragraphs_dict))
for future in as_completed(f_to_hash):
doc_item, failed_hash = future.result()
if failed_hash:
failed_hashes.append(failed_hash)
elif doc_item:
open_ie_docs.append(doc_item)
progress.update(task, advance=1)
tasks = [
extract_info_async(p_hash, p, llm_api)
for p_hash, p in paragraphs_dict.items()
]
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
MofNCompleteColumn(),
"",
TimeElapsedColumn(),
"<",
TimeRemainingColumn(),
) as progress:
prog_task = progress.add_task("[cyan]正在提取信息...", total=len(tasks))
for future in asyncio.as_completed(tasks):
doc_item, failed_hash = await future
if failed_hash:
failed_hashes.append(failed_hash)
elif doc_item:
open_ie_docs.append(doc_item)
progress.update(prog_task, advance=1)
if open_ie_docs:
all_entities = [e for doc in open_ie_docs for e in doc["extracted_entities"]]
all_entities = [
e for doc in open_ie_docs for e in doc["extracted_entities"]
]
num_entities = len(all_entities)
avg_ent_chars = round(sum(len(e) for e in all_entities) / num_entities, 4) if num_entities else 0
avg_ent_words = round(sum(len(e.split()) for e in all_entities) / num_entities, 4) if num_entities else 0
@@ -314,7 +313,7 @@ async def import_data(openie_obj: OpenIE | None = None):
logger.info("--- 数据导入完成 ---")
def import_from_specific_file():
async def import_from_specific_file():
"""从用户指定的 openie.json 文件导入数据"""
file_path = input("请输入 openie.json 文件的完整路径: ").strip()
@@ -329,7 +328,7 @@ def import_from_specific_file():
try:
logger.info(f"正在从 {file_path} 加载 OpenIE 数据...")
openie_obj = OpenIE.load()
asyncio.run(import_data(openie_obj=openie_obj))
await import_data(openie_obj=openie_obj)
except Exception as e:
logger.error(f"从指定文件导入数据时发生错误: {e}")
@@ -337,14 +336,20 @@ def import_from_specific_file():
# --- 主函数 ---
def main():
async def async_main():
# 使用 os.path.relpath 创建相对于项目根目录的友好路径
raw_data_relpath = os.path.relpath(RAW_DATA_PATH, os.path.join(ROOT_PATH, ".."))
openie_output_relpath = os.path.relpath(OPENIE_OUTPUT_DIR, os.path.join(ROOT_PATH, ".."))
raw_data_relpath = os.path.relpath(
RAW_DATA_PATH, os.path.join(ROOT_PATH, "..")
)
openie_output_relpath = os.path.relpath(
OPENIE_OUTPUT_DIR, os.path.join(ROOT_PATH, "..")
)
print("=== LPMM 知识库学习工具 ===")
print(f"1. [数据预处理] -> 读取 .txt 文件 (来源: ./{raw_data_relpath}/)")
print(f"2. [信息提取] -> 提取信息并存为 .json (输出至: ./{openie_output_relpath}/)")
print(
f"2. [信息提取] -> 提取信息并存为 .json (输出至: ./{openie_output_relpath}/)"
)
print("3. [数据导入] -> 从 openie 文件夹自动导入最新知识")
print("4. [全流程] -> 按顺序执行 1 -> 2 -> 3")
print("5. [指定导入] -> 从特定的 openie.json 文件导入知识")
@@ -358,16 +363,20 @@ def main():
elif choice == "2":
paragraphs = preprocess_raw_data()
if paragraphs:
extract_information(paragraphs, model_config.model_task_config.lpmm_qa)
await extract_information(
paragraphs, model_config.model_task_config.lpmm_qa
)
elif choice == "3":
asyncio.run(import_data())
await import_data()
elif choice == "4":
paragraphs = preprocess_raw_data()
if paragraphs:
extract_information(paragraphs, model_config.model_task_config.lpmm_qa)
asyncio.run(import_data())
await extract_information(
paragraphs, model_config.model_task_config.lpmm_qa
)
await import_data()
elif choice == "5":
import_from_specific_file()
await import_from_specific_file()
elif choice == "6":
clear_cache()
elif choice == "0":
@@ -377,4 +386,4 @@ def main():
if __name__ == "__main__":
main()
asyncio.run(async_main())