feat: 移除不必要的命名空间导入,优化本地存储初始化

This commit is contained in:
墨梓柒
2025-07-08 00:18:19 +08:00
parent 3c46d996fe
commit e339f0b228
4 changed files with 111 additions and 40 deletions

View File

@@ -10,13 +10,14 @@ from time import sleep
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from src.chat.knowledge.lpmmconfig import PG_NAMESPACE, global_config from src.chat.knowledge.lpmmconfig import global_config
from src.chat.knowledge.embedding_store import EmbeddingManager from src.chat.knowledge.embedding_store import EmbeddingManager
from src.chat.knowledge.llm_client import LLMClient from src.chat.knowledge.llm_client import LLMClient
from src.chat.knowledge.open_ie import OpenIE from src.chat.knowledge.open_ie import OpenIE
from src.chat.knowledge.kg_manager import KGManager from src.chat.knowledge.kg_manager import KGManager
from src.common.logger import get_logger from src.common.logger import get_logger
from src.chat.knowledge.utils.hash import get_sha256 from src.chat.knowledge.utils.hash import get_sha256
from src.manager.local_store_manager import local_storage
# 添加项目根目录到 sys.path # 添加项目根目录到 sys.path
@@ -61,7 +62,7 @@ def hash_deduplicate(
for _, (raw_paragraph, triple_list) in enumerate(zip(raw_paragraphs.values(), triple_list_data.values())): for _, (raw_paragraph, triple_list) in enumerate(zip(raw_paragraphs.values(), triple_list_data.values())):
# 段落hash # 段落hash
paragraph_hash = get_sha256(raw_paragraph) paragraph_hash = get_sha256(raw_paragraph)
if f"{PG_NAMESPACE}-{paragraph_hash}" in stored_pg_hashes and paragraph_hash in stored_paragraph_hashes: if f"{local_storage['pg_namespace']}-{paragraph_hash}" in stored_pg_hashes and paragraph_hash in stored_paragraph_hashes:
continue continue
new_raw_paragraphs[paragraph_hash] = raw_paragraph new_raw_paragraphs[paragraph_hash] = raw_paragraph
new_triple_list_data[paragraph_hash] = triple_list new_triple_list_data[paragraph_hash] = triple_list
@@ -228,7 +229,7 @@ def main(): # sourcery skip: dict-comprehension
# 数据比对Embedding库与KG的段落hash集合 # 数据比对Embedding库与KG的段落hash集合
for pg_hash in kg_manager.stored_paragraph_hashes: for pg_hash in kg_manager.stored_paragraph_hashes:
key = f"{PG_NAMESPACE}-{pg_hash}" key = f"{local_storage['pg_namespace']}-{pg_hash}"
if key not in embed_manager.stored_pg_hashes: if key not in embed_manager.stored_pg_hashes:
logger.warning(f"KG中存在Embedding库中不存在的段落{key}") logger.warning(f"KG中存在Embedding库中不存在的段落{key}")

View File

@@ -11,7 +11,7 @@ import pandas as pd
import faiss import faiss
from .llm_client import LLMClient from .llm_client import LLMClient
from .lpmmconfig import ENT_NAMESPACE, PG_NAMESPACE, REL_NAMESPACE, global_config from .lpmmconfig import global_config
from .utils.hash import get_sha256 from .utils.hash import get_sha256
from .global_logger import logger from .global_logger import logger
from rich.traceback import install from rich.traceback import install
@@ -25,6 +25,9 @@ from rich.progress import (
SpinnerColumn, SpinnerColumn,
TextColumn, TextColumn,
) )
from src.manager.local_store_manager import local_storage
from src.llm_models.utils_model import LLMRequest
install(extra_lines=3) install(extra_lines=3)
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
@@ -90,11 +93,11 @@ class EmbeddingStore:
self.namespace = namespace self.namespace = namespace
self.llm_client = llm_client self.llm_client = llm_client
self.dir = dir_path self.dir = dir_path
self.embedding_file_path = dir_path + "/" + namespace + ".parquet" self.embedding_file_path = f"{dir_path}/{namespace}.parquet"
self.index_file_path = dir_path + "/" + namespace + ".index" self.index_file_path = f"{dir_path}/{namespace}.index"
self.idx2hash_file_path = dir_path + "/" + namespace + "_i2h.json" self.idx2hash_file_path = dir_path + "/" + namespace + "_i2h.json"
self.store = dict() self.store = {}
self.faiss_index = None self.faiss_index = None
self.idx2hash = None self.idx2hash = None
@@ -296,17 +299,17 @@ class EmbeddingManager:
def __init__(self, llm_client: LLMClient): def __init__(self, llm_client: LLMClient):
self.paragraphs_embedding_store = EmbeddingStore( self.paragraphs_embedding_store = EmbeddingStore(
llm_client, llm_client,
PG_NAMESPACE, local_storage['pg_namespace'],
EMBEDDING_DATA_DIR_STR, EMBEDDING_DATA_DIR_STR,
) )
self.entities_embedding_store = EmbeddingStore( self.entities_embedding_store = EmbeddingStore(
llm_client, llm_client,
ENT_NAMESPACE, local_storage['pg_namespace'],
EMBEDDING_DATA_DIR_STR, EMBEDDING_DATA_DIR_STR,
) )
self.relation_embedding_store = EmbeddingStore( self.relation_embedding_store = EmbeddingStore(
llm_client, llm_client,
REL_NAMESPACE, local_storage['pg_namespace'],
EMBEDDING_DATA_DIR_STR, EMBEDDING_DATA_DIR_STR,
) )
self.stored_pg_hashes = set() self.stored_pg_hashes = set()

