diff --git a/scripts/import_openie.py b/scripts/import_openie.py index 1a36fd240..791c64672 100644 --- a/scripts/import_openie.py +++ b/scripts/import_openie.py @@ -9,19 +9,17 @@ import os from time import sleep sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) - -from src.chat.knowledge.lpmmconfig import PG_NAMESPACE, global_config from src.chat.knowledge.embedding_store import EmbeddingManager -from src.chat.knowledge.llm_client import LLMClient from src.chat.knowledge.open_ie import OpenIE from src.chat.knowledge.kg_manager import KGManager from src.common.logger import get_logger from src.chat.knowledge.utils.hash import get_sha256 +from src.manager.local_store_manager import local_storage # 添加项目根目录到 sys.path ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) -OPENIE_DIR = global_config["persistence"]["openie_data_path"] or os.path.join(ROOT_PATH, "data", "openie") +OPENIE_DIR = os.path.join(ROOT_PATH, "data", "openie") logger = get_logger("OpenIE导入") @@ -63,7 +61,7 @@ def hash_deduplicate( ): # 段落hash paragraph_hash = get_sha256(raw_paragraph) - if f"{PG_NAMESPACE}-{paragraph_hash}" in stored_pg_hashes and paragraph_hash in stored_paragraph_hashes: + if f"{local_storage['pg_namespace']}-{paragraph_hash}" in stored_pg_hashes and paragraph_hash in stored_paragraph_hashes: continue new_raw_paragraphs[paragraph_hash] = raw_paragraph new_triple_list_data[paragraph_hash] = triple_list @@ -193,15 +191,9 @@ def main(): # sourcery skip: dict-comprehension logger.info("----开始导入openie数据----\n") logger.info("创建LLM客户端") - llm_client_list = {} - for key in global_config["llm_providers"]: - llm_client_list[key] = LLMClient( - global_config["llm_providers"][key]["base_url"], - global_config["llm_providers"][key]["api_key"], - ) # 初始化Embedding库 - embed_manager = EmbeddingManager(llm_client_list[global_config["embedding"]["provider"]]) + embed_manager = EmbeddingManager() logger.info("正在从文件加载Embedding库") try: embed_manager.load_from_file() @@ -230,7 +222,7 @@ def main(): # sourcery skip: dict-comprehension # 数据比对:Embedding库与KG的段落hash集合 for pg_hash in kg_manager.stored_paragraph_hashes: - key = f"{PG_NAMESPACE}-{pg_hash}" + key = f"{local_storage['pg_namespace']}-{pg_hash}" if key not in embed_manager.stored_pg_hashes: logger.warning(f"KG中存在Embedding库中不存在的段落:{key}") diff --git a/scripts/info_extraction.py b/scripts/info_extraction.py index b7e2b5592..7370d98d8 100644 --- a/scripts/info_extraction.py +++ b/scripts/info_extraction.py @@ -4,7 +4,6 @@ import signal from concurrent.futures import ThreadPoolExecutor, as_completed from threading import Lock, Event import sys -import glob import datetime sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) @@ -13,11 +12,9 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from rich.progress import Progress # 替换为 rich 进度条 from src.common.logger import get_logger -from src.chat.knowledge.lpmmconfig import global_config +# from src.chat.knowledge.lpmmconfig import global_config from src.chat.knowledge.ie_process import info_extract_from_str -from src.chat.knowledge.llm_client import LLMClient from src.chat.knowledge.open_ie import OpenIE -from src.chat.knowledge.raw_processing import load_raw_data from rich.progress import ( BarColumn, TimeElapsedColumn, @@ -27,24 +24,17 @@ from rich.progress import ( SpinnerColumn, TextColumn, ) +from raw_data_preprocessor import RAW_DATA_PATH, process_multi_files, load_raw_data +from src.config.config import global_config +from src.llm_models.utils_model import LLMRequest logger = get_logger("LPMM知识库-信息提取") ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) TEMP_DIR = os.path.join(ROOT_PATH, "temp") -IMPORTED_DATA_PATH = global_config["persistence"]["imported_data_path"] or os.path.join( - ROOT_PATH, "data", "imported_lpmm_data" -) -OPENIE_OUTPUT_DIR = global_config["persistence"]["openie_data_path"] or os.path.join(ROOT_PATH, "data", "openie") - -# 创建一个线程安全的锁,用于保护文件操作和共享数据 -file_lock = Lock() -open_ie_doc_lock = Lock() - -# 创建一个事件标志,用于控制程序终止 -shutdown_event = Event() - +# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data", "imported_lpmm_data") +OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie") def ensure_dirs(): """确保临时目录和输出目录存在""" @@ -54,12 +44,26 @@ def ensure_dirs(): if not os.path.exists(OPENIE_OUTPUT_DIR): os.makedirs(OPENIE_OUTPUT_DIR) logger.info(f"已创建输出目录: {OPENIE_OUTPUT_DIR}") - if not os.path.exists(IMPORTED_DATA_PATH): - os.makedirs(IMPORTED_DATA_PATH) - logger.info(f"已创建导入数据目录: {IMPORTED_DATA_PATH}") + if not os.path.exists(RAW_DATA_PATH): + os.makedirs(RAW_DATA_PATH) + logger.info(f"已创建原始数据目录: {RAW_DATA_PATH}") +# 创建一个线程安全的锁,用于保护文件操作和共享数据 +file_lock = Lock() +open_ie_doc_lock = Lock() -def process_single_text(pg_hash, raw_data, llm_client_list): +# 创建一个事件标志,用于控制程序终止 +shutdown_event = Event() + +lpmm_entity_extract_llm = LLMRequest( + model=global_config.model.lpmm_entity_extract, + request_type="lpmm.entity_extract" +) +lpmm_rdf_build_llm = LLMRequest( + model=global_config.model.lpmm_rdf_build, + request_type="lpmm.rdf_build" +) +def process_single_text(pg_hash, raw_data): """处理单个文本的函数,用于线程池""" temp_file_path = f"{TEMP_DIR}/{pg_hash}.json" @@ -77,8 +81,8 @@ def process_single_text(pg_hash, raw_data, llm_client_list): os.remove(temp_file_path) entity_list, rdf_triple_list = info_extract_from_str( - llm_client_list[global_config["entity_extract"]["llm"]["provider"]], - llm_client_list[global_config["rdf_build"]["llm"]["provider"]], + lpmm_entity_extract_llm, + lpmm_rdf_build_llm, raw_data, ) if entity_list is None or rdf_triple_list is None: @@ -113,7 +117,7 @@ def signal_handler(_signum, _frame): def main(): # sourcery skip: comprehension-to-generator, extract-method # 设置信号处理器 signal.signal(signal.SIGINT, signal_handler) - + ensure_dirs() # 确保目录存在 # 新增用户确认提示 print("=== 重要操作确认,请认真阅读以下内容哦 ===") print("实体提取操作将会花费较多api余额和时间,建议在空闲时段执行。") @@ -130,50 +134,17 @@ def main(): # sourcery skip: comprehension-to-generator, extract-method ensure_dirs() # 确保目录存在 logger.info("--------进行信息提取--------\n") - logger.info("创建LLM客户端") - llm_client_list = { - key: LLMClient( - global_config["llm_providers"][key]["base_url"], - global_config["llm_providers"][key]["api_key"], - ) - for key in global_config["llm_providers"] - } - # 检查 openie 输出目录 - if not os.path.exists(OPENIE_OUTPUT_DIR): - os.makedirs(OPENIE_OUTPUT_DIR) - logger.info(f"已创建输出目录: {OPENIE_OUTPUT_DIR}") - - # 确保 TEMP_DIR 目录存在 - if not os.path.exists(TEMP_DIR): - os.makedirs(TEMP_DIR) - logger.info(f"已创建缓存目录: {TEMP_DIR}") - - # 遍历IMPORTED_DATA_PATH下所有json文件 - imported_files = sorted(glob.glob(os.path.join(IMPORTED_DATA_PATH, "*.json"))) - if not imported_files: - logger.error(f"未在 {IMPORTED_DATA_PATH} 下找到任何json文件") - sys.exit(1) - - all_sha256_list = [] - all_raw_datas = [] - - for imported_file in imported_files: - logger.info(f"正在处理文件: {imported_file}") - try: - sha256_list, raw_datas = load_raw_data(imported_file) - except Exception as e: - logger.error(f"读取文件失败: {imported_file}, 错误: {e}") - continue - all_sha256_list.extend(sha256_list) - all_raw_datas.extend(raw_datas) + # 加载原始数据 + logger.info("正在加载原始数据") + all_sha256_list, all_raw_datas = load_raw_data() failed_sha256 = [] open_ie_doc = [] - workers = global_config["info_extraction"]["workers"] + workers = global_config.lpmm_knowledge.info_extraction_workers with ThreadPoolExecutor(max_workers=workers) as executor: future_to_hash = { - executor.submit(process_single_text, pg_hash, raw_data, llm_client_list): pg_hash + executor.submit(process_single_text, pg_hash, raw_data): pg_hash for pg_hash, raw_data in zip(all_sha256_list, all_raw_datas, strict=False) } diff --git a/scripts/raw_data_preprocessor.py b/scripts/raw_data_preprocessor.py index ee8960f66..42a99133f 100644 --- a/scripts/raw_data_preprocessor.py +++ b/scripts/raw_data_preprocessor.py @@ -1,40 +1,16 @@ -import json import os from pathlib import Path import sys # 新增系统模块导入 -import datetime # 新增导入 - +from src.chat.knowledge.utils.hash import get_sha256 sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from src.common.logger import get_logger -from src.chat.knowledge.lpmmconfig import global_config logger = get_logger("lpmm") ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) RAW_DATA_PATH = os.path.join(ROOT_PATH, "data/lpmm_raw_data") -# 新增:确保 RAW_DATA_PATH 存在 -if not os.path.exists(RAW_DATA_PATH): - os.makedirs(RAW_DATA_PATH, exist_ok=True) - logger.info(f"已创建目录: {RAW_DATA_PATH}") +# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data/imported_lpmm_data") -if global_config.get("persistence", {}).get("raw_data_path") is not None: - IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, global_config["persistence"]["raw_data_path"]) -else: - IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data/imported_lpmm_data") - -# 添加项目根目录到 sys.path - - -def check_and_create_dirs(): - """检查并创建必要的目录""" - required_dirs = [RAW_DATA_PATH, IMPORTED_DATA_PATH] - - for dir_path in required_dirs: - if not os.path.exists(dir_path): - os.makedirs(dir_path) - logger.info(f"已创建目录: {dir_path}") - - -def process_text_file(file_path): +def _process_text_file(file_path): """处理单个文本文件,返回段落列表""" with open(file_path, "r", encoding="utf-8") as f: raw = f.read() @@ -55,54 +31,45 @@ def process_text_file(file_path): return paragraphs -def main(): - # 新增用户确认提示 - print("=== 数据预处理脚本 ===") - print(f"本脚本将处理 '{RAW_DATA_PATH}' 目录下的所有 .txt 文件。") - print(f"处理后的段落数据将合并,并以 MM-DD-HH-SS-imported-data.json 的格式保存在 '{IMPORTED_DATA_PATH}' 目录中。") - print("请确保原始数据已放置在正确的目录中。") - confirm = input("确认继续执行?(y/n): ").strip().lower() - if confirm != "y": - logger.info("操作已取消") - sys.exit(1) - print("\n" + "=" * 40 + "\n") - - # 检查并创建必要的目录 - check_and_create_dirs() - - # # 检查输出文件是否存在 - # if os.path.exists(RAW_DATA_PATH): - # logger.error("错误: data/import.json 已存在,请先处理或删除该文件") - # sys.exit(1) - - # if os.path.exists(RAW_DATA_PATH): - # logger.error("错误: data/openie.json 已存在,请先处理或删除该文件") - # sys.exit(1) - - # 获取所有原始文本文件 +def _process_multi_files() -> list: raw_files = list(Path(RAW_DATA_PATH).glob("*.txt")) if not raw_files: logger.warning("警告: data/lpmm_raw_data 中没有找到任何 .txt 文件") sys.exit(1) - # 处理所有文件 all_paragraphs = [] for file in raw_files: logger.info(f"正在处理文件: {file.name}") - paragraphs = process_text_file(file) + paragraphs = _process_text_file(file) all_paragraphs.extend(paragraphs) + return all_paragraphs - # 保存合并后的结果到 IMPORTED_DATA_PATH,文件名格式为 MM-DD-HH-ss-imported-data.json - now = datetime.datetime.now() - filename = now.strftime("%m-%d-%H-%S-imported-data.json") - output_path = os.path.join(IMPORTED_DATA_PATH, filename) - with open(output_path, "w", encoding="utf-8") as f: - json.dump(all_paragraphs, f, ensure_ascii=False, indent=4) +def load_raw_data() -> tuple[list[str], list[str]]: + """加载原始数据文件 - logger.info(f"处理完成,结果已保存到: {output_path}") + 读取原始数据文件,将原始数据加载到内存中 + Args: + path: 可选,指定要读取的json文件绝对路径 -if __name__ == "__main__": - logger.info(f"原始数据路径: {RAW_DATA_PATH}") - logger.info(f"处理后的数据路径: {IMPORTED_DATA_PATH}") - main() + Returns: + - raw_data: 原始数据列表 + - sha256_list: 原始数据的SHA256集合 + """ + raw_data = _process_multi_files() + sha256_list = [] + sha256_set = set() + for item in raw_data: + if not isinstance(item, str): + logger.warning(f"数据类型错误:{item}") + continue + pg_hash = get_sha256(item) + if pg_hash in sha256_set: + logger.warning(f"重复数据:{item}") + continue + sha256_set.add(pg_hash) + sha256_list.append(pg_hash) + raw_data.append(item) + logger.info(f"共读取到{len(raw_data)}条数据") + + return sha256_list, raw_data \ No newline at end of file diff --git a/src/chat/knowledge/embedding_store.py b/src/chat/knowledge/embedding_store.py index 1d887e1fe..b827f4b40 100644 --- a/src/chat/knowledge/embedding_store.py +++ b/src/chat/knowledge/embedding_store.py @@ -10,8 +10,8 @@ import pandas as pd # import tqdm import faiss -from .llm_client import LLMClient -from .lpmmconfig import ENT_NAMESPACE, PG_NAMESPACE, REL_NAMESPACE, global_config +# from .llm_client import LLMClient +from .lpmmconfig import global_config from .utils.hash import get_sha256 from .global_logger import logger from rich.traceback import install @@ -25,6 +25,9 @@ from rich.progress import ( SpinnerColumn, TextColumn, ) +from src.manager.local_store_manager import local_storage +from src.chat.utils.utils import get_embedding + install(extra_lines=3) ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) @@ -86,21 +89,20 @@ class EmbeddingStoreItem: class EmbeddingStore: - def __init__(self, llm_client: LLMClient, namespace: str, dir_path: str): + def __init__(self, namespace: str, dir_path: str): self.namespace = namespace - self.llm_client = llm_client self.dir = dir_path - self.embedding_file_path = dir_path + "/" + namespace + ".parquet" - self.index_file_path = dir_path + "/" + namespace + ".index" + self.embedding_file_path = f"{dir_path}/{namespace}.parquet" + self.index_file_path = f"{dir_path}/{namespace}.index" self.idx2hash_file_path = dir_path + "/" + namespace + "_i2h.json" - self.store = dict() + self.store = {} self.faiss_index = None self.idx2hash = None def _get_embedding(self, s: str) -> List[float]: - return self.llm_client.send_embedding_request(global_config["embedding"]["model"], s) + return get_embedding(s) def get_test_file_path(self): return EMBEDDING_TEST_FILE @@ -293,20 +295,17 @@ class EmbeddingStore: class EmbeddingManager: - def __init__(self, llm_client: LLMClient): + def __init__(self): self.paragraphs_embedding_store = EmbeddingStore( - llm_client, - PG_NAMESPACE, + local_storage['pg_namespace'], EMBEDDING_DATA_DIR_STR, ) self.entities_embedding_store = EmbeddingStore( - llm_client, - ENT_NAMESPACE, + local_storage['pg_namespace'], EMBEDDING_DATA_DIR_STR, ) self.relation_embedding_store = EmbeddingStore( - llm_client, - REL_NAMESPACE, + local_storage['pg_namespace'], EMBEDDING_DATA_DIR_STR, ) self.stored_pg_hashes = set() diff --git a/src/chat/knowledge/ie_process.py b/src/chat/knowledge/ie_process.py index f68a848d2..bd0e17684 100644 --- a/src/chat/knowledge/ie_process.py +++ b/src/chat/knowledge/ie_process.py @@ -4,28 +4,35 @@ from typing import List, Union from .global_logger import logger from . import prompt_template -from .lpmmconfig import global_config, INVALID_ENTITY -from .llm_client import LLMClient -from src.chat.knowledge.utils.json_fix import new_fix_broken_generated_json +from .knowledge_lib import INVALID_ENTITY +from src.llm_models.utils_model import LLMRequest +from json_repair import repair_json +def _extract_json_from_text(text: str) -> dict: + """从文本中提取JSON数据的高容错方法""" + try: + fixed_json = repair_json(text) + if isinstance(fixed_json, str): + parsed_json = json.loads(fixed_json) + else: + parsed_json = fixed_json + if isinstance(parsed_json, list) and parsed_json: + parsed_json = parsed_json[0] -def _entity_extract(llm_client: LLMClient, paragraph: str) -> List[str]: + if isinstance(parsed_json, dict): + return parsed_json + + except Exception as e: + logger.error(f"JSON提取失败: {e}, 原始文本: {text[:100]}...") + +def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]: """对段落进行实体提取,返回提取出的实体列表(JSON格式)""" entity_extract_context = prompt_template.build_entity_extract_context(paragraph) - _, request_result = llm_client.send_chat_request( - global_config["entity_extract"]["llm"]["model"], entity_extract_context - ) - - # 去除‘{’前的内容(结果中可能有多个‘{’) - if "[" in request_result: - request_result = request_result[request_result.index("[") :] - - # 去除最后一个‘}’后的内容(结果中可能有多个‘}’) - if "]" in request_result: - request_result = request_result[: request_result.rindex("]") + 1] - - entity_extract_result = json.loads(new_fix_broken_generated_json(request_result)) + response, (reasoning_content, model_name) = llm_req.generate_response_async(entity_extract_context) + entity_extract_result = _extract_json_from_text(response) + # 尝试load JSON数据 + json.loads(entity_extract_result) entity_extract_result = [ entity for entity in entity_extract_result @@ -38,23 +45,16 @@ def _entity_extract(llm_client: LLMClient, paragraph: str) -> List[str]: return entity_extract_result -def _rdf_triple_extract(llm_client: LLMClient, paragraph: str, entities: list) -> List[List[str]]: +def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) -> List[List[str]]: """对段落进行实体提取,返回提取出的实体列表(JSON格式)""" - entity_extract_context = prompt_template.build_rdf_triple_extract_context( + rdf_extract_context = prompt_template.build_rdf_triple_extract_context( paragraph, entities=json.dumps(entities, ensure_ascii=False) ) - _, request_result = llm_client.send_chat_request(global_config["rdf_build"]["llm"]["model"], entity_extract_context) - - # 去除‘{’前的内容(结果中可能有多个‘{’) - if "[" in request_result: - request_result = request_result[request_result.index("[") :] - - # 去除最后一个‘}’后的内容(结果中可能有多个‘}’) - if "]" in request_result: - request_result = request_result[: request_result.rindex("]") + 1] - - entity_extract_result = json.loads(new_fix_broken_generated_json(request_result)) + response, (reasoning_content, model_name) = llm_req.generate_response_async(rdf_extract_context) + entity_extract_result = _extract_json_from_text(response) + # 尝试load JSON数据 + json.loads(entity_extract_result) for triple in entity_extract_result: if len(triple) != 3 or (triple[0] is None or triple[1] is None or triple[2] is None) or "" in triple: raise Exception("RDF提取结果格式错误") @@ -63,7 +63,7 @@ def _rdf_triple_extract(llm_client: LLMClient, paragraph: str, entities: list) - def info_extract_from_str( - llm_client_for_ner: LLMClient, llm_client_for_rdf: LLMClient, paragraph: str + llm_client_for_ner: LLMRequest, llm_client_for_rdf: LLMRequest, paragraph: str ) -> Union[tuple[None, None], tuple[list[str], list[list[str]]]]: try_count = 0 while True: diff --git a/src/chat/knowledge/kg_manager.py b/src/chat/knowledge/kg_manager.py index 1ff651b5e..38f883d0e 100644 --- a/src/chat/knowledge/kg_manager.py +++ b/src/chat/knowledge/kg_manager.py @@ -20,24 +20,37 @@ from quick_algo import di_graph, pagerank from .utils.hash import get_sha256 from .embedding_store import EmbeddingManager, EmbeddingStoreItem -from .lpmmconfig import ( - ENT_NAMESPACE, - PG_NAMESPACE, - RAG_ENT_CNT_NAMESPACE, - RAG_GRAPH_NAMESPACE, - RAG_PG_HASH_NAMESPACE, - global_config, -) +from .lpmmconfig import global_config +from src.manager.local_store_manager import local_storage from .global_logger import logger -ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) -KG_DIR = ( - os.path.join(ROOT_PATH, "data/rag") - if global_config["persistence"]["rag_data_dir"] is None - else os.path.join(ROOT_PATH, global_config["persistence"]["rag_data_dir"]) -) -KG_DIR_STR = str(KG_DIR).replace("\\", "/") + +def _get_kg_dir(): + """ + 安全地获取KG数据目录路径 + """ + root_path = local_storage['root_path'] + if root_path is None: + # 如果 local_storage 中没有 root_path,使用当前文件的相对路径作为备用 + current_dir = os.path.dirname(os.path.abspath(__file__)) + root_path = os.path.abspath(os.path.join(current_dir, "..", "..", "..")) + logger.warning(f"local_storage 中未找到 root_path,使用备用路径: {root_path}") + + # 获取RAG数据目录 + rag_data_dir = global_config["persistence"]["rag_data_dir"] + if rag_data_dir is None: + kg_dir = os.path.join(root_path, "data/rag") + else: + kg_dir = os.path.join(root_path, rag_data_dir) + + return str(kg_dir).replace("\\", "/") + + +# 延迟初始化,避免在模块加载时就访问可能未初始化的 local_storage +def get_kg_dir_str(): + """获取KG目录字符串""" + return _get_kg_dir() class KGManager: @@ -46,15 +59,15 @@ class KGManager: # 存储段落的hash值,用于去重 self.stored_paragraph_hashes = set() # 实体出现次数 - self.ent_appear_cnt = dict() + self.ent_appear_cnt = {} # KG self.graph = di_graph.DiGraph() - # 持久化相关 - self.dir_path = KG_DIR_STR - self.graph_data_path = self.dir_path + "/" + RAG_GRAPH_NAMESPACE + ".graphml" - self.ent_cnt_data_path = self.dir_path + "/" + RAG_ENT_CNT_NAMESPACE + ".parquet" - self.pg_hash_file_path = self.dir_path + "/" + RAG_PG_HASH_NAMESPACE + ".json" + # 持久化相关 - 使用延迟初始化的路径 + self.dir_path = get_kg_dir_str() + self.graph_data_path = self.dir_path + "/" + local_storage['rag_graph_namespace'] + ".graphml" + self.ent_cnt_data_path = self.dir_path + "/" + local_storage['rag_ent_cnt_namespace'] + ".parquet" + self.pg_hash_file_path = self.dir_path + "/" + local_storage['rag_pg_hash_namespace'] + ".json" def save_to_file(self): """将KG数据保存到文件""" @@ -109,8 +122,8 @@ class KGManager: # 避免自连接 continue # 一个triple就是一条边(同时构建双向联系) - hash_key1 = ENT_NAMESPACE + "-" + get_sha256(triple[0]) - hash_key2 = ENT_NAMESPACE + "-" + get_sha256(triple[2]) + hash_key1 = local_storage['ent_namespace'] + "-" + get_sha256(triple[0]) + hash_key2 = local_storage['ent_namespace'] + "-" + get_sha256(triple[2]) node_to_node[(hash_key1, hash_key2)] = node_to_node.get((hash_key1, hash_key2), 0) + 1.0 node_to_node[(hash_key2, hash_key1)] = node_to_node.get((hash_key2, hash_key1), 0) + 1.0 entity_set.add(hash_key1) @@ -128,8 +141,8 @@ class KGManager: """构建实体节点与文段节点之间的关系""" for idx in triple_list_data: for triple in triple_list_data[idx]: - ent_hash_key = ENT_NAMESPACE + "-" + get_sha256(triple[0]) - pg_hash_key = PG_NAMESPACE + "-" + str(idx) + ent_hash_key = local_storage['ent_namespace'] + "-" + get_sha256(triple[0]) + pg_hash_key = local_storage['pg_namespace'] + "-" + str(idx) node_to_node[(ent_hash_key, pg_hash_key)] = node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0 @staticmethod @@ -144,8 +157,8 @@ class KGManager: ent_hash_list = set() for triple_list in triple_list_data.values(): for triple in triple_list: - ent_hash_list.add(ENT_NAMESPACE + "-" + get_sha256(triple[0])) - ent_hash_list.add(ENT_NAMESPACE + "-" + get_sha256(triple[2])) + ent_hash_list.add(local_storage['ent_namespace'] + "-" + get_sha256(triple[0])) + ent_hash_list.add(local_storage['ent_namespace'] + "-" + get_sha256(triple[2])) ent_hash_list = list(ent_hash_list) synonym_hash_set = set() @@ -250,7 +263,7 @@ class KGManager: for src_tgt in node_to_node.keys(): for node_hash in src_tgt: if node_hash not in existed_nodes: - if node_hash.startswith(ENT_NAMESPACE): + if node_hash.startswith(local_storage['ent_namespace']): # 新增实体节点 node = embedding_manager.entities_embedding_store.store[node_hash] assert isinstance(node, EmbeddingStoreItem) @@ -259,7 +272,7 @@ class KGManager: node_item["type"] = "ent" node_item["create_time"] = now_time self.graph.update_node(node_item) - elif node_hash.startswith(PG_NAMESPACE): + elif node_hash.startswith(local_storage['pg_namespace']): # 新增文段节点 node = embedding_manager.paragraphs_embedding_store.store[node_hash] assert isinstance(node, EmbeddingStoreItem) @@ -340,7 +353,7 @@ class KGManager: # 关系三元组 triple = relation[2:-2].split("', '") for ent in [(triple[0]), (triple[2])]: - ent_hash = ENT_NAMESPACE + "-" + get_sha256(ent) + ent_hash = local_storage['ent_namespace'] + "-" + get_sha256(ent) if ent_hash in existed_nodes: # 该实体需在KG中存在 if ent_hash not in ent_sim_scores: # 尚未记录的实体 ent_sim_scores[ent_hash] = [] @@ -418,7 +431,7 @@ class KGManager: # 获取最终结果 # 从搜索结果中提取文段节点的结果 passage_node_res = [ - (node_key, score) for node_key, score in ppr_res.items() if node_key.startswith(PG_NAMESPACE) + (node_key, score) for node_key, score in ppr_res.items() if node_key.startswith(local_storage['pg_namespace']) ] del ppr_res diff --git a/src/chat/knowledge/knowledge_lib.py b/src/chat/knowledge/knowledge_lib.py index 5540d95e2..87a373a5c 100644 --- a/src/chat/knowledge/knowledge_lib.py +++ b/src/chat/knowledge/knowledge_lib.py @@ -1,4 +1,4 @@ -from src.chat.knowledge.lpmmconfig import PG_NAMESPACE, global_config +from src.chat.knowledge.lpmmconfig import global_config from src.chat.knowledge.embedding_store import EmbeddingManager from src.chat.knowledge.llm_client import LLMClient from src.chat.knowledge.mem_active_manager import MemoryActiveManager @@ -6,10 +6,80 @@ from src.chat.knowledge.qa_manager import QAManager from src.chat.knowledge.kg_manager import KGManager from src.chat.knowledge.global_logger import logger from src.config.config import global_config as bot_global_config -# try: -# import quick_algo -# except ImportError: -# print("quick_algo not found, please install it first") +from src.manager.local_store_manager import local_storage +import os + +INVALID_ENTITY = [ + "", + "你", + "他", + "她", + "它", + "我们", + "你们", + "他们", + "她们", + "它们", +] +PG_NAMESPACE = "paragraph" +ENT_NAMESPACE = "entity" +REL_NAMESPACE = "relation" + +RAG_GRAPH_NAMESPACE = "rag-graph" +RAG_ENT_CNT_NAMESPACE = "rag-ent-cnt" +RAG_PG_HASH_NAMESPACE = "rag-pg-hash" + + +ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) +DATA_PATH = os.path.join(ROOT_PATH, "data") + +def _initialize_knowledge_local_storage(): + """ + 初始化知识库相关的本地存储配置 + 使用字典批量设置,避免重复的if判断 + """ + # 定义所有需要初始化的配置项 + default_configs = { + # 路径配置 + 'root_path': ROOT_PATH, + 'data_path': f"{ROOT_PATH}/data", + + # 实体和命名空间配置 + 'lpmm_invalid_entity': INVALID_ENTITY, + 'pg_namespace': PG_NAMESPACE, + 'ent_namespace': ENT_NAMESPACE, + 'rel_namespace': REL_NAMESPACE, + + # RAG相关命名空间配置 + 'rag_graph_namespace': RAG_GRAPH_NAMESPACE, + 'rag_ent_cnt_namespace': RAG_ENT_CNT_NAMESPACE, + 'rag_pg_hash_namespace': RAG_PG_HASH_NAMESPACE + } + + # 日志级别映射:重要配置用info,其他用debug + important_configs = {'root_path', 'data_path'} + + # 批量设置配置项 + initialized_count = 0 + for key, default_value in default_configs.items(): + if local_storage[key] is None: + local_storage[key] = default_value + + # 根据重要性选择日志级别 + if key in important_configs: + logger.info(f"设置{key}: {default_value}") + else: + logger.debug(f"设置{key}: {default_value}") + + initialized_count += 1 + + if initialized_count > 0: + logger.info(f"知识库本地存储初始化完成,共设置 {initialized_count} 项配置") + else: + logger.debug("知识库本地存储配置已存在,跳过初始化") + +# 初始化本地存储路径 +_initialize_knowledge_local_storage() # 检查LPMM知识库是否启用 if bot_global_config.lpmm_knowledge.enable: @@ -23,7 +93,7 @@ if bot_global_config.lpmm_knowledge.enable: ) # 初始化Embedding库 - embed_manager = EmbeddingManager(llm_client_list[global_config["embedding"]["provider"]]) + embed_manager = EmbeddingManager() logger.info("正在从文件加载Embedding库") try: embed_manager.load_from_file() @@ -54,9 +124,6 @@ if bot_global_config.lpmm_knowledge.enable: qa_manager = QAManager( embed_manager, kg_manager, - llm_client_list[global_config["embedding"]["provider"]], - llm_client_list[global_config["qa"]["llm"]["provider"]], - llm_client_list[global_config["qa"]["llm"]["provider"]], ) # 记忆激活(用于记忆库) diff --git a/src/chat/knowledge/open_ie.py b/src/chat/knowledge/open_ie.py index 7bb96d131..90977fb88 100644 --- a/src/chat/knowledge/open_ie.py +++ b/src/chat/knowledge/open_ie.py @@ -4,9 +4,8 @@ import glob from typing import Any, Dict, List -from .lpmmconfig import INVALID_ENTITY, global_config - -ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) +from .knowledge_lib import INVALID_ENTITY, ROOT_PATH, DATA_PATH +# from src.manager.local_store_manager import local_storage def _filter_invalid_entities(entities: List[str]) -> List[str]: @@ -107,7 +106,7 @@ class OpenIE: @staticmethod def load() -> "OpenIE": """从OPENIE_DIR下所有json文件合并加载OpenIE数据""" - openie_dir = os.path.join(ROOT_PATH, global_config["persistence"]["openie_data_path"]) + openie_dir = os.path.join(DATA_PATH, "openie") if not os.path.exists(openie_dir): raise Exception(f"OpenIE数据目录不存在: {openie_dir}") json_files = sorted(glob.glob(os.path.join(openie_dir, "*.json"))) @@ -122,12 +121,6 @@ class OpenIE: openie_data = OpenIE._from_dict(data_list) return openie_data - @staticmethod - def save(openie_data: "OpenIE"): - """保存OpenIE数据到文件""" - with open(global_config["persistence"]["openie_data_path"], "w", encoding="utf-8") as f: - f.write(json.dumps(openie_data._to_dict(), ensure_ascii=False, indent=4)) - def extract_entity_dict(self): """提取实体列表""" ner_output_dict = dict( diff --git a/src/chat/knowledge/qa_manager.py b/src/chat/knowledge/qa_manager.py index 01a3e82d3..8940dbb55 100644 --- a/src/chat/knowledge/qa_manager.py +++ b/src/chat/knowledge/qa_manager.py @@ -5,11 +5,13 @@ from .global_logger import logger # from . import prompt_template from .embedding_store import EmbeddingManager -from .llm_client import LLMClient +# from .llm_client import LLMClient from .kg_manager import KGManager -from .lpmmconfig import global_config +# from .lpmmconfig import global_config from .utils.dyn_topk import dyn_select_top_k - +from src.llm_models.utils_model import LLMRequest +from src.chat.utils.utils import get_embedding +from src.config.config import global_config MAX_KNOWLEDGE_LENGTH = 10000 # 最大知识长度 @@ -19,26 +21,25 @@ class QAManager: self, embed_manager: EmbeddingManager, kg_manager: KGManager, - llm_client_embedding: LLMClient, - llm_client_filter: LLMClient, - llm_client_qa: LLMClient, + ): self.embed_manager = embed_manager self.kg_manager = kg_manager - self.llm_client_list = { - "embedding": llm_client_embedding, - "message_filter": llm_client_filter, - "qa": llm_client_qa, - } + # TODO: API-Adapter修改标记 + self.qa_model = LLMRequest( + model=global_config.model.lpmm_qa, + request_type="lpmm.qa" + ) def process_query(self, question: str) -> Tuple[List[Tuple[str, float, float]], Optional[Dict[str, float]]]: """处理查询""" # 生成问题的Embedding part_start_time = time.perf_counter() - question_embedding = self.llm_client_list["embedding"].send_embedding_request( - global_config["embedding"]["model"], question - ) + question_embedding = get_embedding(question) + if question_embedding is None: + logger.error("生成问题Embedding失败") + return None part_end_time = time.perf_counter() logger.debug(f"Embedding用时:{part_end_time - part_start_time:.5f}s") @@ -46,14 +47,15 @@ class QAManager: part_start_time = time.perf_counter() relation_search_res = self.embed_manager.relation_embedding_store.search_top_k( question_embedding, - global_config["qa"]["params"]["relation_search_top_k"], + global_config.lpmm_knowledge.qa_relation_search_top_k, ) if relation_search_res is not None: # 过滤阈值 # 考虑动态阈值:当存在显著数值差异的结果时,保留显著结果;否则,保留所有结果 relation_search_res = dyn_select_top_k(relation_search_res, 0.5, 1.0) - if relation_search_res[0][1] < global_config["qa"]["params"]["relation_threshold"]: + if relation_search_res[0][1] < global_config.lpmm_knowledge.qa_relation_threshold: # 未找到相关关系 + logger.debug("未找到相关关系,跳过关系检索") relation_search_res = [] part_end_time = time.perf_counter() @@ -71,7 +73,7 @@ class QAManager: part_start_time = time.perf_counter() paragraph_search_res = self.embed_manager.paragraphs_embedding_store.search_top_k( question_embedding, - global_config["qa"]["params"]["paragraph_search_top_k"], + global_config.lpmm_knowledge.qa_paragraph_search_top_k, ) part_end_time = time.perf_counter() logger.debug(f"文段检索用时:{part_end_time - part_start_time:.5f}s") diff --git a/src/config/official_configs.py b/src/config/official_configs.py index bb40687b6..4462daba7 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -625,3 +625,12 @@ class ModelConfig(ConfigBase): embedding: dict[str, Any] = field(default_factory=lambda: {}) """嵌入模型配置""" + + lpmm_entity_extract: dict[str, Any] = field(default_factory=lambda: {}) + """LPMM实体提取模型配置""" + + lpmm_rdf_build: dict[str, Any] = field(default_factory=lambda: {}) + """LPMM RDF构建模型配置""" + + lpmm_qa: dict[str, Any] = field(default_factory=lambda: {}) + """LPMM问答模型配置"""