ruff fix但指定了--unsafe-fixes

This commit is contained in:
minecraft1024a
2025-10-05 21:48:32 +08:00
committed by Windpicker-owo
parent 04feb585b4
commit 2a89efe47a
76 changed files with 301 additions and 316 deletions

View File

@@ -495,7 +495,7 @@ class EmbeddingStore:
"""重新构建Faiss索引以余弦相似度为度量"""
# 获取所有的embedding
array = []
self.idx2hash = dict()
self.idx2hash = {}
for key in self.store:
array.append(self.store[key].embedding)
self.idx2hash[str(len(array) - 1)] = key

View File

@@ -33,7 +33,7 @@ def _extract_json_from_text(text: str):
if isinstance(parsed_json, dict):
# 如果字典只有一个键,并且值是列表,返回那个列表
if len(parsed_json) == 1:
value = list(parsed_json.values())[0]
value = next(iter(parsed_json.values()))
if isinstance(value, list):
return value
return parsed_json

View File

@@ -91,7 +91,7 @@ class KGManager:
# 加载实体计数
ent_cnt_df = pd.read_parquet(self.ent_cnt_data_path, engine="pyarrow")
self.ent_appear_cnt = dict({row["hash_key"]: row["appear_cnt"] for _, row in ent_cnt_df.iterrows()})
self.ent_appear_cnt = {row["hash_key"]: row["appear_cnt"] for _, row in ent_cnt_df.iterrows()}
# 加载KG
self.graph = di_graph.load_from_file(self.graph_data_path)
@@ -290,7 +290,7 @@ class KGManager:
embedding_manager: EmbeddingManager对象
"""
# 实体之间的联系
node_to_node = dict()
node_to_node = {}
# 构建实体节点之间的关系,同时统计实体出现次数
logger.info("正在构建KG实体节点之间的关系同时统计实体出现次数")
@@ -379,8 +379,8 @@ class KGManager:
top_k = global_config.lpmm_knowledge.qa_ent_filter_top_k
if len(ent_mean_scores) > top_k:
# 从大到小排序取后len - k个
ent_mean_scores = {k: v for k, v in sorted(ent_mean_scores.items(), key=lambda item: item[1], reverse=True)}
for ent_hash, _ in ent_mean_scores.items():
ent_mean_scores = dict(sorted(ent_mean_scores.items(), key=lambda item: item[1], reverse=True))
for ent_hash in ent_mean_scores.keys():
# 删除被淘汰的实体节点权重设置
del ent_weights[ent_hash]
del top_k, ent_mean_scores

View File

@@ -124,29 +124,25 @@ class OpenIE:
def extract_entity_dict(self):
"""提取实体列表"""
ner_output_dict = dict(
{
ner_output_dict = {
doc_item["idx"]: doc_item["extracted_entities"]
for doc_item in self.docs
if len(doc_item["extracted_entities"]) > 0
}
)
return ner_output_dict
def extract_triple_dict(self):
"""提取三元组列表"""
triple_output_dict = dict(
{
triple_output_dict = {
doc_item["idx"]: doc_item["extracted_triples"]
for doc_item in self.docs
if len(doc_item["extracted_triples"]) > 0
}
)
return triple_output_dict
def extract_raw_paragraph_dict(self):
"""提取原始段落"""
raw_paragraph_dict = dict({doc_item["idx"]: doc_item["passage"] for doc_item in self.docs})
raw_paragraph_dict = {doc_item["idx"]: doc_item["passage"] for doc_item in self.docs}
return raw_paragraph_dict

View File

@@ -18,13 +18,11 @@ def dyn_select_top_k(
normalized_score = []
for score_item in sorted_score:
normalized_score.append(
tuple(
[
(
score_item[0],
score_item[1],
(score_item[1] - min_score) / (max_score - min_score),
]
)
)
)
# 寻找跳变点score变化最大的位置