重构KGManager类,移除对local_storage的依赖,简化KG目录路径获取逻辑
This commit is contained in:
@@ -21,7 +21,6 @@ from quick_algo import di_graph, pagerank
|
||||
from .utils.hash import get_sha256
|
||||
from .embedding_store import EmbeddingManager, EmbeddingStoreItem
|
||||
from .lpmmconfig import global_config
|
||||
from src.manager.local_store_manager import local_storage
|
||||
|
||||
from .global_logger import logger
|
||||
|
||||
@@ -30,19 +29,9 @@ def _get_kg_dir():
|
||||
"""
|
||||
安全地获取KG数据目录路径
|
||||
"""
|
||||
root_path: str = local_storage["root_path"]
|
||||
if root_path is None:
|
||||
# 如果 local_storage 中没有 root_path,使用当前文件的相对路径作为备用
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
root_path = os.path.abspath(os.path.join(current_dir, "..", "..", ".."))
|
||||
logger.warning(f"local_storage 中未找到 root_path,使用备用路径: {root_path}")
|
||||
|
||||
# 获取RAG数据目录
|
||||
rag_data_dir: str = global_config["persistence"]["rag_data_dir"]
|
||||
if rag_data_dir is None:
|
||||
kg_dir = os.path.join(root_path, "data/rag")
|
||||
else:
|
||||
kg_dir = os.path.join(root_path, rag_data_dir)
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
root_path: str = os.path.abspath(os.path.join(current_dir, "..", "..", ".."))
|
||||
kg_dir = os.path.join(root_path, "data/rag")
|
||||
|
||||
return str(kg_dir).replace("\\", "/")
|
||||
|
||||
@@ -65,9 +54,9 @@ class KGManager:
|
||||
|
||||
# 持久化相关 - 使用延迟初始化的路径
|
||||
self.dir_path = get_kg_dir_str()
|
||||
self.graph_data_path = self.dir_path + "/" + local_storage["rag_graph_namespace"] + ".graphml"
|
||||
self.ent_cnt_data_path = self.dir_path + "/" + local_storage["rag_ent_cnt_namespace"] + ".parquet"
|
||||
self.pg_hash_file_path = self.dir_path + "/" + local_storage["rag_pg_hash_namespace"] + ".json"
|
||||
self.graph_data_path = self.dir_path + "/" + "rag-graph" + ".graphml"
|
||||
self.ent_cnt_data_path = self.dir_path + "/" + "rag-ent-cnt" + ".parquet"
|
||||
self.pg_hash_file_path = self.dir_path + "/" + "rag-pg-hash" + ".json"
|
||||
|
||||
def save_to_file(self):
|
||||
"""将KG数据保存到文件"""
|
||||
@@ -122,8 +111,8 @@ class KGManager:
|
||||
# 避免自连接
|
||||
continue
|
||||
# 一个triple就是一条边(同时构建双向联系)
|
||||
hash_key1 = local_storage["ent_namespace"] + "-" + get_sha256(triple[0])
|
||||
hash_key2 = local_storage["ent_namespace"] + "-" + get_sha256(triple[2])
|
||||
hash_key1 = "entity" + "-" + get_sha256(triple[0])
|
||||
hash_key2 = "entity" + "-" + get_sha256(triple[2])
|
||||
node_to_node[(hash_key1, hash_key2)] = node_to_node.get((hash_key1, hash_key2), 0) + 1.0
|
||||
node_to_node[(hash_key2, hash_key1)] = node_to_node.get((hash_key2, hash_key1), 0) + 1.0
|
||||
entity_set.add(hash_key1)
|
||||
@@ -141,8 +130,8 @@ class KGManager:
|
||||
"""构建实体节点与文段节点之间的关系"""
|
||||
for idx in triple_list_data:
|
||||
for triple in triple_list_data[idx]:
|
||||
ent_hash_key = local_storage["ent_namespace"] + "-" + get_sha256(triple[0])
|
||||
pg_hash_key = local_storage["pg_namespace"] + "-" + str(idx)
|
||||
ent_hash_key = "entity" + "-" + get_sha256(triple[0])
|
||||
pg_hash_key = "paragraph" + "-" + str(idx)
|
||||
node_to_node[(ent_hash_key, pg_hash_key)] = node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0
|
||||
|
||||
@staticmethod
|
||||
@@ -157,8 +146,8 @@ class KGManager:
|
||||
ent_hash_list = set()
|
||||
for triple_list in triple_list_data.values():
|
||||
for triple in triple_list:
|
||||
ent_hash_list.add(local_storage["ent_namespace"] + "-" + get_sha256(triple[0]))
|
||||
ent_hash_list.add(local_storage["ent_namespace"] + "-" + get_sha256(triple[2]))
|
||||
ent_hash_list.add("entity" + "-" + get_sha256(triple[0]))
|
||||
ent_hash_list.add("entity" + "-" + get_sha256(triple[2]))
|
||||
ent_hash_list = list(ent_hash_list)
|
||||
|
||||
synonym_hash_set = set()
|
||||
@@ -263,7 +252,7 @@ class KGManager:
|
||||
for src_tgt in node_to_node.keys():
|
||||
for node_hash in src_tgt:
|
||||
if node_hash not in existed_nodes:
|
||||
if node_hash.startswith(local_storage["ent_namespace"]):
|
||||
if node_hash.startswith("entity"):
|
||||
# 新增实体节点
|
||||
node = embedding_manager.entities_embedding_store.store.get(node_hash)
|
||||
if node is None:
|
||||
@@ -275,7 +264,7 @@ class KGManager:
|
||||
node_item["type"] = "ent"
|
||||
node_item["create_time"] = now_time
|
||||
self.graph.update_node(node_item)
|
||||
elif node_hash.startswith(local_storage["pg_namespace"]):
|
||||
elif node_hash.startswith("paragraph"):
|
||||
# 新增文段节点
|
||||
node = embedding_manager.paragraphs_embedding_store.store.get(node_hash)
|
||||
if node is None:
|
||||
@@ -359,7 +348,7 @@ class KGManager:
|
||||
# 关系三元组
|
||||
triple = relation[2:-2].split("', '")
|
||||
for ent in [(triple[0]), (triple[2])]:
|
||||
ent_hash = local_storage["ent_namespace"] + "-" + get_sha256(ent)
|
||||
ent_hash = "entity" + "-" + get_sha256(ent)
|
||||
if ent_hash in existed_nodes: # 该实体需在KG中存在
|
||||
if ent_hash not in ent_sim_scores: # 尚未记录的实体
|
||||
ent_sim_scores[ent_hash] = []
|
||||
@@ -439,7 +428,7 @@ class KGManager:
|
||||
passage_node_res = [
|
||||
(node_key, score)
|
||||
for node_key, score in ppr_res.items()
|
||||
if node_key.startswith(local_storage["pg_namespace"])
|
||||
if node_key.startswith("paragraph")
|
||||
]
|
||||
del ppr_res
|
||||
|
||||
|
||||
Reference in New Issue
Block a user