View File

@@ -20,22 +20,16 @@ from quick_algo import di_graph, pagerank
from .utils.hash import get_sha256 from .utils.hash import get_sha256
from .embedding_store import EmbeddingManager, EmbeddingStoreItem from .embedding_store import EmbeddingManager, EmbeddingStoreItem
from .lpmmconfig import ( from .lpmmconfig import global_config
ENT_NAMESPACE, from src.manager.local_store_manager import local_storage
PG_NAMESPACE,
RAG_ENT_CNT_NAMESPACE,
RAG_GRAPH_NAMESPACE,
RAG_PG_HASH_NAMESPACE,
global_config,
)
from .global_logger import logger from .global_logger import logger
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
KG_DIR = ( KG_DIR = (
os.path.join(ROOT_PATH, "data/rag") os.path.join(local_storage['root_path'], "data/rag")
if global_config["persistence"]["rag_data_dir"] is None if global_config["persistence"]["rag_data_dir"] is None
else os.path.join(ROOT_PATH, global_config["persistence"]["rag_data_dir"]) else os.path.join(local_storage['root_path'], global_config["persistence"]["rag_data_dir"])
) )
KG_DIR_STR = str(KG_DIR).replace("\\", "/") KG_DIR_STR = str(KG_DIR).replace("\\", "/")
@@ -46,15 +40,15 @@ class KGManager:
# 存储段落的hash值用于去重 # 存储段落的hash值用于去重
self.stored_paragraph_hashes = set() self.stored_paragraph_hashes = set()
# 实体出现次数 # 实体出现次数
self.ent_appear_cnt = dict() self.ent_appear_cnt = {}
# KG # KG
self.graph = di_graph.DiGraph() self.graph = di_graph.DiGraph()
# 持久化相关 # 持久化相关
self.dir_path = KG_DIR_STR self.dir_path = KG_DIR_STR
self.graph_data_path = self.dir_path + "/" + RAG_GRAPH_NAMESPACE + ".graphml" self.graph_data_path = self.dir_path + "/" + local_storage['rag_graph_namespace'] + ".graphml"
self.ent_cnt_data_path = self.dir_path + "/" + RAG_ENT_CNT_NAMESPACE + ".parquet" self.ent_cnt_data_path = self.dir_path + "/" + local_storage['rag_ent_cnt_namespace'] + ".parquet"
self.pg_hash_file_path = self.dir_path + "/" + RAG_PG_HASH_NAMESPACE + ".json" self.pg_hash_file_path = self.dir_path + "/" + local_storage['rag_pg_hash_namespace'] + ".json"
def save_to_file(self): def save_to_file(self):
"""将KG数据保存到文件""" """将KG数据保存到文件"""
@@ -109,8 +103,8 @@ class KGManager:
# 避免自连接 # 避免自连接
continue continue
# 一个triple就是一条边同时构建双向联系 # 一个triple就是一条边同时构建双向联系
hash_key1 = ENT_NAMESPACE + "-" + get_sha256(triple[0]) hash_key1 = local_storage['ent_namespace'] + "-" + get_sha256(triple[0])
hash_key2 = ENT_NAMESPACE + "-" + get_sha256(triple[2]) hash_key2 = local_storage['ent_namespace'] + "-" + get_sha256(triple[2])
node_to_node[(hash_key1, hash_key2)] = node_to_node.get((hash_key1, hash_key2), 0) + 1.0 node_to_node[(hash_key1, hash_key2)] = node_to_node.get((hash_key1, hash_key2), 0) + 1.0
node_to_node[(hash_key2, hash_key1)] = node_to_node.get((hash_key2, hash_key1), 0) + 1.0 node_to_node[(hash_key2, hash_key1)] = node_to_node.get((hash_key2, hash_key1), 0) + 1.0
entity_set.add(hash_key1) entity_set.add(hash_key1)
@@ -128,8 +122,8 @@ class KGManager:
"""构建实体节点与文段节点之间的关系""" """构建实体节点与文段节点之间的关系"""
for idx in triple_list_data: for idx in triple_list_data:
for triple in triple_list_data[idx]: for triple in triple_list_data[idx]:
ent_hash_key = ENT_NAMESPACE + "-" + get_sha256(triple[0]) ent_hash_key = local_storage['ent_namespace'] + "-" + get_sha256(triple[0])
pg_hash_key = PG_NAMESPACE + "-" + str(idx) pg_hash_key = local_storage['pg_namespace'] + "-" + str(idx)
node_to_node[(ent_hash_key, pg_hash_key)] = node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0 node_to_node[(ent_hash_key, pg_hash_key)] = node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0
@staticmethod @staticmethod
@@ -144,8 +138,8 @@ class KGManager:
ent_hash_list = set() ent_hash_list = set()
for triple_list in triple_list_data.values(): for triple_list in triple_list_data.values():
for triple in triple_list: for triple in triple_list:
ent_hash_list.add(ENT_NAMESPACE + "-" + get_sha256(triple[0])) ent_hash_list.add(local_storage['ent_namespace'] + "-" + get_sha256(triple[0]))
ent_hash_list.add(ENT_NAMESPACE + "-" + get_sha256(triple[2])) ent_hash_list.add(local_storage['ent_namespace'] + "-" + get_sha256(triple[2]))
ent_hash_list = list(ent_hash_list) ent_hash_list = list(ent_hash_list)
synonym_hash_set = set() synonym_hash_set = set()
@@ -250,7 +244,7 @@ class KGManager:
for src_tgt in node_to_node.keys(): for src_tgt in node_to_node.keys():
for node_hash in src_tgt: for node_hash in src_tgt:
if node_hash not in existed_nodes: if node_hash not in existed_nodes:
if node_hash.startswith(ENT_NAMESPACE): if node_hash.startswith(local_storage['ent_namespace']):
# 新增实体节点 # 新增实体节点
node = embedding_manager.entities_embedding_store.store[node_hash] node = embedding_manager.entities_embedding_store.store[node_hash]
assert isinstance(node, EmbeddingStoreItem) assert isinstance(node, EmbeddingStoreItem)
@@ -259,7 +253,7 @@ class KGManager:
node_item["type"] = "ent" node_item["type"] = "ent"
node_item["create_time"] = now_time node_item["create_time"] = now_time
self.graph.update_node(node_item) self.graph.update_node(node_item)
elif node_hash.startswith(PG_NAMESPACE): elif node_hash.startswith(local_storage['pg_namespace']):
# 新增文段节点 # 新增文段节点
node = embedding_manager.paragraphs_embedding_store.store[node_hash] node = embedding_manager.paragraphs_embedding_store.store[node_hash]
assert isinstance(node, EmbeddingStoreItem) assert isinstance(node, EmbeddingStoreItem)
@@ -340,7 +334,7 @@ class KGManager:
# 关系三元组 # 关系三元组
triple = relation[2:-2].split("', '") triple = relation[2:-2].split("', '")
for ent in [(triple[0]), (triple[2])]: for ent in [(triple[0]), (triple[2])]:
ent_hash = ENT_NAMESPACE + "-" + get_sha256(ent) ent_hash = local_storage['ent_namespace'] + "-" + get_sha256(ent)
if ent_hash in existed_nodes: # 该实体需在KG中存在 if ent_hash in existed_nodes: # 该实体需在KG中存在
if ent_hash not in ent_sim_scores: # 尚未记录的实体 if ent_hash not in ent_sim_scores: # 尚未记录的实体
ent_sim_scores[ent_hash] = [] ent_sim_scores[ent_hash] = []
@@ -418,7 +412,7 @@ class KGManager:
# 获取最终结果 # 获取最终结果
# 从搜索结果中提取文段节点的结果 # 从搜索结果中提取文段节点的结果
passage_node_res = [ passage_node_res = [
(node_key, score) for node_key, score in ppr_res.items() if node_key.startswith(PG_NAMESPACE) (node_key, score) for node_key, score in ppr_res.items() if node_key.startswith(local_storage['pg_namespace'])
] ]
del ppr_res del ppr_res

