ruff fix但指定了--unsafe-fixes
This commit is contained in:
committed by
Windpicker-owo
parent
04feb585b4
commit
2a89efe47a
@@ -495,7 +495,7 @@ class EmbeddingStore:
|
||||
"""重新构建Faiss索引,以余弦相似度为度量"""
|
||||
# 获取所有的embedding
|
||||
array = []
|
||||
self.idx2hash = dict()
|
||||
self.idx2hash = {}
|
||||
for key in self.store:
|
||||
array.append(self.store[key].embedding)
|
||||
self.idx2hash[str(len(array) - 1)] = key
|
||||
|
||||
@@ -33,7 +33,7 @@ def _extract_json_from_text(text: str):
|
||||
if isinstance(parsed_json, dict):
|
||||
# 如果字典只有一个键,并且值是列表,返回那个列表
|
||||
if len(parsed_json) == 1:
|
||||
value = list(parsed_json.values())[0]
|
||||
value = next(iter(parsed_json.values()))
|
||||
if isinstance(value, list):
|
||||
return value
|
||||
return parsed_json
|
||||
|
||||
@@ -91,7 +91,7 @@ class KGManager:
|
||||
|
||||
# 加载实体计数
|
||||
ent_cnt_df = pd.read_parquet(self.ent_cnt_data_path, engine="pyarrow")
|
||||
self.ent_appear_cnt = dict({row["hash_key"]: row["appear_cnt"] for _, row in ent_cnt_df.iterrows()})
|
||||
self.ent_appear_cnt = {row["hash_key"]: row["appear_cnt"] for _, row in ent_cnt_df.iterrows()}
|
||||
|
||||
# 加载KG
|
||||
self.graph = di_graph.load_from_file(self.graph_data_path)
|
||||
@@ -290,7 +290,7 @@ class KGManager:
|
||||
embedding_manager: EmbeddingManager对象
|
||||
"""
|
||||
# 实体之间的联系
|
||||
node_to_node = dict()
|
||||
node_to_node = {}
|
||||
|
||||
# 构建实体节点之间的关系,同时统计实体出现次数
|
||||
logger.info("正在构建KG实体节点之间的关系,同时统计实体出现次数")
|
||||
@@ -379,8 +379,8 @@ class KGManager:
|
||||
top_k = global_config.lpmm_knowledge.qa_ent_filter_top_k
|
||||
if len(ent_mean_scores) > top_k:
|
||||
# 从大到小排序,取后len - k个
|
||||
ent_mean_scores = {k: v for k, v in sorted(ent_mean_scores.items(), key=lambda item: item[1], reverse=True)}
|
||||
for ent_hash, _ in ent_mean_scores.items():
|
||||
ent_mean_scores = dict(sorted(ent_mean_scores.items(), key=lambda item: item[1], reverse=True))
|
||||
for ent_hash in ent_mean_scores.keys():
|
||||
# 删除被淘汰的实体节点权重设置
|
||||
del ent_weights[ent_hash]
|
||||
del top_k, ent_mean_scores
|
||||
|
||||
@@ -124,29 +124,25 @@ class OpenIE:
|
||||
|
||||
def extract_entity_dict(self):
|
||||
"""提取实体列表"""
|
||||
ner_output_dict = dict(
|
||||
{
|
||||
ner_output_dict = {
|
||||
doc_item["idx"]: doc_item["extracted_entities"]
|
||||
for doc_item in self.docs
|
||||
if len(doc_item["extracted_entities"]) > 0
|
||||
}
|
||||
)
|
||||
return ner_output_dict
|
||||
|
||||
def extract_triple_dict(self):
|
||||
"""提取三元组列表"""
|
||||
triple_output_dict = dict(
|
||||
{
|
||||
triple_output_dict = {
|
||||
doc_item["idx"]: doc_item["extracted_triples"]
|
||||
for doc_item in self.docs
|
||||
if len(doc_item["extracted_triples"]) > 0
|
||||
}
|
||||
)
|
||||
return triple_output_dict
|
||||
|
||||
def extract_raw_paragraph_dict(self):
|
||||
"""提取原始段落"""
|
||||
raw_paragraph_dict = dict({doc_item["idx"]: doc_item["passage"] for doc_item in self.docs})
|
||||
raw_paragraph_dict = {doc_item["idx"]: doc_item["passage"] for doc_item in self.docs}
|
||||
return raw_paragraph_dict
|
||||
|
||||
|
||||
|
||||
@@ -18,13 +18,11 @@ def dyn_select_top_k(
|
||||
normalized_score = []
|
||||
for score_item in sorted_score:
|
||||
normalized_score.append(
|
||||
tuple(
|
||||
[
|
||||
(
|
||||
score_item[0],
|
||||
score_item[1],
|
||||
(score_item[1] - min_score) / (max_score - min_score),
|
||||
]
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# 寻找跳变点:score变化最大的位置
|
||||
|
||||
Reference in New Issue
Block a user