import os import time import numpy as np import orjson import pandas as pd from quick_algo import di_graph, pagerank from rich.progress import ( BarColumn, MofNCompleteColumn, Progress, SpinnerColumn, TaskProgressColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn, ) from src.config.config import global_config from .embedding_store import EmbeddingManager, EmbeddingStoreItem from .global_logger import logger from .utils.hash import get_sha256 def _get_kg_dir(): """ 安全地获取KG数据目录路径 """ current_dir = os.path.dirname(os.path.abspath(__file__)) root_path: str = os.path.abspath(os.path.join(current_dir, "..", "..", "..")) kg_dir = os.path.join(root_path, "data/rag") return str(kg_dir).replace("\\", "/") # 延迟初始化,避免在模块加载时就访问可能未初始化的 local_storage def get_kg_dir_str(): """获取KG目录字符串""" return _get_kg_dir() class KGManager: def __init__(self): # 会被保存的字段 # 存储段落的hash值,用于去重 self.stored_paragraph_hashes = set() # 实体出现次数 self.ent_appear_cnt = {} # KG self.graph = di_graph.DiGraph() # 持久化相关 - 使用延迟初始化的路径 self.dir_path = get_kg_dir_str() self.graph_data_path = self.dir_path + "/" + "rag-graph" + ".graphml" self.ent_cnt_data_path = self.dir_path + "/" + "rag-ent-cnt" + ".parquet" self.pg_hash_file_path = self.dir_path + "/" + "rag-pg-hash" + ".json" def save_to_file(self): """将KG数据保存到文件""" # 确保目录存在 if not os.path.exists(self.dir_path): os.makedirs(self.dir_path, exist_ok=True) # 保存KG di_graph.save_to_file(self.graph, self.graph_data_path) # 保存实体计数到文件 ent_cnt_df = pd.DataFrame([{"hash_key": k, "appear_cnt": v} for k, v in self.ent_appear_cnt.items()]) ent_cnt_df.to_parquet(self.ent_cnt_data_path, engine="pyarrow", index=False) # 保存段落hash到文件 with open(self.pg_hash_file_path, "w", encoding="utf-8") as f: data = {"stored_paragraph_hashes": list(self.stored_paragraph_hashes)} f.write(orjson.dumps(data, option=orjson.OPT_INDENT_2).decode("utf-8")) def load_from_file(self): """从文件加载KG数据""" # 确保文件存在 if not os.path.exists(self.pg_hash_file_path): raise FileNotFoundError(f"KG段落hash文件{self.pg_hash_file_path}不存在") if not os.path.exists(self.ent_cnt_data_path): raise FileNotFoundError(f"KG实体计数文件{self.ent_cnt_data_path}不存在") if not os.path.exists(self.graph_data_path): raise FileNotFoundError(f"KG图文件{self.graph_data_path}不存在") # 加载段落hash with open(self.pg_hash_file_path, encoding="utf-8") as f: data = orjson.loads(f.read()) self.stored_paragraph_hashes = set(data["stored_paragraph_hashes"]) # 加载实体计数 ent_cnt_df = pd.read_parquet(self.ent_cnt_data_path, engine="pyarrow") self.ent_appear_cnt = dict({row["hash_key"]: row["appear_cnt"] for _, row in ent_cnt_df.iterrows()}) # 加载KG self.graph = di_graph.load_from_file(self.graph_data_path) def _build_edges_between_ent( self, node_to_node: dict[tuple[str, str], float], triple_list_data: dict[str, list[list[str]]], ): """构建实体节点之间的关系,同时统计实体出现次数""" for triple_list in triple_list_data.values(): entity_set = set() for triple in triple_list: if triple[0] == triple[2]: # 避免自连接 continue # 一个triple就是一条边(同时构建双向联系) hash_key1 = "entity" + "-" + get_sha256(triple[0]) hash_key2 = "entity" + "-" + get_sha256(triple[2]) node_to_node[(hash_key1, hash_key2)] = node_to_node.get((hash_key1, hash_key2), 0) + 1.0 node_to_node[(hash_key2, hash_key1)] = node_to_node.get((hash_key2, hash_key1), 0) + 1.0 entity_set.add(hash_key1) entity_set.add(hash_key2) # 实体出现次数统计 for hash_key in entity_set: self.ent_appear_cnt[hash_key] = self.ent_appear_cnt.get(hash_key, 0) + 1.0 @staticmethod def _build_edges_between_ent_pg( node_to_node: dict[tuple[str, str], float], triple_list_data: dict[str, list[list[str]]], ): """构建实体节点与文段节点之间的关系""" for idx in triple_list_data: for triple in triple_list_data[idx]: ent_hash_key = "entity" + "-" + get_sha256(triple[0]) pg_hash_key = "paragraph" + "-" + str(idx) node_to_node[(ent_hash_key, pg_hash_key)] = node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0 @staticmethod def _synonym_connect( node_to_node: dict[tuple[str, str], float], triple_list_data: dict[str, list[list[str]]], embedding_manager: EmbeddingManager, ) -> int: """同义词连接""" new_edge_cnt = 0 # 获取所有实体节点的hash值 ent_hash_list = set() for triple_list in triple_list_data.values(): for triple in triple_list: ent_hash_list.add("entity" + "-" + get_sha256(triple[0])) ent_hash_list.add("entity" + "-" + get_sha256(triple[2])) ent_hash_list = list(ent_hash_list) synonym_hash_set = set() synonym_result = {} # rich 进度条 total = len(ent_hash_list) with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), BarColumn(), TaskProgressColumn(), MofNCompleteColumn(), "•", TimeElapsedColumn(), "<", TimeRemainingColumn(), transient=False, ) as progress: task = progress.add_task("同义词连接", total=total) for ent_hash in ent_hash_list: if ent_hash in synonym_hash_set: progress.update(task, advance=1) continue ent = embedding_manager.entities_embedding_store.store.get(ent_hash) if ent is None: progress.update(task, advance=1) continue assert isinstance(ent, EmbeddingStoreItem) # 查询相似实体 similar_ents = embedding_manager.entities_embedding_store.search_top_k( ent.embedding, global_config.lpmm_knowledge.rag_synonym_search_top_k ) res_ent = [] # Debug for res_ent_hash, similarity in similar_ents: if res_ent_hash == ent_hash: # 避免自连接 continue if similarity < global_config.lpmm_knowledge.rag_synonym_threshold: # 相似度阈值 continue node_to_node[(res_ent_hash, ent_hash)] = similarity node_to_node[(ent_hash, res_ent_hash)] = similarity synonym_hash_set.add(res_ent_hash) new_edge_cnt += 1 res_ent.append( ( embedding_manager.entities_embedding_store.store[res_ent_hash].str, similarity, ) ) # Debug synonym_result[ent.str] = res_ent progress.update(task, advance=1) for k, v in synonym_result.items(): print(f'"{k}"的相似实体为:{v}') return new_edge_cnt def _update_graph( self, node_to_node: dict[tuple[str, str], float], embedding_manager: EmbeddingManager, ): """更新KG图结构 流程: 1. 更新图结构:遍历所有待添加的新边 - 若是新边,则添加到图中 - 若是已存在的边,则更新边的权重 2. 更新新节点的属性 """ existed_nodes = self.graph.get_node_list() existed_edges = [str((edge[0], edge[1])) for edge in self.graph.get_edge_list()] now_time = time.time() # 更新图结构 for src_tgt, weight in node_to_node.items(): key = str(src_tgt) # 检查边是否已存在 if key not in existed_edges: # 新边 self.graph.add_edge( di_graph.DiEdge( src_tgt[0], src_tgt[1], { "weight": weight, "create_time": now_time, "update_time": now_time, }, ) ) else: # 已存在的边 edge_item = self.graph[src_tgt[0], src_tgt[1]] edge_item["weight"] += weight edge_item["update_time"] = now_time self.graph.update_edge(edge_item) # 更新新节点属性 for src_tgt in node_to_node.keys(): for node_hash in src_tgt: if node_hash not in existed_nodes: if node_hash.startswith("entity"): # 新增实体节点 node = embedding_manager.entities_embedding_store.store.get(node_hash) if node is None: logger.warning(f"实体节点 {node_hash} 在嵌入库中不存在,跳过") continue assert isinstance(node, EmbeddingStoreItem) node_item = self.graph[node_hash] node_item["content"] = node.str node_item["type"] = "ent" node_item["create_time"] = now_time self.graph.update_node(node_item) elif node_hash.startswith("paragraph"): # 新增文段节点 node = embedding_manager.paragraphs_embedding_store.store.get(node_hash) if node is None: logger.warning(f"段落节点 {node_hash} 在嵌入库中不存在,跳过") continue assert isinstance(node, EmbeddingStoreItem) content = node.str.replace("\n", " ") node_item = self.graph[node_hash] node_item["content"] = content if len(content) < 8 else content[:8] + "..." node_item["type"] = "pg" node_item["create_time"] = now_time self.graph.update_node(node_item) def build_kg( self, triple_list_data: dict[str, list[list[str]]], embedding_manager: EmbeddingManager, ): """增量式构建KG 注意:应当在调用该方法后保存KG Args: triple_list_data: 三元组数据 embedding_manager: EmbeddingManager对象 """ # 实体之间的联系 node_to_node = dict() # 构建实体节点之间的关系,同时统计实体出现次数 logger.info("正在构建KG实体节点之间的关系,同时统计实体出现次数") # 从三元组提取实体对 self._build_edges_between_ent(node_to_node, triple_list_data) # 构建实体节点与文段节点之间的关系 logger.info("正在构建KG实体节点与文段节点之间的关系") self._build_edges_between_ent_pg(node_to_node, triple_list_data) # 近义词扩展链接 # 对每个实体节点,找到最相似的实体节点,建立扩展连接 logger.info("正在进行近义词扩展链接") self._synonym_connect(node_to_node, triple_list_data, embedding_manager) # 构建图 self._update_graph(node_to_node, embedding_manager) # 记录已处理(存储)的段落hash for idx in triple_list_data: self.stored_paragraph_hashes.add(str(idx)) def kg_search( self, relation_search_result: list[tuple[tuple[str, str, str], float]], paragraph_search_result: list[tuple[str, float]], embed_manager: EmbeddingManager, ): """RAG搜索与PageRank Args: relation_search_result: RelationEmbedding的搜索结果(relation_tripple, similarity) paragraph_search_result: ParagraphEmbedding的搜索结果(paragraph_hash, similarity) embed_manager: EmbeddingManager对象 """ # 图中存在的节点总集 existed_nodes = self.graph.get_node_list() # 准备PPR使用的数据 # 节点权重:实体 ent_weights = {} # 节点权重:文段 pg_weights = {} # 以下部分处理实体权重ent_weights # 针对每个关系,提取出其中的主宾短语作为两个实体,并记录对应的三元组的相似度作为权重依据 ent_sim_scores = {} for relation_hash, similarity, _ in relation_search_result: # 提取主宾短语 relation = embed_manager.relation_embedding_store.store.get(relation_hash).str assert relation is not None # 断言:relation不为空 # 关系三元组 triple = relation[2:-2].split("', '") for ent in [(triple[0]), (triple[2])]: ent_hash = "entity" + "-" + get_sha256(ent) if ent_hash in existed_nodes: # 该实体需在KG中存在 if ent_hash not in ent_sim_scores: # 尚未记录的实体 ent_sim_scores[ent_hash] = [] ent_sim_scores[ent_hash].append(similarity) ent_mean_scores = {} # 记录实体的平均相似度 for ent_hash, scores in ent_sim_scores.items(): # 先对相似度进行累加,然后与实体计数相除获取最终权重 ent_weights[ent_hash] = float(np.sum(scores)) / self.ent_appear_cnt[ent_hash] # 记录实体的平均相似度,用于后续的top_k筛选 ent_mean_scores[ent_hash] = float(np.mean(scores)) del ent_sim_scores ent_weights_max = max(ent_weights.values()) ent_weights_min = min(ent_weights.values()) if ent_weights_max == ent_weights_min: # 只有一个相似度,则全赋值为1 for ent_hash in ent_weights.keys(): ent_weights[ent_hash] = 1.0 else: down_edge = global_config.lpmm_knowledge.qa_paragraph_node_weight # 缩放取值区间至[down_edge, 1] for ent_hash, score in ent_weights.items(): # 缩放相似度 ent_weights[ent_hash] = ( (score - ent_weights_min) * (1 - down_edge) / (ent_weights_max - ent_weights_min) ) + down_edge # 取平均相似度的top_k实体 top_k = global_config.lpmm_knowledge.qa_ent_filter_top_k if len(ent_mean_scores) > top_k: # 从大到小排序,取后len - k个 ent_mean_scores = {k: v for k, v in sorted(ent_mean_scores.items(), key=lambda item: item[1], reverse=True)} for ent_hash, _ in ent_mean_scores.items(): # 删除被淘汰的实体节点权重设置 del ent_weights[ent_hash] del top_k, ent_mean_scores # 以下部分处理文段权重pg_weights # 将搜索结果中文段的相似度归一化作为权重 pg_sim_scores = {} pg_sim_score_max = 0.0 pg_sim_score_min = 1.0 for pg_hash, similarity in paragraph_search_result: # 查找最大和最小值 pg_sim_score_max = max(pg_sim_score_max, similarity) pg_sim_score_min = min(pg_sim_score_min, similarity) pg_sim_scores[pg_hash] = similarity # 归一化 for pg_hash, similarity in pg_sim_scores.items(): # 归一化相似度 pg_sim_scores[pg_hash] = (similarity - pg_sim_score_min) / (pg_sim_score_max - pg_sim_score_min) del pg_sim_score_max, pg_sim_score_min for pg_hash, score in pg_sim_scores.items(): pg_weights[pg_hash] = ( score * global_config.lpmm_knowledge.qa_paragraph_node_weight ) # 文段权重 = 归一化相似度 * 文段节点权重参数 del pg_sim_scores # 最终权重数据 = 实体权重 + 文段权重 ppr_node_weights = {k: v for d in [ent_weights, pg_weights] for k, v in d.items()} del ent_weights, pg_weights # PersonalizedPageRank ppr_res = pagerank.run_pagerank( self.graph, personalization=ppr_node_weights, max_iter=100, alpha=global_config.lpmm_knowledge.qa_ppr_damping, ) # 获取最终结果 # 从搜索结果中提取文段节点的结果 passage_node_res = [ (node_key, score) for node_key, score in ppr_res.items() if node_key.startswith("paragraph") ] del ppr_res # 排序:按照分数从大到小 passage_node_res = sorted(passage_node_res, key=lambda item: item[1], reverse=True) return passage_node_res, ppr_node_weights