View File

@@ -1,4 +1,4 @@
from src.chat.knowledge.lpmmconfig import PG_NAMESPACE, global_config from src.chat.knowledge.lpmmconfig import global_config
from src.chat.knowledge.embedding_store import EmbeddingManager from src.chat.knowledge.embedding_store import EmbeddingManager
from src.chat.knowledge.llm_client import LLMClient from src.chat.knowledge.llm_client import LLMClient
from src.chat.knowledge.mem_active_manager import MemoryActiveManager from src.chat.knowledge.mem_active_manager import MemoryActiveManager
@@ -6,10 +6,83 @@ from src.chat.knowledge.qa_manager import QAManager
from src.chat.knowledge.kg_manager import KGManager from src.chat.knowledge.kg_manager import KGManager
from src.chat.knowledge.global_logger import logger from src.chat.knowledge.global_logger import logger
from src.config.config import global_config as bot_global_config from src.config.config import global_config as bot_global_config
# try: from src.manager.local_store_manager import local_storage
# import quick_algo import os
# except ImportError:
# print("quick_algo not found, please install it first") INVALID_ENTITY = [
"",
"",
"",
"",
"",
"我们",
"你们",
"他们",
"她们",
"它们",
]
PG_NAMESPACE = "paragraph"
ENT_NAMESPACE = "entity"
REL_NAMESPACE = "relation"
RAG_GRAPH_NAMESPACE = "rag-graph"
RAG_ENT_CNT_NAMESPACE = "rag-ent-cnt"
RAG_PG_HASH_NAMESPACE = "rag-pg-hash"
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
def _initialize_knowledge_local_storage():
"""
初始化知识库相关的本地存储配置
使用字典批量设置避免重复的if判断
"""
# 定义所有需要初始化的配置项
default_configs = {
# 路径配置
'root_path': ROOT_PATH,
'data_path': f"{ROOT_PATH}/data",
'lpmm_raw_data_path': f"{ROOT_PATH}/data/raw_data",
'lpmm_openie_data_path': f"{ROOT_PATH}/data/openie",
'lpmm_embedding_data_dir': f"{ROOT_PATH}/data/embedding",
'lpmm_rag_data_dir': f"{ROOT_PATH}/data/rag",
# 实体和命名空间配置
'lpmm_invalid_entity': INVALID_ENTITY,
'pg_namespace': PG_NAMESPACE,
'ent_namespace': ENT_NAMESPACE,
'rel_namespace': REL_NAMESPACE,
# RAG相关命名空间配置
'rag_graph_namespace': RAG_GRAPH_NAMESPACE,
'rag_ent_cnt_namespace': RAG_ENT_CNT_NAMESPACE,
'rag_pg_hash_namespace': RAG_PG_HASH_NAMESPACE
}
# 日志级别映射重要配置用info其他用debug
important_configs = {'root_path', 'data_path'}
# 批量设置配置项
initialized_count = 0
for key, default_value in default_configs.items():
if local_storage.get(key) is None:
local_storage.set(key, default_value)
# 根据重要性选择日志级别
if key in important_configs:
logger.info(f"设置{key}: {default_value}")
else:
logger.debug(f"设置{key}: {default_value}")
initialized_count += 1
if initialized_count > 0:
logger.info(f"知识库本地存储初始化完成,共设置 {initialized_count} 项配置")
else:
logger.debug("知识库本地存储配置已存在,跳过初始化")
# 初始化本地存储路径
_initialize_knowledge_local_storage()
# 检查LPMM知识库是否启用 # 检查LPMM知识库是否启用
if bot_global_config.lpmm_knowledge.enable: if bot_global_config.lpmm_knowledge.enable: