总之就是知识库
This commit is contained in:
@@ -192,11 +192,7 @@ async def extract_info_async(pg_hash, paragraph, llm_api):
|
||||
return None, pg_hash
|
||||
|
||||
|
||||
def extract_info_sync(pg_hash, paragraph, llm_api):
|
||||
return asyncio.run(extract_info_async(pg_hash, paragraph, llm_api))
|
||||
|
||||
|
||||
def extract_information(paragraphs_dict, model_set):
|
||||
async def extract_information(paragraphs_dict, model_set):
|
||||
logger.info("--- 步骤 2: 开始信息提取 ---")
|
||||
os.makedirs(OPENIE_OUTPUT_DIR, exist_ok=True)
|
||||
os.makedirs(TEMP_DIR, exist_ok=True)
|
||||
@@ -204,32 +200,35 @@ def extract_information(paragraphs_dict, model_set):
|
||||
llm_api = LLMRequest(model_set=model_set)
|
||||
failed_hashes, open_ie_docs = [], []
|
||||
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
f_to_hash = {
|
||||
executor.submit(extract_info_sync, p_hash, p, llm_api): p_hash for p_hash, p in paragraphs_dict.items()
|
||||
}
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
MofNCompleteColumn(),
|
||||
"•",
|
||||
TimeElapsedColumn(),
|
||||
"<",
|
||||
TimeRemainingColumn(),
|
||||
) as progress:
|
||||
task = progress.add_task("[cyan]正在提取信息...", total=len(paragraphs_dict))
|
||||
for future in as_completed(f_to_hash):
|
||||
doc_item, failed_hash = future.result()
|
||||
if failed_hash:
|
||||
failed_hashes.append(failed_hash)
|
||||
elif doc_item:
|
||||
open_ie_docs.append(doc_item)
|
||||
progress.update(task, advance=1)
|
||||
tasks = [
|
||||
extract_info_async(p_hash, p, llm_api)
|
||||
for p_hash, p in paragraphs_dict.items()
|
||||
]
|
||||
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
MofNCompleteColumn(),
|
||||
"•",
|
||||
TimeElapsedColumn(),
|
||||
"<",
|
||||
TimeRemainingColumn(),
|
||||
) as progress:
|
||||
prog_task = progress.add_task("[cyan]正在提取信息...", total=len(tasks))
|
||||
for future in asyncio.as_completed(tasks):
|
||||
doc_item, failed_hash = await future
|
||||
if failed_hash:
|
||||
failed_hashes.append(failed_hash)
|
||||
elif doc_item:
|
||||
open_ie_docs.append(doc_item)
|
||||
progress.update(prog_task, advance=1)
|
||||
|
||||
if open_ie_docs:
|
||||
all_entities = [e for doc in open_ie_docs for e in doc["extracted_entities"]]
|
||||
all_entities = [
|
||||
e for doc in open_ie_docs for e in doc["extracted_entities"]
|
||||
]
|
||||
num_entities = len(all_entities)
|
||||
avg_ent_chars = round(sum(len(e) for e in all_entities) / num_entities, 4) if num_entities else 0
|
||||
avg_ent_words = round(sum(len(e.split()) for e in all_entities) / num_entities, 4) if num_entities else 0
|
||||
@@ -314,7 +313,7 @@ async def import_data(openie_obj: OpenIE | None = None):
|
||||
logger.info("--- 数据导入完成 ---")
|
||||
|
||||
|
||||
def import_from_specific_file():
|
||||
async def import_from_specific_file():
|
||||
"""从用户指定的 openie.json 文件导入数据"""
|
||||
file_path = input("请输入 openie.json 文件的完整路径: ").strip()
|
||||
|
||||
@@ -329,7 +328,7 @@ def import_from_specific_file():
|
||||
try:
|
||||
logger.info(f"正在从 {file_path} 加载 OpenIE 数据...")
|
||||
openie_obj = OpenIE.load()
|
||||
asyncio.run(import_data(openie_obj=openie_obj))
|
||||
await import_data(openie_obj=openie_obj)
|
||||
except Exception as e:
|
||||
logger.error(f"从指定文件导入数据时发生错误: {e}")
|
||||
|
||||
@@ -337,14 +336,20 @@ def import_from_specific_file():
|
||||
# --- 主函数 ---
|
||||
|
||||
|
||||
def main():
|
||||
async def async_main():
|
||||
# 使用 os.path.relpath 创建相对于项目根目录的友好路径
|
||||
raw_data_relpath = os.path.relpath(RAW_DATA_PATH, os.path.join(ROOT_PATH, ".."))
|
||||
openie_output_relpath = os.path.relpath(OPENIE_OUTPUT_DIR, os.path.join(ROOT_PATH, ".."))
|
||||
raw_data_relpath = os.path.relpath(
|
||||
RAW_DATA_PATH, os.path.join(ROOT_PATH, "..")
|
||||
)
|
||||
openie_output_relpath = os.path.relpath(
|
||||
OPENIE_OUTPUT_DIR, os.path.join(ROOT_PATH, "..")
|
||||
)
|
||||
|
||||
print("=== LPMM 知识库学习工具 ===")
|
||||
print(f"1. [数据预处理] -> 读取 .txt 文件 (来源: ./{raw_data_relpath}/)")
|
||||
print(f"2. [信息提取] -> 提取信息并存为 .json (输出至: ./{openie_output_relpath}/)")
|
||||
print(
|
||||
f"2. [信息提取] -> 提取信息并存为 .json (输出至: ./{openie_output_relpath}/)"
|
||||
)
|
||||
print("3. [数据导入] -> 从 openie 文件夹自动导入最新知识")
|
||||
print("4. [全流程] -> 按顺序执行 1 -> 2 -> 3")
|
||||
print("5. [指定导入] -> 从特定的 openie.json 文件导入知识")
|
||||
@@ -358,16 +363,20 @@ def main():
|
||||
elif choice == "2":
|
||||
paragraphs = preprocess_raw_data()
|
||||
if paragraphs:
|
||||
extract_information(paragraphs, model_config.model_task_config.lpmm_qa)
|
||||
await extract_information(
|
||||
paragraphs, model_config.model_task_config.lpmm_qa
|
||||
)
|
||||
elif choice == "3":
|
||||
asyncio.run(import_data())
|
||||
await import_data()
|
||||
elif choice == "4":
|
||||
paragraphs = preprocess_raw_data()
|
||||
if paragraphs:
|
||||
extract_information(paragraphs, model_config.model_task_config.lpmm_qa)
|
||||
asyncio.run(import_data())
|
||||
await extract_information(
|
||||
paragraphs, model_config.model_task_config.lpmm_qa
|
||||
)
|
||||
await import_data()
|
||||
elif choice == "5":
|
||||
import_from_specific_file()
|
||||
await import_from_specific_file()
|
||||
elif choice == "6":
|
||||
clear_cache()
|
||||
elif choice == "0":
|
||||
@@ -377,4 +386,4 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
asyncio.run(async_main())
|
||||
|
||||
Reference in New Issue
Block a user