diff --git a/scripts/import_openie.py b/scripts/import_openie.py index fc677877f..94b6ef48f 100644 --- a/scripts/import_openie.py +++ b/scripts/import_openie.py @@ -10,13 +10,14 @@ from time import sleep sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) -from src.chat.knowledge.lpmmconfig import PG_NAMESPACE, global_config +from src.chat.knowledge.lpmmconfig import global_config from src.chat.knowledge.embedding_store import EmbeddingManager from src.chat.knowledge.llm_client import LLMClient from src.chat.knowledge.open_ie import OpenIE from src.chat.knowledge.kg_manager import KGManager from src.common.logger import get_logger from src.chat.knowledge.utils.hash import get_sha256 +from src.manager.local_store_manager import local_storage # 添加项目根目录到 sys.path @@ -61,7 +62,7 @@ def hash_deduplicate( for _, (raw_paragraph, triple_list) in enumerate(zip(raw_paragraphs.values(), triple_list_data.values())): # 段落hash paragraph_hash = get_sha256(raw_paragraph) - if f"{PG_NAMESPACE}-{paragraph_hash}" in stored_pg_hashes and paragraph_hash in stored_paragraph_hashes: + if f"{local_storage['pg_namespace']}-{paragraph_hash}" in stored_pg_hashes and paragraph_hash in stored_paragraph_hashes: continue new_raw_paragraphs[paragraph_hash] = raw_paragraph new_triple_list_data[paragraph_hash] = triple_list @@ -228,7 +229,7 @@ def main(): # sourcery skip: dict-comprehension # 数据比对:Embedding库与KG的段落hash集合 for pg_hash in kg_manager.stored_paragraph_hashes: - key = f"{PG_NAMESPACE}-{pg_hash}" + key = f"{local_storage['pg_namespace']}-{pg_hash}" if key not in embed_manager.stored_pg_hashes: logger.warning(f"KG中存在Embedding库中不存在的段落:{key}") diff --git a/src/chat/knowledge/embedding_store.py b/src/chat/knowledge/embedding_store.py index 1214611ec..c38dc40c1 100644 --- a/src/chat/knowledge/embedding_store.py +++ b/src/chat/knowledge/embedding_store.py @@ -11,7 +11,7 @@ import pandas as pd import faiss from .llm_client import LLMClient -from .lpmmconfig import ENT_NAMESPACE, PG_NAMESPACE, REL_NAMESPACE, global_config +from .lpmmconfig import global_config from .utils.hash import get_sha256 from .global_logger import logger from rich.traceback import install @@ -25,6 +25,9 @@ from rich.progress import ( SpinnerColumn, TextColumn, ) +from src.manager.local_store_manager import local_storage +from src.llm_models.utils_model import LLMRequest + install(extra_lines=3) ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) @@ -90,11 +93,11 @@ class EmbeddingStore: self.namespace = namespace self.llm_client = llm_client self.dir = dir_path - self.embedding_file_path = dir_path + "/" + namespace + ".parquet" - self.index_file_path = dir_path + "/" + namespace + ".index" + self.embedding_file_path = f"{dir_path}/{namespace}.parquet" + self.index_file_path = f"{dir_path}/{namespace}.index" self.idx2hash_file_path = dir_path + "/" + namespace + "_i2h.json" - self.store = dict() + self.store = {} self.faiss_index = None self.idx2hash = None @@ -296,17 +299,17 @@ class EmbeddingManager: def __init__(self, llm_client: LLMClient): self.paragraphs_embedding_store = EmbeddingStore( llm_client, - PG_NAMESPACE, + local_storage['pg_namespace'], EMBEDDING_DATA_DIR_STR, ) self.entities_embedding_store = EmbeddingStore( llm_client, - ENT_NAMESPACE, + local_storage['pg_namespace'], EMBEDDING_DATA_DIR_STR, ) self.relation_embedding_store = EmbeddingStore( llm_client, - REL_NAMESPACE, + local_storage['pg_namespace'], EMBEDDING_DATA_DIR_STR, ) self.stored_pg_hashes = set() diff --git a/src/chat/knowledge/kg_manager.py b/src/chat/knowledge/kg_manager.py index 1ff651b5e..f3dc4e0cd 100644 --- a/src/chat/knowledge/kg_manager.py +++ b/src/chat/knowledge/kg_manager.py @@ -20,22 +20,16 @@ from quick_algo import di_graph, pagerank from .utils.hash import get_sha256 from .embedding_store import EmbeddingManager, EmbeddingStoreItem -from .lpmmconfig import ( - ENT_NAMESPACE, - PG_NAMESPACE, - RAG_ENT_CNT_NAMESPACE, - RAG_GRAPH_NAMESPACE, - RAG_PG_HASH_NAMESPACE, - global_config, -) +from .lpmmconfig import global_config +from src.manager.local_store_manager import local_storage from .global_logger import logger -ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) + KG_DIR = ( - os.path.join(ROOT_PATH, "data/rag") + os.path.join(local_storage['root_path'], "data/rag") if global_config["persistence"]["rag_data_dir"] is None - else os.path.join(ROOT_PATH, global_config["persistence"]["rag_data_dir"]) + else os.path.join(local_storage['root_path'], global_config["persistence"]["rag_data_dir"]) ) KG_DIR_STR = str(KG_DIR).replace("\\", "/") @@ -46,15 +40,15 @@ class KGManager: # 存储段落的hash值,用于去重 self.stored_paragraph_hashes = set() # 实体出现次数 - self.ent_appear_cnt = dict() + self.ent_appear_cnt = {} # KG self.graph = di_graph.DiGraph() # 持久化相关 self.dir_path = KG_DIR_STR - self.graph_data_path = self.dir_path + "/" + RAG_GRAPH_NAMESPACE + ".graphml" - self.ent_cnt_data_path = self.dir_path + "/" + RAG_ENT_CNT_NAMESPACE + ".parquet" - self.pg_hash_file_path = self.dir_path + "/" + RAG_PG_HASH_NAMESPACE + ".json" + self.graph_data_path = self.dir_path + "/" + local_storage['rag_graph_namespace'] + ".graphml" + self.ent_cnt_data_path = self.dir_path + "/" + local_storage['rag_ent_cnt_namespace'] + ".parquet" + self.pg_hash_file_path = self.dir_path + "/" + local_storage['rag_pg_hash_namespace'] + ".json" def save_to_file(self): """将KG数据保存到文件""" @@ -109,8 +103,8 @@ class KGManager: # 避免自连接 continue # 一个triple就是一条边(同时构建双向联系) - hash_key1 = ENT_NAMESPACE + "-" + get_sha256(triple[0]) - hash_key2 = ENT_NAMESPACE + "-" + get_sha256(triple[2]) + hash_key1 = local_storage['ent_namespace'] + "-" + get_sha256(triple[0]) + hash_key2 = local_storage['ent_namespace'] + "-" + get_sha256(triple[2]) node_to_node[(hash_key1, hash_key2)] = node_to_node.get((hash_key1, hash_key2), 0) + 1.0 node_to_node[(hash_key2, hash_key1)] = node_to_node.get((hash_key2, hash_key1), 0) + 1.0 entity_set.add(hash_key1) @@ -128,8 +122,8 @@ class KGManager: """构建实体节点与文段节点之间的关系""" for idx in triple_list_data: for triple in triple_list_data[idx]: - ent_hash_key = ENT_NAMESPACE + "-" + get_sha256(triple[0]) - pg_hash_key = PG_NAMESPACE + "-" + str(idx) + ent_hash_key = local_storage['ent_namespace'] + "-" + get_sha256(triple[0]) + pg_hash_key = local_storage['pg_namespace'] + "-" + str(idx) node_to_node[(ent_hash_key, pg_hash_key)] = node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0 @staticmethod @@ -144,8 +138,8 @@ class KGManager: ent_hash_list = set() for triple_list in triple_list_data.values(): for triple in triple_list: - ent_hash_list.add(ENT_NAMESPACE + "-" + get_sha256(triple[0])) - ent_hash_list.add(ENT_NAMESPACE + "-" + get_sha256(triple[2])) + ent_hash_list.add(local_storage['ent_namespace'] + "-" + get_sha256(triple[0])) + ent_hash_list.add(local_storage['ent_namespace'] + "-" + get_sha256(triple[2])) ent_hash_list = list(ent_hash_list) synonym_hash_set = set() @@ -250,7 +244,7 @@ class KGManager: for src_tgt in node_to_node.keys(): for node_hash in src_tgt: if node_hash not in existed_nodes: - if node_hash.startswith(ENT_NAMESPACE): + if node_hash.startswith(local_storage['ent_namespace']): # 新增实体节点 node = embedding_manager.entities_embedding_store.store[node_hash] assert isinstance(node, EmbeddingStoreItem) @@ -259,7 +253,7 @@ class KGManager: node_item["type"] = "ent" node_item["create_time"] = now_time self.graph.update_node(node_item) - elif node_hash.startswith(PG_NAMESPACE): + elif node_hash.startswith(local_storage['pg_namespace']): # 新增文段节点 node = embedding_manager.paragraphs_embedding_store.store[node_hash] assert isinstance(node, EmbeddingStoreItem) @@ -340,7 +334,7 @@ class KGManager: # 关系三元组 triple = relation[2:-2].split("', '") for ent in [(triple[0]), (triple[2])]: - ent_hash = ENT_NAMESPACE + "-" + get_sha256(ent) + ent_hash = local_storage['ent_namespace'] + "-" + get_sha256(ent) if ent_hash in existed_nodes: # 该实体需在KG中存在 if ent_hash not in ent_sim_scores: # 尚未记录的实体 ent_sim_scores[ent_hash] = [] @@ -418,7 +412,7 @@ class KGManager: # 获取最终结果 # 从搜索结果中提取文段节点的结果 passage_node_res = [ - (node_key, score) for node_key, score in ppr_res.items() if node_key.startswith(PG_NAMESPACE) + (node_key, score) for node_key, score in ppr_res.items() if node_key.startswith(local_storage['pg_namespace']) ] del ppr_res diff --git a/src/chat/knowledge/knowledge_lib.py b/src/chat/knowledge/knowledge_lib.py index 5540d95e2..8780b93ce 100644 --- a/src/chat/knowledge/knowledge_lib.py +++ b/src/chat/knowledge/knowledge_lib.py @@ -1,4 +1,4 @@ -from src.chat.knowledge.lpmmconfig import PG_NAMESPACE, global_config +from src.chat.knowledge.lpmmconfig import global_config from src.chat.knowledge.embedding_store import EmbeddingManager from src.chat.knowledge.llm_client import LLMClient from src.chat.knowledge.mem_active_manager import MemoryActiveManager @@ -6,10 +6,83 @@ from src.chat.knowledge.qa_manager import QAManager from src.chat.knowledge.kg_manager import KGManager from src.chat.knowledge.global_logger import logger from src.config.config import global_config as bot_global_config -# try: -# import quick_algo -# except ImportError: -# print("quick_algo not found, please install it first") +from src.manager.local_store_manager import local_storage +import os + +INVALID_ENTITY = [ + "", + "你", + "他", + "她", + "它", + "我们", + "你们", + "他们", + "她们", + "它们", +] +PG_NAMESPACE = "paragraph" +ENT_NAMESPACE = "entity" +REL_NAMESPACE = "relation" + +RAG_GRAPH_NAMESPACE = "rag-graph" +RAG_ENT_CNT_NAMESPACE = "rag-ent-cnt" +RAG_PG_HASH_NAMESPACE = "rag-pg-hash" + + +ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) + +def _initialize_knowledge_local_storage(): + """ + 初始化知识库相关的本地存储配置 + 使用字典批量设置,避免重复的if判断 + """ + # 定义所有需要初始化的配置项 + default_configs = { + # 路径配置 + 'root_path': ROOT_PATH, + 'data_path': f"{ROOT_PATH}/data", + 'lpmm_raw_data_path': f"{ROOT_PATH}/data/raw_data", + 'lpmm_openie_data_path': f"{ROOT_PATH}/data/openie", + 'lpmm_embedding_data_dir': f"{ROOT_PATH}/data/embedding", + 'lpmm_rag_data_dir': f"{ROOT_PATH}/data/rag", + + # 实体和命名空间配置 + 'lpmm_invalid_entity': INVALID_ENTITY, + 'pg_namespace': PG_NAMESPACE, + 'ent_namespace': ENT_NAMESPACE, + 'rel_namespace': REL_NAMESPACE, + + # RAG相关命名空间配置 + 'rag_graph_namespace': RAG_GRAPH_NAMESPACE, + 'rag_ent_cnt_namespace': RAG_ENT_CNT_NAMESPACE, + 'rag_pg_hash_namespace': RAG_PG_HASH_NAMESPACE + } + + # 日志级别映射:重要配置用info,其他用debug + important_configs = {'root_path', 'data_path'} + + # 批量设置配置项 + initialized_count = 0 + for key, default_value in default_configs.items(): + if local_storage.get(key) is None: + local_storage.set(key, default_value) + + # 根据重要性选择日志级别 + if key in important_configs: + logger.info(f"设置{key}: {default_value}") + else: + logger.debug(f"设置{key}: {default_value}") + + initialized_count += 1 + + if initialized_count > 0: + logger.info(f"知识库本地存储初始化完成,共设置 {initialized_count} 项配置") + else: + logger.debug("知识库本地存储配置已存在,跳过初始化") + +# 初始化本地存储路径 +_initialize_knowledge_local_storage() # 检查LPMM知识库是否启用 if bot_global_config.lpmm_knowledge.enable: