Merge branch 'dev' of github.com:MaiM-with-u/MaiBot into dev

This commit is contained in:
UnCLAS-Prommer
2025-08-03 11:20:15 +08:00
3 changed files with 26 additions and 28 deletions

View File

@@ -21,7 +21,6 @@ from quick_algo import di_graph, pagerank
from .utils.hash import get_sha256 from .utils.hash import get_sha256
from .embedding_store import EmbeddingManager, EmbeddingStoreItem from .embedding_store import EmbeddingManager, EmbeddingStoreItem
from .lpmmconfig import global_config from .lpmmconfig import global_config
from src.manager.local_store_manager import local_storage
from .global_logger import logger from .global_logger import logger
@@ -30,19 +29,9 @@ def _get_kg_dir():
""" """
安全地获取KG数据目录路径 安全地获取KG数据目录路径
""" """
root_path: str = local_storage["root_path"] current_dir = os.path.dirname(os.path.abspath(__file__))
if root_path is None: root_path: str = os.path.abspath(os.path.join(current_dir, "..", "..", ".."))
# 如果 local_storage 中没有 root_path使用当前文件的相对路径作为备用 kg_dir = os.path.join(root_path, "data/rag")
current_dir = os.path.dirname(os.path.abspath(__file__))
root_path = os.path.abspath(os.path.join(current_dir, "..", "..", ".."))
logger.warning(f"local_storage 中未找到 root_path使用备用路径: {root_path}")
# 获取RAG数据目录
rag_data_dir: str = global_config["persistence"]["rag_data_dir"]
if rag_data_dir is None:
kg_dir = os.path.join(root_path, "data/rag")
else:
kg_dir = os.path.join(root_path, rag_data_dir)
return str(kg_dir).replace("\\", "/") return str(kg_dir).replace("\\", "/")
@@ -65,9 +54,9 @@ class KGManager:
# 持久化相关 - 使用延迟初始化的路径 # 持久化相关 - 使用延迟初始化的路径
self.dir_path = get_kg_dir_str() self.dir_path = get_kg_dir_str()
self.graph_data_path = self.dir_path + "/" + local_storage["rag_graph_namespace"] + ".graphml" self.graph_data_path = self.dir_path + "/" + "rag-graph" + ".graphml"
self.ent_cnt_data_path = self.dir_path + "/" + local_storage["rag_ent_cnt_namespace"] + ".parquet" self.ent_cnt_data_path = self.dir_path + "/" + "rag-ent-cnt" + ".parquet"
self.pg_hash_file_path = self.dir_path + "/" + local_storage["rag_pg_hash_namespace"] + ".json" self.pg_hash_file_path = self.dir_path + "/" + "rag-pg-hash" + ".json"
def save_to_file(self): def save_to_file(self):
"""将KG数据保存到文件""" """将KG数据保存到文件"""
@@ -122,8 +111,8 @@ class KGManager:
# 避免自连接 # 避免自连接
continue continue
# 一个triple就是一条边同时构建双向联系 # 一个triple就是一条边同时构建双向联系
hash_key1 = local_storage["ent_namespace"] + "-" + get_sha256(triple[0]) hash_key1 = "entity" + "-" + get_sha256(triple[0])
hash_key2 = local_storage["ent_namespace"] + "-" + get_sha256(triple[2]) hash_key2 = "entity" + "-" + get_sha256(triple[2])
node_to_node[(hash_key1, hash_key2)] = node_to_node.get((hash_key1, hash_key2), 0) + 1.0 node_to_node[(hash_key1, hash_key2)] = node_to_node.get((hash_key1, hash_key2), 0) + 1.0
node_to_node[(hash_key2, hash_key1)] = node_to_node.get((hash_key2, hash_key1), 0) + 1.0 node_to_node[(hash_key2, hash_key1)] = node_to_node.get((hash_key2, hash_key1), 0) + 1.0
entity_set.add(hash_key1) entity_set.add(hash_key1)
@@ -141,8 +130,8 @@ class KGManager:
"""构建实体节点与文段节点之间的关系""" """构建实体节点与文段节点之间的关系"""
for idx in triple_list_data: for idx in triple_list_data:
for triple in triple_list_data[idx]: for triple in triple_list_data[idx]:
ent_hash_key = local_storage["ent_namespace"] + "-" + get_sha256(triple[0]) ent_hash_key = "entity" + "-" + get_sha256(triple[0])
pg_hash_key = local_storage["pg_namespace"] + "-" + str(idx) pg_hash_key = "paragraph" + "-" + str(idx)
node_to_node[(ent_hash_key, pg_hash_key)] = node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0 node_to_node[(ent_hash_key, pg_hash_key)] = node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0
@staticmethod @staticmethod
@@ -157,8 +146,8 @@ class KGManager:
ent_hash_list = set() ent_hash_list = set()
for triple_list in triple_list_data.values(): for triple_list in triple_list_data.values():
for triple in triple_list: for triple in triple_list:
ent_hash_list.add(local_storage["ent_namespace"] + "-" + get_sha256(triple[0])) ent_hash_list.add("entity" + "-" + get_sha256(triple[0]))
ent_hash_list.add(local_storage["ent_namespace"] + "-" + get_sha256(triple[2])) ent_hash_list.add("entity" + "-" + get_sha256(triple[2]))
ent_hash_list = list(ent_hash_list) ent_hash_list = list(ent_hash_list)
synonym_hash_set = set() synonym_hash_set = set()
@@ -263,7 +252,7 @@ class KGManager:
for src_tgt in node_to_node.keys(): for src_tgt in node_to_node.keys():
for node_hash in src_tgt: for node_hash in src_tgt:
if node_hash not in existed_nodes: if node_hash not in existed_nodes:
if node_hash.startswith(local_storage["ent_namespace"]): if node_hash.startswith("entity"):
# 新增实体节点 # 新增实体节点
node = embedding_manager.entities_embedding_store.store.get(node_hash) node = embedding_manager.entities_embedding_store.store.get(node_hash)
if node is None: if node is None:
@@ -275,7 +264,7 @@ class KGManager:
node_item["type"] = "ent" node_item["type"] = "ent"
node_item["create_time"] = now_time node_item["create_time"] = now_time
self.graph.update_node(node_item) self.graph.update_node(node_item)
elif node_hash.startswith(local_storage["pg_namespace"]): elif node_hash.startswith("paragraph"):
# 新增文段节点 # 新增文段节点
node = embedding_manager.paragraphs_embedding_store.store.get(node_hash) node = embedding_manager.paragraphs_embedding_store.store.get(node_hash)
if node is None: if node is None:
@@ -359,7 +348,7 @@ class KGManager:
# 关系三元组 # 关系三元组
triple = relation[2:-2].split("', '") triple = relation[2:-2].split("', '")
for ent in [(triple[0]), (triple[2])]: for ent in [(triple[0]), (triple[2])]:
ent_hash = local_storage["ent_namespace"] + "-" + get_sha256(ent) ent_hash = "entity" + "-" + get_sha256(ent)
if ent_hash in existed_nodes: # 该实体需在KG中存在 if ent_hash in existed_nodes: # 该实体需在KG中存在
if ent_hash not in ent_sim_scores: # 尚未记录的实体 if ent_hash not in ent_sim_scores: # 尚未记录的实体
ent_sim_scores[ent_hash] = [] ent_sim_scores[ent_hash] = []
@@ -439,7 +428,7 @@ class KGManager:
passage_node_res = [ passage_node_res = [
(node_key, score) (node_key, score)
for node_key, score in ppr_res.items() for node_key, score in ppr_res.items()
if node_key.startswith(local_storage["pg_namespace"]) if node_key.startswith("paragraph")
] ]
del ppr_res del ppr_res

