feat: 添加嵌入模型一致性校验功能,优化错误处理
This commit is contained in:
@@ -30,6 +30,7 @@ OPENIE_DIR = (
|
|||||||
logger = get_module_logger("OpenIE导入")
|
logger = get_module_logger("OpenIE导入")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def hash_deduplicate(
|
def hash_deduplicate(
|
||||||
raw_paragraphs: dict[str, str],
|
raw_paragraphs: dict[str, str],
|
||||||
triple_list_data: dict[str, list[list[str]]],
|
triple_list_data: dict[str, list[list[str]]],
|
||||||
@@ -167,6 +168,7 @@ def main():
|
|||||||
global_config["llm_providers"][key]["api_key"],
|
global_config["llm_providers"][key]["api_key"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# 初始化Embedding库
|
# 初始化Embedding库
|
||||||
embed_manager = EmbeddingManager(llm_client_list[global_config["embedding"]["provider"]])
|
embed_manager = EmbeddingManager(llm_client_list[global_config["embedding"]["provider"]])
|
||||||
logger.info("正在从文件加载Embedding库")
|
logger.info("正在从文件加载Embedding库")
|
||||||
@@ -174,6 +176,11 @@ def main():
|
|||||||
embed_manager.load_from_file()
|
embed_manager.load_from_file()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("从文件加载Embedding库时发生错误:{}".format(e))
|
logger.error("从文件加载Embedding库时发生错误:{}".format(e))
|
||||||
|
if "嵌入模型与本地存储不一致" in str(e):
|
||||||
|
logger.error("检测到嵌入模型与本地存储不一致,已终止导入。请检查模型设置或清空嵌入库后重试。")
|
||||||
|
logger.error("请保证你的嵌入模型从未更改,并且在导入时使用相同的模型")
|
||||||
|
# print("检测到嵌入模型与本地存储不一致,已终止导入。请检查模型设置或清空嵌入库后重试。")
|
||||||
|
sys.exit(1)
|
||||||
logger.error("如果你是第一次导入知识,请忽略此错误")
|
logger.error("如果你是第一次导入知识,请忽略此错误")
|
||||||
logger.info("Embedding库加载完成")
|
logger.info("Embedding库加载完成")
|
||||||
# 初始化KG
|
# 初始化KG
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import math
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -25,9 +26,39 @@ from rich.progress import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
|
||||||
TOTAL_EMBEDDING_TIMES = 3 # 统计嵌入次数
|
TOTAL_EMBEDDING_TIMES = 3 # 统计嵌入次数
|
||||||
|
|
||||||
|
# 嵌入模型测试字符串,测试模型一致性,来自开发群的聊天记录
|
||||||
|
# 这些字符串的嵌入结果应该是固定的,不能随时间变化
|
||||||
|
EMBEDDING_TEST_STRINGS = [
|
||||||
|
"阿卡伊真的太好玩了,神秘性感大女同等着你",
|
||||||
|
"你怎么知道我arc12.64了",
|
||||||
|
"我是蕾缪乐小姐的狗",
|
||||||
|
"关注Oct谢谢喵",
|
||||||
|
"不是w6我不草",
|
||||||
|
"关注千石可乐谢谢喵",
|
||||||
|
"来玩CLANNAD,AIR,樱之诗,樱之刻谢谢喵",
|
||||||
|
"关注墨梓柒谢谢喵",
|
||||||
|
"Ciallo~",
|
||||||
|
"来玩巧克甜恋谢谢喵",
|
||||||
|
"水印",
|
||||||
|
"我也在纠结晚饭,铁锅炒鸡听着就香!",
|
||||||
|
"test你妈喵"
|
||||||
|
]
|
||||||
|
EMBEDDING_TEST_FILE = os.path.join(ROOT_PATH, "data", "embedding_model_test.json")
|
||||||
|
EMBEDDING_SIM_THRESHOLD = 0.99
|
||||||
|
|
||||||
|
|
||||||
|
def cosine_similarity(a, b):
|
||||||
|
# 计算余弦相似度
|
||||||
|
dot = sum(x * y for x, y in zip(a, b))
|
||||||
|
norm_a = math.sqrt(sum(x * x for x in a))
|
||||||
|
norm_b = math.sqrt(sum(x * x for x in b))
|
||||||
|
if norm_a == 0 or norm_b == 0:
|
||||||
|
return 0.0
|
||||||
|
return dot / (norm_a * norm_b)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EmbeddingStoreItem:
|
class EmbeddingStoreItem:
|
||||||
@@ -64,6 +95,46 @@ class EmbeddingStore:
|
|||||||
def _get_embedding(self, s: str) -> List[float]:
|
def _get_embedding(self, s: str) -> List[float]:
|
||||||
return self.llm_client.send_embedding_request(global_config["embedding"]["model"], s)
|
return self.llm_client.send_embedding_request(global_config["embedding"]["model"], s)
|
||||||
|
|
||||||
|
def get_test_file_path(self):
|
||||||
|
return EMBEDDING_TEST_FILE
|
||||||
|
|
||||||
|
def save_embedding_test_vectors(self):
|
||||||
|
"""保存测试字符串的嵌入到本地"""
|
||||||
|
test_vectors = {}
|
||||||
|
for idx, s in enumerate(EMBEDDING_TEST_STRINGS):
|
||||||
|
test_vectors[str(idx)] = self._get_embedding(s)
|
||||||
|
with open(self.get_test_file_path(), "w", encoding="utf-8") as f:
|
||||||
|
json.dump(test_vectors, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
def load_embedding_test_vectors(self):
|
||||||
|
"""加载本地保存的测试字符串嵌入"""
|
||||||
|
path = self.get_test_file_path()
|
||||||
|
if not os.path.exists(path):
|
||||||
|
return None
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
def check_embedding_model_consistency(self):
|
||||||
|
"""校验当前模型与本地嵌入模型是否一致"""
|
||||||
|
local_vectors = self.load_embedding_test_vectors()
|
||||||
|
if local_vectors is None:
|
||||||
|
logger.warning("未检测到本地嵌入模型测试文件,将保存当前模型的测试嵌入。")
|
||||||
|
self.save_embedding_test_vectors()
|
||||||
|
return True
|
||||||
|
for idx, s in enumerate(EMBEDDING_TEST_STRINGS):
|
||||||
|
local_emb = local_vectors.get(str(idx))
|
||||||
|
if local_emb is None:
|
||||||
|
logger.warning("本地嵌入模型测试文件缺失部分测试字符串,将重新保存。")
|
||||||
|
self.save_embedding_test_vectors()
|
||||||
|
return True
|
||||||
|
new_emb = self._get_embedding(s)
|
||||||
|
sim = cosine_similarity(local_emb, new_emb)
|
||||||
|
if sim < EMBEDDING_SIM_THRESHOLD:
|
||||||
|
logger.error("嵌入模型一致性校验失败")
|
||||||
|
return False
|
||||||
|
logger.info("嵌入模型一致性校验通过。")
|
||||||
|
return True
|
||||||
|
|
||||||
def batch_insert_strs(self, strs: List[str], times: int) -> None:
|
def batch_insert_strs(self, strs: List[str], times: int) -> None:
|
||||||
"""向库中存入字符串"""
|
"""向库中存入字符串"""
|
||||||
total = len(strs)
|
total = len(strs)
|
||||||
@@ -216,6 +287,17 @@ class EmbeddingManager:
|
|||||||
)
|
)
|
||||||
self.stored_pg_hashes = set()
|
self.stored_pg_hashes = set()
|
||||||
|
|
||||||
|
def check_all_embedding_model_consistency(self):
|
||||||
|
"""对所有嵌入库做模型一致性校验"""
|
||||||
|
for store in [
|
||||||
|
self.paragraphs_embedding_store,
|
||||||
|
self.entities_embedding_store,
|
||||||
|
self.relation_embedding_store,
|
||||||
|
]:
|
||||||
|
if not store.check_embedding_model_consistency():
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]):
|
def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]):
|
||||||
"""将段落编码存入Embedding库"""
|
"""将段落编码存入Embedding库"""
|
||||||
self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values()), times=1)
|
self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values()), times=1)
|
||||||
@@ -239,6 +321,8 @@ class EmbeddingManager:
|
|||||||
|
|
||||||
def load_from_file(self):
|
def load_from_file(self):
|
||||||
"""从文件加载"""
|
"""从文件加载"""
|
||||||
|
if not self.check_all_embedding_model_consistency():
|
||||||
|
raise Exception("嵌入模型与本地存储不一致,请检查模型设置或清空嵌入库后重试。")
|
||||||
self.paragraphs_embedding_store.load_from_file()
|
self.paragraphs_embedding_store.load_from_file()
|
||||||
self.entities_embedding_store.load_from_file()
|
self.entities_embedding_store.load_from_file()
|
||||||
self.relation_embedding_store.load_from_file()
|
self.relation_embedding_store.load_from_file()
|
||||||
@@ -250,6 +334,8 @@ class EmbeddingManager:
|
|||||||
raw_paragraphs: Dict[str, str],
|
raw_paragraphs: Dict[str, str],
|
||||||
triple_list_data: Dict[str, List[List[str]]],
|
triple_list_data: Dict[str, List[List[str]]],
|
||||||
):
|
):
|
||||||
|
if not self.check_all_embedding_model_consistency():
|
||||||
|
raise Exception("嵌入模型与本地存储不一致,请检查模型设置或清空嵌入库后重试。")
|
||||||
"""存储新的数据集"""
|
"""存储新的数据集"""
|
||||||
self._store_pg_into_embedding(raw_paragraphs)
|
self._store_pg_into_embedding(raw_paragraphs)
|
||||||
self._store_ent_into_embedding(triple_list_data)
|
self._store_ent_into_embedding(triple_list_data)
|
||||||
|
|||||||
Reference in New Issue
Block a user