From f15e074ccac26de971e8f7d67bbe4ad71b56ca9a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Tue, 15 Jul 2025 16:54:25 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E9=87=8D=E6=9E=84=E4=BF=A1=E6=81=AF?= =?UTF-8?q?=E6=8F=90=E5=8F=96=E6=A8=A1=E5=9D=97=EF=BC=8C=E7=A7=BB=E9=99=A4?= =?UTF-8?q?LLMClient=E4=BE=9D=E8=B5=96=EF=BC=8C=E6=94=B9=E4=B8=BA=E4=BD=BF?= =?UTF-8?q?=E7=94=A8LLMRequest=EF=BC=8C=E4=BC=98=E5=8C=96=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=8A=A0=E8=BD=BD=E5=92=8C=E5=A4=84=E7=90=86=E9=80=BB?= =?UTF-8?q?=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/info_extraction.py | 74 +++++++------------- scripts/raw_data_preprocessor.py | 97 +++++++++------------------ src/chat/knowledge/embedding_store.py | 4 +- src/chat/knowledge/ie_process.py | 58 ++++++++-------- src/chat/knowledge/knowledge_lib.py | 5 +- src/chat/knowledge/open_ie.py | 13 +--- src/chat/knowledge/qa_manager.py | 36 +++++----- src/config/official_configs.py | 9 +++ 8 files changed, 119 insertions(+), 177 deletions(-) diff --git a/scripts/info_extraction.py b/scripts/info_extraction.py index b7e2b5592..4a77fd5b3 100644 --- a/scripts/info_extraction.py +++ b/scripts/info_extraction.py @@ -13,11 +13,10 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from rich.progress import Progress # 替换为 rich 进度条 from src.common.logger import get_logger -from src.chat.knowledge.lpmmconfig import global_config -from src.chat.knowledge.ie_process import info_extract_from_str +# from src.chat.knowledge.lpmmconfig import global_config +from src.chat.knowledge.ie_process import _entity_extract, info_extract_from_str from src.chat.knowledge.llm_client import LLMClient from src.chat.knowledge.open_ie import OpenIE -from src.chat.knowledge.raw_processing import load_raw_data from rich.progress import ( BarColumn, TimeElapsedColumn, @@ -27,16 +26,17 @@ from rich.progress import ( SpinnerColumn, TextColumn, ) +from raw_data_preprocessor import process_multi_files, load_raw_data +from src.config.config import global_config +from src.llm_models.utils_model import LLMRequest logger = get_logger("LPMM知识库-信息提取") ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) TEMP_DIR = os.path.join(ROOT_PATH, "temp") -IMPORTED_DATA_PATH = global_config["persistence"]["imported_data_path"] or os.path.join( - ROOT_PATH, "data", "imported_lpmm_data" -) -OPENIE_OUTPUT_DIR = global_config["persistence"]["openie_data_path"] or os.path.join(ROOT_PATH, "data", "openie") +# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data", "imported_lpmm_data") +OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie") # 创建一个线程安全的锁,用于保护文件操作和共享数据 file_lock = Lock() @@ -45,6 +45,14 @@ open_ie_doc_lock = Lock() # 创建一个事件标志,用于控制程序终止 shutdown_event = Event() +lpmm_entity_extract_llm = LLMRequest( + model=global_config.model.lpmm_entity_extract, + request_type="lpmm.entity_extract" +) +lpmm_rdf_build_llm = LLMRequest( + model=global_config.model.lpmm_rdf_build, + request_type="lpmm.rdf_build" +) def ensure_dirs(): """确保临时目录和输出目录存在""" @@ -54,12 +62,9 @@ def ensure_dirs(): if not os.path.exists(OPENIE_OUTPUT_DIR): os.makedirs(OPENIE_OUTPUT_DIR) logger.info(f"已创建输出目录: {OPENIE_OUTPUT_DIR}") - if not os.path.exists(IMPORTED_DATA_PATH): - os.makedirs(IMPORTED_DATA_PATH) - logger.info(f"已创建导入数据目录: {IMPORTED_DATA_PATH}") -def process_single_text(pg_hash, raw_data, llm_client_list): +def process_single_text(pg_hash, raw_data): """处理单个文本的函数,用于线程池""" temp_file_path = f"{TEMP_DIR}/{pg_hash}.json" @@ -77,8 +82,8 @@ def process_single_text(pg_hash, raw_data, llm_client_list): os.remove(temp_file_path) entity_list, rdf_triple_list = info_extract_from_str( - llm_client_list[global_config["entity_extract"]["llm"]["provider"]], - llm_client_list[global_config["rdf_build"]["llm"]["provider"]], + lpmm_entity_extract_llm, + lpmm_rdf_build_llm, raw_data, ) if entity_list is None or rdf_triple_list is None: @@ -130,50 +135,17 @@ def main(): # sourcery skip: comprehension-to-generator, extract-method ensure_dirs() # 确保目录存在 logger.info("--------进行信息提取--------\n") - logger.info("创建LLM客户端") - llm_client_list = { - key: LLMClient( - global_config["llm_providers"][key]["base_url"], - global_config["llm_providers"][key]["api_key"], - ) - for key in global_config["llm_providers"] - } - # 检查 openie 输出目录 - if not os.path.exists(OPENIE_OUTPUT_DIR): - os.makedirs(OPENIE_OUTPUT_DIR) - logger.info(f"已创建输出目录: {OPENIE_OUTPUT_DIR}") - - # 确保 TEMP_DIR 目录存在 - if not os.path.exists(TEMP_DIR): - os.makedirs(TEMP_DIR) - logger.info(f"已创建缓存目录: {TEMP_DIR}") - - # 遍历IMPORTED_DATA_PATH下所有json文件 - imported_files = sorted(glob.glob(os.path.join(IMPORTED_DATA_PATH, "*.json"))) - if not imported_files: - logger.error(f"未在 {IMPORTED_DATA_PATH} 下找到任何json文件") - sys.exit(1) - - all_sha256_list = [] - all_raw_datas = [] - - for imported_file in imported_files: - logger.info(f"正在处理文件: {imported_file}") - try: - sha256_list, raw_datas = load_raw_data(imported_file) - except Exception as e: - logger.error(f"读取文件失败: {imported_file}, 错误: {e}") - continue - all_sha256_list.extend(sha256_list) - all_raw_datas.extend(raw_datas) + # 加载原始数据 + logger.info("正在加载原始数据") + all_sha256_list, all_raw_datas = load_raw_data() failed_sha256 = [] open_ie_doc = [] - workers = global_config["info_extraction"]["workers"] + workers = global_config.lpmm_knowledge.info_extraction_workers with ThreadPoolExecutor(max_workers=workers) as executor: future_to_hash = { - executor.submit(process_single_text, pg_hash, raw_data, llm_client_list): pg_hash + executor.submit(process_single_text, pg_hash, raw_data): pg_hash for pg_hash, raw_data in zip(all_sha256_list, all_raw_datas, strict=False) } diff --git a/scripts/raw_data_preprocessor.py b/scripts/raw_data_preprocessor.py index ee8960f66..42a99133f 100644 --- a/scripts/raw_data_preprocessor.py +++ b/scripts/raw_data_preprocessor.py @@ -1,40 +1,16 @@ -import json import os from pathlib import Path import sys # 新增系统模块导入 -import datetime # 新增导入 - +from src.chat.knowledge.utils.hash import get_sha256 sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from src.common.logger import get_logger -from src.chat.knowledge.lpmmconfig import global_config logger = get_logger("lpmm") ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) RAW_DATA_PATH = os.path.join(ROOT_PATH, "data/lpmm_raw_data") -# 新增:确保 RAW_DATA_PATH 存在 -if not os.path.exists(RAW_DATA_PATH): - os.makedirs(RAW_DATA_PATH, exist_ok=True) - logger.info(f"已创建目录: {RAW_DATA_PATH}") +# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data/imported_lpmm_data") -if global_config.get("persistence", {}).get("raw_data_path") is not None: - IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, global_config["persistence"]["raw_data_path"]) -else: - IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data/imported_lpmm_data") - -# 添加项目根目录到 sys.path - - -def check_and_create_dirs(): - """检查并创建必要的目录""" - required_dirs = [RAW_DATA_PATH, IMPORTED_DATA_PATH] - - for dir_path in required_dirs: - if not os.path.exists(dir_path): - os.makedirs(dir_path) - logger.info(f"已创建目录: {dir_path}") - - -def process_text_file(file_path): +def _process_text_file(file_path): """处理单个文本文件,返回段落列表""" with open(file_path, "r", encoding="utf-8") as f: raw = f.read() @@ -55,54 +31,45 @@ def process_text_file(file_path): return paragraphs -def main(): - # 新增用户确认提示 - print("=== 数据预处理脚本 ===") - print(f"本脚本将处理 '{RAW_DATA_PATH}' 目录下的所有 .txt 文件。") - print(f"处理后的段落数据将合并,并以 MM-DD-HH-SS-imported-data.json 的格式保存在 '{IMPORTED_DATA_PATH}' 目录中。") - print("请确保原始数据已放置在正确的目录中。") - confirm = input("确认继续执行?(y/n): ").strip().lower() - if confirm != "y": - logger.info("操作已取消") - sys.exit(1) - print("\n" + "=" * 40 + "\n") - - # 检查并创建必要的目录 - check_and_create_dirs() - - # # 检查输出文件是否存在 - # if os.path.exists(RAW_DATA_PATH): - # logger.error("错误: data/import.json 已存在,请先处理或删除该文件") - # sys.exit(1) - - # if os.path.exists(RAW_DATA_PATH): - # logger.error("错误: data/openie.json 已存在,请先处理或删除该文件") - # sys.exit(1) - - # 获取所有原始文本文件 +def _process_multi_files() -> list: raw_files = list(Path(RAW_DATA_PATH).glob("*.txt")) if not raw_files: logger.warning("警告: data/lpmm_raw_data 中没有找到任何 .txt 文件") sys.exit(1) - # 处理所有文件 all_paragraphs = [] for file in raw_files: logger.info(f"正在处理文件: {file.name}") - paragraphs = process_text_file(file) + paragraphs = _process_text_file(file) all_paragraphs.extend(paragraphs) + return all_paragraphs - # 保存合并后的结果到 IMPORTED_DATA_PATH,文件名格式为 MM-DD-HH-ss-imported-data.json - now = datetime.datetime.now() - filename = now.strftime("%m-%d-%H-%S-imported-data.json") - output_path = os.path.join(IMPORTED_DATA_PATH, filename) - with open(output_path, "w", encoding="utf-8") as f: - json.dump(all_paragraphs, f, ensure_ascii=False, indent=4) +def load_raw_data() -> tuple[list[str], list[str]]: + """加载原始数据文件 - logger.info(f"处理完成,结果已保存到: {output_path}") + 读取原始数据文件,将原始数据加载到内存中 + Args: + path: 可选,指定要读取的json文件绝对路径 -if __name__ == "__main__": - logger.info(f"原始数据路径: {RAW_DATA_PATH}") - logger.info(f"处理后的数据路径: {IMPORTED_DATA_PATH}") - main() + Returns: + - raw_data: 原始数据列表 + - sha256_list: 原始数据的SHA256集合 + """ + raw_data = _process_multi_files() + sha256_list = [] + sha256_set = set() + for item in raw_data: + if not isinstance(item, str): + logger.warning(f"数据类型错误:{item}") + continue + pg_hash = get_sha256(item) + if pg_hash in sha256_set: + logger.warning(f"重复数据:{item}") + continue + sha256_set.add(pg_hash) + sha256_list.append(pg_hash) + raw_data.append(item) + logger.info(f"共读取到{len(raw_data)}条数据") + + return sha256_list, raw_data \ No newline at end of file diff --git a/src/chat/knowledge/embedding_store.py b/src/chat/knowledge/embedding_store.py index 80b1c02c2..b827f4b40 100644 --- a/src/chat/knowledge/embedding_store.py +++ b/src/chat/knowledge/embedding_store.py @@ -10,7 +10,7 @@ import pandas as pd # import tqdm import faiss -from .llm_client import LLMClient +# from .llm_client import LLMClient from .lpmmconfig import global_config from .utils.hash import get_sha256 from .global_logger import logger @@ -295,7 +295,7 @@ class EmbeddingStore: class EmbeddingManager: - def __init__(self, llm_client: LLMClient): + def __init__(self): self.paragraphs_embedding_store = EmbeddingStore( local_storage['pg_namespace'], EMBEDDING_DATA_DIR_STR, diff --git a/src/chat/knowledge/ie_process.py b/src/chat/knowledge/ie_process.py index f68a848d2..4314ca5e6 100644 --- a/src/chat/knowledge/ie_process.py +++ b/src/chat/knowledge/ie_process.py @@ -7,25 +7,34 @@ from . import prompt_template from .lpmmconfig import global_config, INVALID_ENTITY from .llm_client import LLMClient from src.chat.knowledge.utils.json_fix import new_fix_broken_generated_json +from src.llm_models.utils_model import LLMRequest +from json_repair import repair_json +def _extract_json_from_text(text: str) -> dict: + """从文本中提取JSON数据的高容错方法""" + try: + fixed_json = repair_json(text) + if isinstance(fixed_json, str): + parsed_json = json.loads(fixed_json) + else: + parsed_json = fixed_json + if isinstance(parsed_json, list) and parsed_json: + parsed_json = parsed_json[0] -def _entity_extract(llm_client: LLMClient, paragraph: str) -> List[str]: + if isinstance(parsed_json, dict): + return parsed_json + + except Exception as e: + logger.error(f"JSON提取失败: {e}, 原始文本: {text[:100]}...") + +def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]: """对段落进行实体提取,返回提取出的实体列表(JSON格式)""" entity_extract_context = prompt_template.build_entity_extract_context(paragraph) - _, request_result = llm_client.send_chat_request( - global_config["entity_extract"]["llm"]["model"], entity_extract_context - ) - - # 去除‘{’前的内容(结果中可能有多个‘{’) - if "[" in request_result: - request_result = request_result[request_result.index("[") :] - - # 去除最后一个‘}’后的内容(结果中可能有多个‘}’) - if "]" in request_result: - request_result = request_result[: request_result.rindex("]") + 1] - - entity_extract_result = json.loads(new_fix_broken_generated_json(request_result)) + response, (reasoning_content, model_name) = llm_req.generate_response_async(entity_extract_context) + entity_extract_result = _extract_json_from_text(response) + # 尝试load JSON数据 + json.loads(entity_extract_result) entity_extract_result = [ entity for entity in entity_extract_result @@ -38,23 +47,16 @@ def _entity_extract(llm_client: LLMClient, paragraph: str) -> List[str]: return entity_extract_result -def _rdf_triple_extract(llm_client: LLMClient, paragraph: str, entities: list) -> List[List[str]]: +def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) -> List[List[str]]: """对段落进行实体提取,返回提取出的实体列表(JSON格式)""" - entity_extract_context = prompt_template.build_rdf_triple_extract_context( + rdf_extract_context = prompt_template.build_rdf_triple_extract_context( paragraph, entities=json.dumps(entities, ensure_ascii=False) ) - _, request_result = llm_client.send_chat_request(global_config["rdf_build"]["llm"]["model"], entity_extract_context) - - # 去除‘{’前的内容(结果中可能有多个‘{’) - if "[" in request_result: - request_result = request_result[request_result.index("[") :] - - # 去除最后一个‘}’后的内容(结果中可能有多个‘}’) - if "]" in request_result: - request_result = request_result[: request_result.rindex("]") + 1] - - entity_extract_result = json.loads(new_fix_broken_generated_json(request_result)) + response, (reasoning_content, model_name) = llm_req.generate_response_async(rdf_extract_context) + entity_extract_result = _extract_json_from_text(response) + # 尝试load JSON数据 + json.loads(entity_extract_result) for triple in entity_extract_result: if len(triple) != 3 or (triple[0] is None or triple[1] is None or triple[2] is None) or "" in triple: raise Exception("RDF提取结果格式错误") @@ -63,7 +65,7 @@ def _rdf_triple_extract(llm_client: LLMClient, paragraph: str, entities: list) - def info_extract_from_str( - llm_client_for_ner: LLMClient, llm_client_for_rdf: LLMClient, paragraph: str + llm_client_for_ner: LLMRequest, llm_client_for_rdf: LLMRequest, paragraph: str ) -> Union[tuple[None, None], tuple[list[str], list[list[str]]]]: try_count = 0 while True: diff --git a/src/chat/knowledge/knowledge_lib.py b/src/chat/knowledge/knowledge_lib.py index f9ea7e731..09a1a08e3 100644 --- a/src/chat/knowledge/knowledge_lib.py +++ b/src/chat/knowledge/knowledge_lib.py @@ -31,6 +31,7 @@ RAG_PG_HASH_NAMESPACE = "rag-pg-hash" ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) +DATA_PATH = os.path.join(ROOT_PATH, "data") def _initialize_knowledge_local_storage(): """ @@ -42,10 +43,6 @@ def _initialize_knowledge_local_storage(): # 路径配置 'root_path': ROOT_PATH, 'data_path': f"{ROOT_PATH}/data", - 'lpmm_raw_data_path': f"{ROOT_PATH}/data/raw_data", - 'lpmm_openie_data_path': f"{ROOT_PATH}/data/openie", - 'lpmm_embedding_data_dir': f"{ROOT_PATH}/data/embedding", - 'lpmm_rag_data_dir': f"{ROOT_PATH}/data/rag", # 实体和命名空间配置 'lpmm_invalid_entity': INVALID_ENTITY, diff --git a/src/chat/knowledge/open_ie.py b/src/chat/knowledge/open_ie.py index 7bb96d131..90977fb88 100644 --- a/src/chat/knowledge/open_ie.py +++ b/src/chat/knowledge/open_ie.py @@ -4,9 +4,8 @@ import glob from typing import Any, Dict, List -from .lpmmconfig import INVALID_ENTITY, global_config - -ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) +from .knowledge_lib import INVALID_ENTITY, ROOT_PATH, DATA_PATH +# from src.manager.local_store_manager import local_storage def _filter_invalid_entities(entities: List[str]) -> List[str]: @@ -107,7 +106,7 @@ class OpenIE: @staticmethod def load() -> "OpenIE": """从OPENIE_DIR下所有json文件合并加载OpenIE数据""" - openie_dir = os.path.join(ROOT_PATH, global_config["persistence"]["openie_data_path"]) + openie_dir = os.path.join(DATA_PATH, "openie") if not os.path.exists(openie_dir): raise Exception(f"OpenIE数据目录不存在: {openie_dir}") json_files = sorted(glob.glob(os.path.join(openie_dir, "*.json"))) @@ -122,12 +121,6 @@ class OpenIE: openie_data = OpenIE._from_dict(data_list) return openie_data - @staticmethod - def save(openie_data: "OpenIE"): - """保存OpenIE数据到文件""" - with open(global_config["persistence"]["openie_data_path"], "w", encoding="utf-8") as f: - f.write(json.dumps(openie_data._to_dict(), ensure_ascii=False, indent=4)) - def extract_entity_dict(self): """提取实体列表""" ner_output_dict = dict( diff --git a/src/chat/knowledge/qa_manager.py b/src/chat/knowledge/qa_manager.py index 01a3e82d3..8940dbb55 100644 --- a/src/chat/knowledge/qa_manager.py +++ b/src/chat/knowledge/qa_manager.py @@ -5,11 +5,13 @@ from .global_logger import logger # from . import prompt_template from .embedding_store import EmbeddingManager -from .llm_client import LLMClient +# from .llm_client import LLMClient from .kg_manager import KGManager -from .lpmmconfig import global_config +# from .lpmmconfig import global_config from .utils.dyn_topk import dyn_select_top_k - +from src.llm_models.utils_model import LLMRequest +from src.chat.utils.utils import get_embedding +from src.config.config import global_config MAX_KNOWLEDGE_LENGTH = 10000 # 最大知识长度 @@ -19,26 +21,25 @@ class QAManager: self, embed_manager: EmbeddingManager, kg_manager: KGManager, - llm_client_embedding: LLMClient, - llm_client_filter: LLMClient, - llm_client_qa: LLMClient, + ): self.embed_manager = embed_manager self.kg_manager = kg_manager - self.llm_client_list = { - "embedding": llm_client_embedding, - "message_filter": llm_client_filter, - "qa": llm_client_qa, - } + # TODO: API-Adapter修改标记 + self.qa_model = LLMRequest( + model=global_config.model.lpmm_qa, + request_type="lpmm.qa" + ) def process_query(self, question: str) -> Tuple[List[Tuple[str, float, float]], Optional[Dict[str, float]]]: """处理查询""" # 生成问题的Embedding part_start_time = time.perf_counter() - question_embedding = self.llm_client_list["embedding"].send_embedding_request( - global_config["embedding"]["model"], question - ) + question_embedding = get_embedding(question) + if question_embedding is None: + logger.error("生成问题Embedding失败") + return None part_end_time = time.perf_counter() logger.debug(f"Embedding用时:{part_end_time - part_start_time:.5f}s") @@ -46,14 +47,15 @@ class QAManager: part_start_time = time.perf_counter() relation_search_res = self.embed_manager.relation_embedding_store.search_top_k( question_embedding, - global_config["qa"]["params"]["relation_search_top_k"], + global_config.lpmm_knowledge.qa_relation_search_top_k, ) if relation_search_res is not None: # 过滤阈值 # 考虑动态阈值:当存在显著数值差异的结果时,保留显著结果;否则,保留所有结果 relation_search_res = dyn_select_top_k(relation_search_res, 0.5, 1.0) - if relation_search_res[0][1] < global_config["qa"]["params"]["relation_threshold"]: + if relation_search_res[0][1] < global_config.lpmm_knowledge.qa_relation_threshold: # 未找到相关关系 + logger.debug("未找到相关关系,跳过关系检索") relation_search_res = [] part_end_time = time.perf_counter() @@ -71,7 +73,7 @@ class QAManager: part_start_time = time.perf_counter() paragraph_search_res = self.embed_manager.paragraphs_embedding_store.search_top_k( question_embedding, - global_config["qa"]["params"]["paragraph_search_top_k"], + global_config.lpmm_knowledge.qa_paragraph_search_top_k, ) part_end_time = time.perf_counter() logger.debug(f"文段检索用时:{part_end_time - part_start_time:.5f}s") diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 4433bae44..9dc962ef7 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -627,3 +627,12 @@ class ModelConfig(ConfigBase): embedding: dict[str, Any] = field(default_factory=lambda: {}) """嵌入模型配置""" + + lpmm_entity_extract: dict[str, Any] = field(default_factory=lambda: {}) + """LPMM实体提取模型配置""" + + lpmm_rdf_build: dict[str, Any] = field(default_factory=lambda: {}) + """LPMM RDF构建模型配置""" + + lpmm_qa: dict[str, Any] = field(default_factory=lambda: {}) + """LPMM问答模型配置"""