View File

@@ -50,7 +50,7 @@ class QAManager:
# 过滤阈值 # 过滤阈值
# 考虑动态阈值:当存在显著数值差异的结果时,保留显著结果;否则,保留所有结果 # 考虑动态阈值:当存在显著数值差异的结果时,保留显著结果;否则,保留所有结果
relation_search_res = dyn_select_top_k(relation_search_res, 0.5, 1.0) relation_search_res = dyn_select_top_k(relation_search_res, 0.5, 1.0)
if relation_search_res[0][1] < global_config.lpmm_knowledge.qa_relation_threshold: if not relation_search_res or relation_search_res[0][1] < global_config.lpmm_knowledge.qa_relation_threshold:
# 未找到相关关系 # 未找到相关关系
logger.debug("未找到相关关系,跳过关系检索") logger.debug("未找到相关关系,跳过关系检索")
relation_search_res = [] relation_search_res = []
@@ -106,6 +106,11 @@ class QAManager:
processed_result = await self.process_query(question) processed_result = await self.process_query(question)
if processed_result is not None: if processed_result is not None:
query_res = processed_result[0] query_res = processed_result[0]
# 检查查询结果是否为空
if not query_res:
logger.debug("知识库查询结果为空,可能是知识库中没有相关内容")
return None
knowledge = [ knowledge = [
( (
self.embed_manager.paragraphs_embedding_store.store[res[0]].str, self.embed_manager.paragraphs_embedding_store.store[res[0]].str,

View File

@@ -5,6 +5,10 @@ def dyn_select_top_k(
score: List[Tuple[Any, float]], jmp_factor: float, var_factor: float score: List[Tuple[Any, float]], jmp_factor: float, var_factor: float
) -> List[Tuple[Any, float, float]]: ) -> List[Tuple[Any, float, float]]:
"""动态TopK选择""" """动态TopK选择"""
# 检查输入列表是否为空
if not score:
return []
# 按照分数排序(降序) # 按照分数排序(降序)
sorted_score = sorted(score, key=lambda x: x[1], reverse=True) sorted_score = sorted(score, key=lambda x: x[1], reverse=True)