Merge branch 'dev' of https://github.com/MaiM-with-u/MaiBot into dev
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
|
||||
import sys
|
||||
import os
|
||||
from time import sleep
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
@@ -19,9 +20,14 @@ from src.plugins.knowledge.src.utils.hash import get_sha256
|
||||
|
||||
|
||||
# 添加项目根目录到 sys.path
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
OPENIE_DIR = (
|
||||
global_config["persistence"]["openie_data_path"]
|
||||
if global_config["persistence"]["openie_data_path"]
|
||||
else os.path.join(ROOT_PATH, "data/openie")
|
||||
)
|
||||
|
||||
|
||||
logger = get_module_logger("LPMM知识库-OpenIE导入")
|
||||
logger = get_module_logger("OpenIE导入")
|
||||
|
||||
|
||||
def hash_deduplicate(
|
||||
@@ -66,8 +72,45 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k
|
||||
entity_list_data = openie_data.extract_entity_dict()
|
||||
# 索引的三元组列表
|
||||
triple_list_data = openie_data.extract_triple_dict()
|
||||
# print(openie_data.docs)
|
||||
if len(raw_paragraphs) != len(entity_list_data) or len(raw_paragraphs) != len(triple_list_data):
|
||||
logger.error("OpenIE数据存在异常")
|
||||
logger.error(f"原始段落数量:{len(raw_paragraphs)}")
|
||||
logger.error(f"实体列表数量:{len(entity_list_data)}")
|
||||
logger.error(f"三元组列表数量:{len(triple_list_data)}")
|
||||
logger.error("OpenIE数据段落数量与实体列表数量或三元组列表数量不一致")
|
||||
logger.error("请保证你的原始数据分段良好,不要有类似于 “.....” 单独成一段的情况")
|
||||
logger.error("或者一段中只有符号的情况")
|
||||
# 新增:检查docs中每条数据的完整性
|
||||
logger.error("系统将于2秒后开始检查数据完整性")
|
||||
sleep(2)
|
||||
found_missing = False
|
||||
for doc in getattr(openie_data, "docs", []):
|
||||
idx = doc.get("idx", "<无idx>")
|
||||
passage = doc.get("passage", "<无passage>")
|
||||
missing = []
|
||||
# 检查字段是否存在且非空
|
||||
if "passage" not in doc or not doc.get("passage"):
|
||||
missing.append("passage")
|
||||
if "extracted_entities" not in doc or not isinstance(doc.get("extracted_entities"), list):
|
||||
missing.append("名词列表缺失")
|
||||
elif len(doc.get("extracted_entities", [])) == 0:
|
||||
missing.append("名词列表为空")
|
||||
if "extracted_triples" not in doc or not isinstance(doc.get("extracted_triples"), list):
|
||||
missing.append("主谓宾三元组缺失")
|
||||
elif len(doc.get("extracted_triples", [])) == 0:
|
||||
missing.append("主谓宾三元组为空")
|
||||
# 输出所有doc的idx
|
||||
# print(f"检查: idx={idx}")
|
||||
if missing:
|
||||
found_missing = True
|
||||
logger.error("\n")
|
||||
logger.error("数据缺失:")
|
||||
logger.error(f"对应哈希值:{idx}")
|
||||
logger.error(f"对应文段内容内容:{passage}")
|
||||
logger.error(f"非法原因:{', '.join(missing)}")
|
||||
if not found_missing:
|
||||
print("所有数据均完整,没有发现缺失字段。")
|
||||
return False
|
||||
# 将索引换为对应段落的hash值
|
||||
logger.info("正在进行段落去重与重索引")
|
||||
@@ -131,6 +174,7 @@ def main():
|
||||
embed_manager.load_from_file()
|
||||
except Exception as e:
|
||||
logger.error("从文件加载Embedding库时发生错误:{}".format(e))
|
||||
logger.error("如果你是第一次导入知识,请忽略此错误")
|
||||
logger.info("Embedding库加载完成")
|
||||
# 初始化KG
|
||||
kg_manager = KGManager()
|
||||
@@ -139,6 +183,7 @@ def main():
|
||||
kg_manager.load_from_file()
|
||||
except Exception as e:
|
||||
logger.error("从文件加载KG时发生错误:{}".format(e))
|
||||
logger.error("如果你是第一次导入知识,请忽略此错误")
|
||||
logger.info("KG加载完成")
|
||||
|
||||
logger.info(f"KG节点数量:{len(kg_manager.graph.get_node_list())}")
|
||||
@@ -163,4 +208,5 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# logger.info(f"111111111111111111111111{ROOT_PATH}")
|
||||
main()
|
||||
|
||||
@@ -4,11 +4,13 @@ import signal
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from threading import Lock, Event
|
||||
import sys
|
||||
import glob
|
||||
import datetime
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
# 添加项目根目录到 sys.path
|
||||
|
||||
import tqdm
|
||||
from rich.progress import Progress # 替换为 rich 进度条
|
||||
|
||||
from src.common.logger import get_module_logger
|
||||
from src.plugins.knowledge.src.lpmmconfig import global_config
|
||||
@@ -16,10 +18,31 @@ from src.plugins.knowledge.src.ie_process import info_extract_from_str
|
||||
from src.plugins.knowledge.src.llm_client import LLMClient
|
||||
from src.plugins.knowledge.src.open_ie import OpenIE
|
||||
from src.plugins.knowledge.src.raw_processing import load_raw_data
|
||||
from rich.progress import (
|
||||
BarColumn,
|
||||
TimeElapsedColumn,
|
||||
TimeRemainingColumn,
|
||||
TaskProgressColumn,
|
||||
MofNCompleteColumn,
|
||||
SpinnerColumn,
|
||||
TextColumn,
|
||||
)
|
||||
|
||||
logger = get_module_logger("LPMM知识库-信息提取")
|
||||
|
||||
TEMP_DIR = "./temp"
|
||||
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
TEMP_DIR = os.path.join(ROOT_PATH, "temp")
|
||||
IMPORTED_DATA_PATH = (
|
||||
global_config["persistence"]["raw_data_path"]
|
||||
if global_config["persistence"]["raw_data_path"]
|
||||
else os.path.join(ROOT_PATH, "data/imported_lpmm_data")
|
||||
)
|
||||
OPENIE_OUTPUT_DIR = (
|
||||
global_config["persistence"]["openie_data_path"]
|
||||
if global_config["persistence"]["openie_data_path"]
|
||||
else os.path.join(ROOT_PATH, "data/openie")
|
||||
)
|
||||
|
||||
# 创建一个线程安全的锁,用于保护文件操作和共享数据
|
||||
file_lock = Lock()
|
||||
@@ -70,8 +93,7 @@ def process_single_text(pg_hash, raw_data, llm_client_list):
|
||||
# 如果保存失败,确保不会留下损坏的文件
|
||||
if os.path.exists(temp_file_path):
|
||||
os.remove(temp_file_path)
|
||||
# 设置shutdown_event以终止程序
|
||||
shutdown_event.set()
|
||||
sys.exit(0)
|
||||
return None, pg_hash
|
||||
return doc_item, None
|
||||
|
||||
@@ -79,7 +101,7 @@ def process_single_text(pg_hash, raw_data, llm_client_list):
|
||||
def signal_handler(_signum, _frame):
|
||||
"""处理Ctrl+C信号"""
|
||||
logger.info("\n接收到中断信号,正在优雅地关闭程序...")
|
||||
shutdown_event.set()
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def main():
|
||||
@@ -110,33 +132,61 @@ def main():
|
||||
global_config["llm_providers"][key]["api_key"],
|
||||
)
|
||||
|
||||
logger.info("正在加载原始数据")
|
||||
sha256_list, raw_datas = load_raw_data()
|
||||
logger.info("原始数据加载完成\n")
|
||||
# 检查 openie 输出目录
|
||||
if not os.path.exists(OPENIE_OUTPUT_DIR):
|
||||
os.makedirs(OPENIE_OUTPUT_DIR)
|
||||
logger.info(f"已创建输出目录: {OPENIE_OUTPUT_DIR}")
|
||||
|
||||
# 创建临时目录
|
||||
if not os.path.exists(f"{TEMP_DIR}"):
|
||||
os.makedirs(f"{TEMP_DIR}")
|
||||
# 确保 TEMP_DIR 目录存在
|
||||
if not os.path.exists(TEMP_DIR):
|
||||
os.makedirs(TEMP_DIR)
|
||||
logger.info(f"已创建缓存目录: {TEMP_DIR}")
|
||||
|
||||
# 遍历IMPORTED_DATA_PATH下所有json文件
|
||||
imported_files = sorted(glob.glob(os.path.join(IMPORTED_DATA_PATH, "*.json")))
|
||||
if not imported_files:
|
||||
logger.error(f"未在 {IMPORTED_DATA_PATH} 下找到任何json文件")
|
||||
sys.exit(1)
|
||||
|
||||
all_sha256_list = []
|
||||
all_raw_datas = []
|
||||
|
||||
for imported_file in imported_files:
|
||||
logger.info(f"正在处理文件: {imported_file}")
|
||||
try:
|
||||
sha256_list, raw_datas = load_raw_data(imported_file)
|
||||
except Exception as e:
|
||||
logger.error(f"读取文件失败: {imported_file}, 错误: {e}")
|
||||
continue
|
||||
all_sha256_list.extend(sha256_list)
|
||||
all_raw_datas.extend(raw_datas)
|
||||
|
||||
failed_sha256 = []
|
||||
open_ie_doc = []
|
||||
|
||||
# 创建线程池,最大线程数为50
|
||||
workers = global_config["info_extraction"]["workers"]
|
||||
with ThreadPoolExecutor(max_workers=workers) as executor:
|
||||
# 提交所有任务到线程池
|
||||
future_to_hash = {
|
||||
executor.submit(process_single_text, pg_hash, raw_data, llm_client_list): pg_hash
|
||||
for pg_hash, raw_data in zip(sha256_list, raw_datas)
|
||||
for pg_hash, raw_data in zip(all_sha256_list, all_raw_datas)
|
||||
}
|
||||
|
||||
# 使用tqdm显示进度
|
||||
with tqdm.tqdm(total=len(future_to_hash), postfix="正在进行提取:") as pbar:
|
||||
# 处理完成的任务
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
MofNCompleteColumn(),
|
||||
"•",
|
||||
TimeElapsedColumn(),
|
||||
"<",
|
||||
TimeRemainingColumn(),
|
||||
transient=False,
|
||||
) as progress:
|
||||
task = progress.add_task("正在进行提取:", total=len(future_to_hash))
|
||||
try:
|
||||
for future in as_completed(future_to_hash):
|
||||
if shutdown_event.is_set():
|
||||
# 取消所有未完成的任务
|
||||
for f in future_to_hash:
|
||||
if not f.done():
|
||||
f.cancel()
|
||||
@@ -149,26 +199,38 @@ def main():
|
||||
elif doc_item:
|
||||
with open_ie_doc_lock:
|
||||
open_ie_doc.append(doc_item)
|
||||
pbar.update(1)
|
||||
progress.update(task, advance=1)
|
||||
except KeyboardInterrupt:
|
||||
# 如果在这里捕获到KeyboardInterrupt,说明signal_handler可能没有正常工作
|
||||
logger.info("\n接收到中断信号,正在优雅地关闭程序...")
|
||||
shutdown_event.set()
|
||||
# 取消所有未完成的任务
|
||||
for f in future_to_hash:
|
||||
if not f.done():
|
||||
f.cancel()
|
||||
|
||||
# 保存信息提取结果
|
||||
sum_phrase_chars = sum([len(e) for chunk in open_ie_doc for e in chunk["extracted_entities"]])
|
||||
sum_phrase_words = sum([len(e.split()) for chunk in open_ie_doc for e in chunk["extracted_entities"]])
|
||||
num_phrases = sum([len(chunk["extracted_entities"]) for chunk in open_ie_doc])
|
||||
openie_obj = OpenIE(
|
||||
open_ie_doc,
|
||||
round(sum_phrase_chars / num_phrases, 4),
|
||||
round(sum_phrase_words / num_phrases, 4),
|
||||
)
|
||||
OpenIE.save(openie_obj)
|
||||
# 合并所有文件的提取结果并保存
|
||||
if open_ie_doc:
|
||||
sum_phrase_chars = sum([len(e) for chunk in open_ie_doc for e in chunk["extracted_entities"]])
|
||||
sum_phrase_words = sum([len(e.split()) for chunk in open_ie_doc for e in chunk["extracted_entities"]])
|
||||
num_phrases = sum([len(chunk["extracted_entities"]) for chunk in open_ie_doc])
|
||||
openie_obj = OpenIE(
|
||||
open_ie_doc,
|
||||
round(sum_phrase_chars / num_phrases, 4) if num_phrases else 0,
|
||||
round(sum_phrase_words / num_phrases, 4) if num_phrases else 0,
|
||||
)
|
||||
# 输出文件名格式:MM-DD-HH-ss-openie.json
|
||||
now = datetime.datetime.now()
|
||||
filename = now.strftime("%m-%d-%H-%S-openie.json")
|
||||
output_path = os.path.join(OPENIE_OUTPUT_DIR, filename)
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(
|
||||
openie_obj.to_dict() if hasattr(openie_obj, "to_dict") else openie_obj.__dict__,
|
||||
f,
|
||||
ensure_ascii=False,
|
||||
indent=4,
|
||||
)
|
||||
logger.info(f"信息提取结果已保存到: {output_path}")
|
||||
else:
|
||||
logger.warning("没有可保存的信息提取结果")
|
||||
|
||||
logger.info("--------信息提取完成--------")
|
||||
logger.info(f"提取失败的文段SHA256:{failed_sha256}")
|
||||
|
||||
@@ -2,18 +2,22 @@ import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
import sys # 新增系统模块导入
|
||||
import datetime # 新增导入
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
from src.common.logger import get_module_logger
|
||||
|
||||
logger = get_module_logger("LPMM数据库-原始数据处理")
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
RAW_DATA_PATH = os.path.join(ROOT_PATH, "data/lpmm_raw_data")
|
||||
IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data/imported_lpmm_data")
|
||||
|
||||
# 添加项目根目录到 sys.path
|
||||
|
||||
|
||||
def check_and_create_dirs():
|
||||
"""检查并创建必要的目录"""
|
||||
required_dirs = ["data/lpmm_raw_data", "data/imported_lpmm_data"]
|
||||
required_dirs = [RAW_DATA_PATH, IMPORTED_DATA_PATH]
|
||||
|
||||
for dir_path in required_dirs:
|
||||
if not os.path.exists(dir_path):
|
||||
@@ -58,17 +62,17 @@ def main():
|
||||
# 检查并创建必要的目录
|
||||
check_and_create_dirs()
|
||||
|
||||
# 检查输出文件是否存在
|
||||
if os.path.exists("data/import.json"):
|
||||
logger.error("错误: data/import.json 已存在,请先处理或删除该文件")
|
||||
sys.exit(1)
|
||||
# # 检查输出文件是否存在
|
||||
# if os.path.exists(RAW_DATA_PATH):
|
||||
# logger.error("错误: data/import.json 已存在,请先处理或删除该文件")
|
||||
# sys.exit(1)
|
||||
|
||||
if os.path.exists("data/openie.json"):
|
||||
logger.error("错误: data/openie.json 已存在,请先处理或删除该文件")
|
||||
sys.exit(1)
|
||||
# if os.path.exists(RAW_DATA_PATH):
|
||||
# logger.error("错误: data/openie.json 已存在,请先处理或删除该文件")
|
||||
# sys.exit(1)
|
||||
|
||||
# 获取所有原始文本文件
|
||||
raw_files = list(Path("data/lpmm_raw_data").glob("*.txt"))
|
||||
raw_files = list(Path(RAW_DATA_PATH).glob("*.txt"))
|
||||
if not raw_files:
|
||||
logger.warning("警告: data/lpmm_raw_data 中没有找到任何 .txt 文件")
|
||||
sys.exit(1)
|
||||
@@ -80,8 +84,10 @@ def main():
|
||||
paragraphs = process_text_file(file)
|
||||
all_paragraphs.extend(paragraphs)
|
||||
|
||||
# 保存合并后的结果
|
||||
output_path = "data/import.json"
|
||||
# 保存合并后的结果到 IMPORTED_DATA_PATH,文件名格式为 MM-DD-HH-ss-imported-data.json
|
||||
now = datetime.datetime.now()
|
||||
filename = now.strftime("%m-%d-%H-%S-imported-data.json")
|
||||
output_path = os.path.join(IMPORTED_DATA_PATH, filename)
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(all_paragraphs, f, ensure_ascii=False, indent=4)
|
||||
|
||||
@@ -89,4 +95,6 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(f"Raw Data Path: {RAW_DATA_PATH}")
|
||||
print(f"Imported Data Path: {IMPORTED_DATA_PATH}")
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user