feat: 移除不必要的命名空间导入,优化本地存储初始化

This commit is contained in:
墨梓柒
2025-07-08 00:18:19 +08:00
parent 3c46d996fe
commit e339f0b228
4 changed files with 111 additions and 40 deletions

View File

@@ -20,22 +20,16 @@ from quick_algo import di_graph, pagerank
from .utils.hash import get_sha256
from .embedding_store import EmbeddingManager, EmbeddingStoreItem
from .lpmmconfig import (
ENT_NAMESPACE,
PG_NAMESPACE,
RAG_ENT_CNT_NAMESPACE,
RAG_GRAPH_NAMESPACE,
RAG_PG_HASH_NAMESPACE,
global_config,
)
from .lpmmconfig import global_config
from src.manager.local_store_manager import local_storage
from .global_logger import logger
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
KG_DIR = (
os.path.join(ROOT_PATH, "data/rag")
os.path.join(local_storage['root_path'], "data/rag")
if global_config["persistence"]["rag_data_dir"] is None
else os.path.join(ROOT_PATH, global_config["persistence"]["rag_data_dir"])
else os.path.join(local_storage['root_path'], global_config["persistence"]["rag_data_dir"])
)
KG_DIR_STR = str(KG_DIR).replace("\\", "/")
@@ -46,15 +40,15 @@ class KGManager:
# 存储段落的hash值用于去重
self.stored_paragraph_hashes = set()
# 实体出现次数
self.ent_appear_cnt = dict()
self.ent_appear_cnt = {}
# KG
self.graph = di_graph.DiGraph()
# 持久化相关
self.dir_path = KG_DIR_STR
self.graph_data_path = self.dir_path + "/" + RAG_GRAPH_NAMESPACE + ".graphml"
self.ent_cnt_data_path = self.dir_path + "/" + RAG_ENT_CNT_NAMESPACE + ".parquet"
self.pg_hash_file_path = self.dir_path + "/" + RAG_PG_HASH_NAMESPACE + ".json"
self.graph_data_path = self.dir_path + "/" + local_storage['rag_graph_namespace'] + ".graphml"
self.ent_cnt_data_path = self.dir_path + "/" + local_storage['rag_ent_cnt_namespace'] + ".parquet"
self.pg_hash_file_path = self.dir_path + "/" + local_storage['rag_pg_hash_namespace'] + ".json"
def save_to_file(self):
"""将KG数据保存到文件"""
@@ -109,8 +103,8 @@ class KGManager:
# 避免自连接
continue
# 一个triple就是一条边同时构建双向联系
hash_key1 = ENT_NAMESPACE + "-" + get_sha256(triple[0])
hash_key2 = ENT_NAMESPACE + "-" + get_sha256(triple[2])
hash_key1 = local_storage['ent_namespace'] + "-" + get_sha256(triple[0])
hash_key2 = local_storage['ent_namespace'] + "-" + get_sha256(triple[2])
node_to_node[(hash_key1, hash_key2)] = node_to_node.get((hash_key1, hash_key2), 0) + 1.0
node_to_node[(hash_key2, hash_key1)] = node_to_node.get((hash_key2, hash_key1), 0) + 1.0
entity_set.add(hash_key1)
@@ -128,8 +122,8 @@ class KGManager:
"""构建实体节点与文段节点之间的关系"""
for idx in triple_list_data:
for triple in triple_list_data[idx]:
ent_hash_key = ENT_NAMESPACE + "-" + get_sha256(triple[0])
pg_hash_key = PG_NAMESPACE + "-" + str(idx)
ent_hash_key = local_storage['ent_namespace'] + "-" + get_sha256(triple[0])
pg_hash_key = local_storage['pg_namespace'] + "-" + str(idx)
node_to_node[(ent_hash_key, pg_hash_key)] = node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0
@staticmethod
@@ -144,8 +138,8 @@ class KGManager:
ent_hash_list = set()
for triple_list in triple_list_data.values():
for triple in triple_list:
ent_hash_list.add(ENT_NAMESPACE + "-" + get_sha256(triple[0]))
ent_hash_list.add(ENT_NAMESPACE + "-" + get_sha256(triple[2]))
ent_hash_list.add(local_storage['ent_namespace'] + "-" + get_sha256(triple[0]))
ent_hash_list.add(local_storage['ent_namespace'] + "-" + get_sha256(triple[2]))
ent_hash_list = list(ent_hash_list)
synonym_hash_set = set()
@@ -250,7 +244,7 @@ class KGManager:
for src_tgt in node_to_node.keys():
for node_hash in src_tgt:
if node_hash not in existed_nodes:
if node_hash.startswith(ENT_NAMESPACE):
if node_hash.startswith(local_storage['ent_namespace']):
# 新增实体节点
node = embedding_manager.entities_embedding_store.store[node_hash]
assert isinstance(node, EmbeddingStoreItem)
@@ -259,7 +253,7 @@ class KGManager:
node_item["type"] = "ent"
node_item["create_time"] = now_time
self.graph.update_node(node_item)
elif node_hash.startswith(PG_NAMESPACE):
elif node_hash.startswith(local_storage['pg_namespace']):
# 新增文段节点
node = embedding_manager.paragraphs_embedding_store.store[node_hash]
assert isinstance(node, EmbeddingStoreItem)
@@ -340,7 +334,7 @@ class KGManager:
# 关系三元组
triple = relation[2:-2].split("', '")
for ent in [(triple[0]), (triple[2])]:
ent_hash = ENT_NAMESPACE + "-" + get_sha256(ent)
ent_hash = local_storage['ent_namespace'] + "-" + get_sha256(ent)
if ent_hash in existed_nodes: # 该实体需在KG中存在
if ent_hash not in ent_sim_scores: # 尚未记录的实体
ent_sim_scores[ent_hash] = []
@@ -418,7 +412,7 @@ class KGManager:
# 获取最终结果
# 从搜索结果中提取文段节点的结果
passage_node_res = [
(node_key, score) for node_key, score in ppr_res.items() if node_key.startswith(PG_NAMESPACE)
(node_key, score) for node_key, score in ppr_res.items() if node_key.startswith(local_storage['pg_namespace'])
]
del ppr_res