feat: 更新数据路径配置,增强数据处理功能并优化错误提示
This commit is contained in:
@@ -26,6 +26,7 @@ try:
|
||||
embed_manager.load_from_file()
|
||||
except Exception as e:
|
||||
logger.error("从文件加载Embedding库时发生错误:{}".format(e))
|
||||
logger.error("如果你是第一次导入知识,或者还未导入知识,请忽略此错误")
|
||||
logger.info("Embedding库加载完成")
|
||||
# 初始化KG
|
||||
kg_manager = KGManager()
|
||||
@@ -34,6 +35,7 @@ try:
|
||||
kg_manager.load_from_file()
|
||||
except Exception as e:
|
||||
logger.error("从文件加载KG时发生错误:{}".format(e))
|
||||
logger.error("如果你是第一次导入知识,或者还未导入知识,请忽略此错误")
|
||||
logger.info("KG加载完成")
|
||||
|
||||
logger.info(f"KG节点数量:{len(kg_manager.graph.get_node_list())}")
|
||||
|
||||
@@ -13,9 +13,11 @@ from .lpmmconfig import ENT_NAMESPACE, PG_NAMESPACE, REL_NAMESPACE, global_confi
|
||||
from .utils.hash import get_sha256
|
||||
from .global_logger import logger
|
||||
from rich.traceback import install
|
||||
from rich.progress import Progress, BarColumn, TimeElapsedColumn, TimeRemainingColumn, TaskProgressColumn, MofNCompleteColumn, SpinnerColumn, TextColumn
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
TOTAL_EMBEDDING_TIMES = 3 # 统计嵌入次数
|
||||
|
||||
@dataclass
|
||||
class EmbeddingStoreItem:
|
||||
@@ -52,20 +54,35 @@ class EmbeddingStore:
|
||||
def _get_embedding(self, s: str) -> List[float]:
|
||||
return self.llm_client.send_embedding_request(global_config["embedding"]["model"], s)
|
||||
|
||||
def batch_insert_strs(self, strs: List[str]) -> None:
|
||||
def batch_insert_strs(self, strs: List[str], times: int) -> None:
|
||||
"""向库中存入字符串"""
|
||||
# 逐项处理
|
||||
for s in tqdm.tqdm(strs, desc="存入嵌入库", unit="items"):
|
||||
# 计算hash去重
|
||||
item_hash = self.namespace + "-" + get_sha256(s)
|
||||
if item_hash in self.store:
|
||||
continue
|
||||
total = len(strs)
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
MofNCompleteColumn(),
|
||||
"•",
|
||||
TimeElapsedColumn(),
|
||||
"<",
|
||||
TimeRemainingColumn(),
|
||||
transient=False,
|
||||
) as progress:
|
||||
task = progress.add_task(f"存入嵌入库:({times}/{TOTAL_EMBEDDING_TIMES})", total=total)
|
||||
for s in strs:
|
||||
# 计算hash去重
|
||||
item_hash = self.namespace + "-" + get_sha256(s)
|
||||
if item_hash in self.store:
|
||||
progress.update(task, advance=1)
|
||||
continue
|
||||
|
||||
# 获取embedding
|
||||
embedding = self._get_embedding(s)
|
||||
# 获取embedding
|
||||
embedding = self._get_embedding(s)
|
||||
|
||||
# 存入
|
||||
self.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, s)
|
||||
# 存入
|
||||
self.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, s)
|
||||
progress.update(task, advance=1)
|
||||
|
||||
def save_to_file(self) -> None:
|
||||
"""保存到文件"""
|
||||
@@ -191,7 +208,7 @@ class EmbeddingManager:
|
||||
|
||||
def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]):
|
||||
"""将段落编码存入Embedding库"""
|
||||
self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values()))
|
||||
self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values()),times=1)
|
||||
|
||||
def _store_ent_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]):
|
||||
"""将实体编码存入Embedding库"""
|
||||
@@ -200,7 +217,7 @@ class EmbeddingManager:
|
||||
for triple in triple_list:
|
||||
entities.add(triple[0])
|
||||
entities.add(triple[2])
|
||||
self.entities_embedding_store.batch_insert_strs(list(entities))
|
||||
self.entities_embedding_store.batch_insert_strs(list(entities),times=2)
|
||||
|
||||
def _store_rel_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]):
|
||||
"""将关系编码存入Embedding库"""
|
||||
@@ -208,7 +225,7 @@ class EmbeddingManager:
|
||||
for triples in triple_list_data.values():
|
||||
graph_triples.extend([tuple(t) for t in triples])
|
||||
graph_triples = list(set(graph_triples))
|
||||
self.relation_embedding_store.batch_insert_strs([str(triple) for triple in graph_triples])
|
||||
self.relation_embedding_store.batch_insert_strs([str(triple) for triple in graph_triples],times=3)
|
||||
|
||||
def load_from_file(self):
|
||||
"""从文件加载"""
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import tqdm
|
||||
from rich.progress import Progress, BarColumn, TimeElapsedColumn, TimeRemainingColumn, TaskProgressColumn, MofNCompleteColumn, SpinnerColumn, TextColumn
|
||||
from quick_algo import di_graph, pagerank
|
||||
|
||||
|
||||
@@ -132,41 +132,56 @@ class KGManager:
|
||||
ent_hash_list = list(ent_hash_list)
|
||||
|
||||
synonym_hash_set = set()
|
||||
|
||||
synonym_result = dict()
|
||||
|
||||
# 对每个实体节点,查找其相似的实体节点,建立扩展连接
|
||||
for ent_hash in tqdm.tqdm(ent_hash_list):
|
||||
if ent_hash in synonym_hash_set:
|
||||
# 避免同一批次内重复添加
|
||||
continue
|
||||
ent = embedding_manager.entities_embedding_store.store.get(ent_hash)
|
||||
assert isinstance(ent, EmbeddingStoreItem)
|
||||
if ent is None:
|
||||
continue
|
||||
# 查询相似实体
|
||||
similar_ents = embedding_manager.entities_embedding_store.search_top_k(
|
||||
ent.embedding, global_config["rag"]["params"]["synonym_search_top_k"]
|
||||
)
|
||||
res_ent = [] # Debug
|
||||
for res_ent_hash, similarity in similar_ents:
|
||||
if res_ent_hash == ent_hash:
|
||||
# 避免自连接
|
||||
# rich 进度条
|
||||
total = len(ent_hash_list)
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
MofNCompleteColumn(),
|
||||
"•",
|
||||
TimeElapsedColumn(),
|
||||
"<",
|
||||
TimeRemainingColumn(),
|
||||
transient=False,
|
||||
) as progress:
|
||||
task = progress.add_task("同义词连接", total=total)
|
||||
for ent_hash in ent_hash_list:
|
||||
if ent_hash in synonym_hash_set:
|
||||
progress.update(task, advance=1)
|
||||
continue
|
||||
if similarity < global_config["rag"]["params"]["synonym_threshold"]:
|
||||
# 相似度阈值
|
||||
ent = embedding_manager.entities_embedding_store.store.get(ent_hash)
|
||||
assert isinstance(ent, EmbeddingStoreItem)
|
||||
if ent is None:
|
||||
progress.update(task, advance=1)
|
||||
continue
|
||||
node_to_node[(res_ent_hash, ent_hash)] = similarity
|
||||
node_to_node[(ent_hash, res_ent_hash)] = similarity
|
||||
synonym_hash_set.add(res_ent_hash)
|
||||
new_edge_cnt += 1
|
||||
res_ent.append(
|
||||
(
|
||||
embedding_manager.entities_embedding_store.store[res_ent_hash].str,
|
||||
similarity,
|
||||
)
|
||||
) # Debug
|
||||
synonym_result[ent.str] = res_ent
|
||||
# 查询相似实体
|
||||
similar_ents = embedding_manager.entities_embedding_store.search_top_k(
|
||||
ent.embedding, global_config["rag"]["params"]["synonym_search_top_k"]
|
||||
)
|
||||
res_ent = [] # Debug
|
||||
for res_ent_hash, similarity in similar_ents:
|
||||
if res_ent_hash == ent_hash:
|
||||
# 避免自连接
|
||||
continue
|
||||
if similarity < global_config["rag"]["params"]["synonym_threshold"]:
|
||||
# 相似度阈值
|
||||
continue
|
||||
node_to_node[(res_ent_hash, ent_hash)] = similarity
|
||||
node_to_node[(ent_hash, res_ent_hash)] = similarity
|
||||
synonym_hash_set.add(res_ent_hash)
|
||||
new_edge_cnt += 1
|
||||
res_ent.append(
|
||||
(
|
||||
embedding_manager.entities_embedding_store.store[res_ent_hash].str,
|
||||
similarity,
|
||||
)
|
||||
) # Debug
|
||||
synonym_result[ent.str] = res_ent
|
||||
progress.update(task, advance=1)
|
||||
|
||||
for k, v in synonym_result.items():
|
||||
print(f'"{k}"的相似实体为:{v}')
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
import json
|
||||
import os
|
||||
import glob
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
from .lpmmconfig import INVALID_ENTITY, global_config
|
||||
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
|
||||
|
||||
|
||||
def _filter_invalid_entities(entities: List[str]) -> List[str]:
|
||||
"""过滤无效的实体"""
|
||||
@@ -74,12 +78,22 @@ class OpenIE:
|
||||
doc["extracted_triples"] = _filter_invalid_triples(doc["extracted_triples"])
|
||||
|
||||
@staticmethod
|
||||
def _from_dict(data):
|
||||
"""从字典中获取OpenIE对象"""
|
||||
def _from_dict(data_list):
|
||||
"""从多个字典合并OpenIE对象"""
|
||||
# data_list: List[dict]
|
||||
all_docs = []
|
||||
for data in data_list:
|
||||
all_docs.extend(data.get("docs", []))
|
||||
# 重新计算统计
|
||||
sum_phrase_chars = sum([len(e) for chunk in all_docs for e in chunk["extracted_entities"]])
|
||||
sum_phrase_words = sum([len(e.split()) for chunk in all_docs for e in chunk["extracted_entities"]])
|
||||
num_phrases = sum([len(chunk["extracted_entities"]) for chunk in all_docs])
|
||||
avg_ent_chars = round(sum_phrase_chars / num_phrases, 4) if num_phrases else 0
|
||||
avg_ent_words = round(sum_phrase_words / num_phrases, 4) if num_phrases else 0
|
||||
return OpenIE(
|
||||
docs=data["docs"],
|
||||
avg_ent_chars=data["avg_ent_chars"],
|
||||
avg_ent_words=data["avg_ent_words"],
|
||||
docs=all_docs,
|
||||
avg_ent_chars=avg_ent_chars,
|
||||
avg_ent_words=avg_ent_words,
|
||||
)
|
||||
|
||||
def _to_dict(self):
|
||||
@@ -92,12 +106,20 @@ class OpenIE:
|
||||
|
||||
@staticmethod
|
||||
def load() -> "OpenIE":
|
||||
"""从文件中加载OpenIE数据"""
|
||||
with open(global_config["persistence"]["openie_data_path"], "r", encoding="utf-8") as f:
|
||||
data = json.loads(f.read())
|
||||
|
||||
openie_data = OpenIE._from_dict(data)
|
||||
|
||||
"""从OPENIE_DIR下所有json文件合并加载OpenIE数据"""
|
||||
openie_dir = os.path.join(ROOT_PATH, global_config["persistence"]["openie_data_path"])
|
||||
if not os.path.exists(openie_dir):
|
||||
raise Exception(f"OpenIE数据目录不存在: {openie_dir}")
|
||||
json_files = sorted(glob.glob(os.path.join(openie_dir, "*.json")))
|
||||
data_list = []
|
||||
for file in json_files:
|
||||
with open(file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
data_list.append(data)
|
||||
if not data_list:
|
||||
# print(f"111111111111111111111Root Path : \n{ROOT_PATH}")
|
||||
raise Exception(f"未在 {openie_dir} 找到任何OpenIE json文件")
|
||||
openie_data = OpenIE._from_dict(data_list)
|
||||
return openie_data
|
||||
|
||||
@staticmethod
|
||||
@@ -132,3 +154,7 @@ class OpenIE:
|
||||
"""提取原始段落"""
|
||||
raw_paragraph_dict = dict({doc_item["idx"]: doc_item["passage"] for doc_item in self.docs})
|
||||
return raw_paragraph_dict
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试代码
|
||||
print(ROOT_PATH)
|
||||
|
||||
@@ -6,21 +6,25 @@ from .lpmmconfig import global_config
|
||||
from .utils.hash import get_sha256
|
||||
|
||||
|
||||
def load_raw_data() -> tuple[list[str], list[str]]:
|
||||
def load_raw_data(path: str = None) -> tuple[list[str], list[str]]:
|
||||
"""加载原始数据文件
|
||||
|
||||
读取原始数据文件,将原始数据加载到内存中
|
||||
|
||||
Args:
|
||||
path: 可选,指定要读取的json文件绝对路径
|
||||
|
||||
Returns:
|
||||
- raw_data: 原始数据字典
|
||||
- md5_set: 原始数据的SHA256集合
|
||||
- raw_data: 原始数据列表
|
||||
- sha256_list: 原始数据的SHA256集合
|
||||
"""
|
||||
# 读取import.json文件
|
||||
if os.path.exists(global_config["persistence"]["raw_data_path"]) is True:
|
||||
with open(global_config["persistence"]["raw_data_path"], "r", encoding="utf-8") as f:
|
||||
# 读取指定路径或默认路径的json文件
|
||||
json_path = path if path else global_config["persistence"]["raw_data_path"]
|
||||
if os.path.exists(json_path):
|
||||
with open(json_path, "r", encoding="utf-8") as f:
|
||||
import_json = json.loads(f.read())
|
||||
else:
|
||||
raise Exception("原始数据文件读取失败")
|
||||
raise Exception(f"原始数据文件读取失败: {json_path}")
|
||||
# import_json内容示例:
|
||||
# import_json = [
|
||||
# "The capital of China is Beijing. The capital of France is Paris.",
|
||||
|
||||
Reference in New Issue
Block a user