Files
Mofox-Core/src/chat/knowledge/open_ie.py
sunbiz1024 8f4f7d19af ruff
2025-10-06 09:38:16 +08:00

152 lines
5.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import glob
import os
from typing import Any
import orjson
from .knowledge_lib import DATA_PATH, INVALID_ENTITY, ROOT_PATH
# from src.manager.local_store_manager import local_storage
def _filter_invalid_entities(entities: list[str]) -> list[str]:
"""过滤无效的实体"""
valid_entities = set()
for entity in entities:
if not isinstance(entity, str) or entity.strip() == "" or entity in INVALID_ENTITY or entity in valid_entities:
# 非字符串/空字符串/在无效实体列表中/重复
continue
valid_entities.add(entity)
return list(valid_entities)
def _filter_invalid_triples(triples: list[list[str]]) -> list[list[str]]:
"""过滤无效的三元组"""
unique_triples = set()
valid_triples = []
for triple in triples:
if len(triple) != 3 or (
(not isinstance(triple[0], str) or triple[0].strip() == "")
or (not isinstance(triple[1], str) or triple[1].strip() == "")
or (not isinstance(triple[2], str) or triple[2].strip() == "")
):
# 三元组长度不为3或其中存在空值
continue
valid_triple = [str(item) for item in triple]
if tuple(valid_triple) not in unique_triples:
unique_triples.add(tuple(valid_triple))
valid_triples.append(valid_triple)
return valid_triples
class OpenIE:
"""
OpenIE规约的数据格式为如下
{
"docs": [
{
"idx": "文档的唯一标识符通常是文本的SHA256哈希值",
"passage": "文档的原始文本",
"extracted_entities": ["实体1", "实体2", ...],
"extracted_triples": [["主语", "谓语", "宾语"], ...]
},
...
],
"avg_ent_chars": "实体平均字符数",
"avg_ent_words": "实体平均词数"
}
"""
def __init__(
self,
docs: list[dict[str, Any]],
avg_ent_chars,
avg_ent_words,
):
self.docs = docs
self.avg_ent_chars = avg_ent_chars
self.avg_ent_words = avg_ent_words
for doc in self.docs:
# 过滤实体列表
doc["extracted_entities"] = _filter_invalid_entities(doc["extracted_entities"])
# 过滤无效的三元组
doc["extracted_triples"] = _filter_invalid_triples(doc["extracted_triples"])
@staticmethod
def _from_dict(data_list):
"""从多个字典合并OpenIE对象"""
# data_list: List[dict]
all_docs = []
for data in data_list:
all_docs.extend(data.get("docs", []))
# 重新计算统计
sum_phrase_chars = sum([len(e) for chunk in all_docs for e in chunk["extracted_entities"]])
sum_phrase_words = sum([len(e.split()) for chunk in all_docs for e in chunk["extracted_entities"]])
num_phrases = sum([len(chunk["extracted_entities"]) for chunk in all_docs])
avg_ent_chars = round(sum_phrase_chars / num_phrases, 4) if num_phrases else 0
avg_ent_words = round(sum_phrase_words / num_phrases, 4) if num_phrases else 0
return OpenIE(
docs=all_docs,
avg_ent_chars=avg_ent_chars,
avg_ent_words=avg_ent_words,
)
def _to_dict(self):
"""转换为字典"""
return {
"docs": self.docs,
"avg_ent_chars": self.avg_ent_chars,
"avg_ent_words": self.avg_ent_words,
}
@staticmethod
def load() -> "OpenIE":
"""从OPENIE_DIR下所有json文件合并加载OpenIE数据"""
openie_dir = os.path.join(DATA_PATH, "openie")
if not os.path.exists(openie_dir):
raise Exception(f"OpenIE数据目录不存在: {openie_dir}")
json_files = sorted(glob.glob(os.path.join(openie_dir, "*.json")))
data_list = []
for file in json_files:
with open(file, encoding="utf-8") as f:
data = orjson.loads(f.read())
data_list.append(data)
if not data_list:
# print(f"111111111111111111111Root Path : \n{ROOT_PATH}")
raise Exception(f"未在 {openie_dir} 找到任何OpenIE json文件")
openie_data = OpenIE._from_dict(data_list)
return openie_data
def extract_entity_dict(self):
"""提取实体列表"""
ner_output_dict = {
doc_item["idx"]: doc_item["extracted_entities"]
for doc_item in self.docs
if len(doc_item["extracted_entities"]) > 0
}
return ner_output_dict
def extract_triple_dict(self):
"""提取三元组列表"""
triple_output_dict = {
doc_item["idx"]: doc_item["extracted_triples"]
for doc_item in self.docs
if len(doc_item["extracted_triples"]) > 0
}
return triple_output_dict
def extract_raw_paragraph_dict(self):
"""提取原始段落"""
raw_paragraph_dict = {doc_item["idx"]: doc_item["passage"] for doc_item in self.docs}
return raw_paragraph_dict
if __name__ == "__main__":
# 测试代码
print(ROOT_PATH)