diff --git a/src/chat/knowledge/kg_manager.py b/src/chat/knowledge/kg_manager.py index c2172312f..de81ef8c5 100644 --- a/src/chat/knowledge/kg_manager.py +++ b/src/chat/knowledge/kg_manager.py @@ -21,7 +21,6 @@ from quick_algo import di_graph, pagerank from .utils.hash import get_sha256 from .embedding_store import EmbeddingManager, EmbeddingStoreItem from .lpmmconfig import global_config -from src.manager.local_store_manager import local_storage from .global_logger import logger @@ -30,19 +29,9 @@ def _get_kg_dir(): """ 安全地获取KG数据目录路径 """ - root_path: str = local_storage["root_path"] - if root_path is None: - # 如果 local_storage 中没有 root_path,使用当前文件的相对路径作为备用 - current_dir = os.path.dirname(os.path.abspath(__file__)) - root_path = os.path.abspath(os.path.join(current_dir, "..", "..", "..")) - logger.warning(f"local_storage 中未找到 root_path,使用备用路径: {root_path}") - - # 获取RAG数据目录 - rag_data_dir: str = global_config["persistence"]["rag_data_dir"] - if rag_data_dir is None: - kg_dir = os.path.join(root_path, "data/rag") - else: - kg_dir = os.path.join(root_path, rag_data_dir) + current_dir = os.path.dirname(os.path.abspath(__file__)) + root_path: str = os.path.abspath(os.path.join(current_dir, "..", "..", "..")) + kg_dir = os.path.join(root_path, "data/rag") return str(kg_dir).replace("\\", "/") @@ -65,9 +54,9 @@ class KGManager: # 持久化相关 - 使用延迟初始化的路径 self.dir_path = get_kg_dir_str() - self.graph_data_path = self.dir_path + "/" + local_storage["rag_graph_namespace"] + ".graphml" - self.ent_cnt_data_path = self.dir_path + "/" + local_storage["rag_ent_cnt_namespace"] + ".parquet" - self.pg_hash_file_path = self.dir_path + "/" + local_storage["rag_pg_hash_namespace"] + ".json" + self.graph_data_path = self.dir_path + "/" + "rag-graph" + ".graphml" + self.ent_cnt_data_path = self.dir_path + "/" + "rag-ent-cnt" + ".parquet" + self.pg_hash_file_path = self.dir_path + "/" + "rag-pg-hash" + ".json" def save_to_file(self): """将KG数据保存到文件""" @@ -122,8 +111,8 @@ class KGManager: # 避免自连接 continue # 一个triple就是一条边(同时构建双向联系) - hash_key1 = local_storage["ent_namespace"] + "-" + get_sha256(triple[0]) - hash_key2 = local_storage["ent_namespace"] + "-" + get_sha256(triple[2]) + hash_key1 = "entity" + "-" + get_sha256(triple[0]) + hash_key2 = "entity" + "-" + get_sha256(triple[2]) node_to_node[(hash_key1, hash_key2)] = node_to_node.get((hash_key1, hash_key2), 0) + 1.0 node_to_node[(hash_key2, hash_key1)] = node_to_node.get((hash_key2, hash_key1), 0) + 1.0 entity_set.add(hash_key1) @@ -141,8 +130,8 @@ class KGManager: """构建实体节点与文段节点之间的关系""" for idx in triple_list_data: for triple in triple_list_data[idx]: - ent_hash_key = local_storage["ent_namespace"] + "-" + get_sha256(triple[0]) - pg_hash_key = local_storage["pg_namespace"] + "-" + str(idx) + ent_hash_key = "entity" + "-" + get_sha256(triple[0]) + pg_hash_key = "paragraph" + "-" + str(idx) node_to_node[(ent_hash_key, pg_hash_key)] = node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0 @staticmethod @@ -157,8 +146,8 @@ class KGManager: ent_hash_list = set() for triple_list in triple_list_data.values(): for triple in triple_list: - ent_hash_list.add(local_storage["ent_namespace"] + "-" + get_sha256(triple[0])) - ent_hash_list.add(local_storage["ent_namespace"] + "-" + get_sha256(triple[2])) + ent_hash_list.add("entity" + "-" + get_sha256(triple[0])) + ent_hash_list.add("entity" + "-" + get_sha256(triple[2])) ent_hash_list = list(ent_hash_list) synonym_hash_set = set() @@ -263,7 +252,7 @@ class KGManager: for src_tgt in node_to_node.keys(): for node_hash in src_tgt: if node_hash not in existed_nodes: - if node_hash.startswith(local_storage["ent_namespace"]): + if node_hash.startswith("entity"): # 新增实体节点 node = embedding_manager.entities_embedding_store.store.get(node_hash) if node is None: @@ -275,7 +264,7 @@ class KGManager: node_item["type"] = "ent" node_item["create_time"] = now_time self.graph.update_node(node_item) - elif node_hash.startswith(local_storage["pg_namespace"]): + elif node_hash.startswith("paragraph"): # 新增文段节点 node = embedding_manager.paragraphs_embedding_store.store.get(node_hash) if node is None: @@ -359,7 +348,7 @@ class KGManager: # 关系三元组 triple = relation[2:-2].split("', '") for ent in [(triple[0]), (triple[2])]: - ent_hash = local_storage["ent_namespace"] + "-" + get_sha256(ent) + ent_hash = "entity" + "-" + get_sha256(ent) if ent_hash in existed_nodes: # 该实体需在KG中存在 if ent_hash not in ent_sim_scores: # 尚未记录的实体 ent_sim_scores[ent_hash] = [] @@ -439,7 +428,7 @@ class KGManager: passage_node_res = [ (node_key, score) for node_key, score in ppr_res.items() - if node_key.startswith(local_storage["pg_namespace"]) + if node_key.startswith("paragraph") ] del ppr_res diff --git a/src/chat/knowledge/qa_manager.py b/src/chat/knowledge/qa_manager.py index 678aa4190..587775755 100644 --- a/src/chat/knowledge/qa_manager.py +++ b/src/chat/knowledge/qa_manager.py @@ -50,7 +50,7 @@ class QAManager: # 过滤阈值 # 考虑动态阈值:当存在显著数值差异的结果时,保留显著结果;否则,保留所有结果 relation_search_res = dyn_select_top_k(relation_search_res, 0.5, 1.0) - if relation_search_res[0][1] < global_config.lpmm_knowledge.qa_relation_threshold: + if not relation_search_res or relation_search_res[0][1] < global_config.lpmm_knowledge.qa_relation_threshold: # 未找到相关关系 logger.debug("未找到相关关系,跳过关系检索") relation_search_res = [] @@ -106,6 +106,11 @@ class QAManager: processed_result = await self.process_query(question) if processed_result is not None: query_res = processed_result[0] + # 检查查询结果是否为空 + if not query_res: + logger.debug("知识库查询结果为空,可能是知识库中没有相关内容") + return None + knowledge = [ ( self.embed_manager.paragraphs_embedding_store.store[res[0]].str, diff --git a/src/chat/knowledge/utils/dyn_topk.py b/src/chat/knowledge/utils/dyn_topk.py index eb40ef3a8..5304934f0 100644 --- a/src/chat/knowledge/utils/dyn_topk.py +++ b/src/chat/knowledge/utils/dyn_topk.py @@ -5,6 +5,10 @@ def dyn_select_top_k( score: List[Tuple[Any, float]], jmp_factor: float, var_factor: float ) -> List[Tuple[Any, float, float]]: """动态TopK选择""" + # 检查输入列表是否为空 + if not score: + return [] + # 按照分数排序(降序) sorted_score = sorted(score, key=lambda x: x[1], reverse=True)