feat: 添加嵌入模型一致性校验功能,优化错误处理

This commit is contained in:
墨梓柒
2025-05-04 00:32:10 +08:00
parent fe9a2315a5
commit b8d14add91
2 changed files with 94 additions and 1 deletions

View File

@@ -30,6 +30,7 @@ OPENIE_DIR = (
logger = get_module_logger("OpenIE导入") logger = get_module_logger("OpenIE导入")
def hash_deduplicate( def hash_deduplicate(
raw_paragraphs: dict[str, str], raw_paragraphs: dict[str, str],
triple_list_data: dict[str, list[list[str]]], triple_list_data: dict[str, list[list[str]]],
@@ -167,6 +168,7 @@ def main():
global_config["llm_providers"][key]["api_key"], global_config["llm_providers"][key]["api_key"],
) )
# 初始化Embedding库 # 初始化Embedding库
embed_manager = EmbeddingManager(llm_client_list[global_config["embedding"]["provider"]]) embed_manager = EmbeddingManager(llm_client_list[global_config["embedding"]["provider"]])
logger.info("正在从文件加载Embedding库") logger.info("正在从文件加载Embedding库")
@@ -174,6 +176,11 @@ def main():
embed_manager.load_from_file() embed_manager.load_from_file()
except Exception as e: except Exception as e:
logger.error("从文件加载Embedding库时发生错误{}".format(e)) logger.error("从文件加载Embedding库时发生错误{}".format(e))
if "嵌入模型与本地存储不一致" in str(e):
logger.error("检测到嵌入模型与本地存储不一致,已终止导入。请检查模型设置或清空嵌入库后重试。")
logger.error("请保证你的嵌入模型从未更改,并且在导入时使用相同的模型")
# print("检测到嵌入模型与本地存储不一致,已终止导入。请检查模型设置或清空嵌入库后重试。")
sys.exit(1)
logger.error("如果你是第一次导入知识,请忽略此错误") logger.error("如果你是第一次导入知识,请忽略此错误")
logger.info("Embedding库加载完成") logger.info("Embedding库加载完成")
# 初始化KG # 初始化KG

View File

@@ -1,6 +1,7 @@
from dataclasses import dataclass from dataclasses import dataclass
import json import json
import os import os
import math
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import numpy as np import numpy as np
@@ -25,9 +26,39 @@ from rich.progress import (
) )
install(extra_lines=3) install(extra_lines=3)
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
TOTAL_EMBEDDING_TIMES = 3 # 统计嵌入次数 TOTAL_EMBEDDING_TIMES = 3 # 统计嵌入次数
# 嵌入模型测试字符串,测试模型一致性,来自开发群的聊天记录
# 这些字符串的嵌入结果应该是固定的,不能随时间变化
EMBEDDING_TEST_STRINGS = [
"阿卡伊真的太好玩了,神秘性感大女同等着你",
"你怎么知道我arc12.64了",
"我是蕾缪乐小姐的狗",
"关注Oct谢谢喵",
"不是w6我不草",
"关注千石可乐谢谢喵",
"来玩CLANNADAIR樱之诗樱之刻谢谢喵",
"关注墨梓柒谢谢喵",
"Ciallo~",
"来玩巧克甜恋谢谢喵",
"水印",
"我也在纠结晚饭,铁锅炒鸡听着就香!",
"test你妈喵"
]
EMBEDDING_TEST_FILE = os.path.join(ROOT_PATH, "data", "embedding_model_test.json")
EMBEDDING_SIM_THRESHOLD = 0.99
def cosine_similarity(a, b):
# 计算余弦相似度
dot = sum(x * y for x, y in zip(a, b))
norm_a = math.sqrt(sum(x * x for x in a))
norm_b = math.sqrt(sum(x * x for x in b))
if norm_a == 0 or norm_b == 0:
return 0.0
return dot / (norm_a * norm_b)
@dataclass @dataclass
class EmbeddingStoreItem: class EmbeddingStoreItem:
@@ -64,6 +95,46 @@ class EmbeddingStore:
def _get_embedding(self, s: str) -> List[float]: def _get_embedding(self, s: str) -> List[float]:
return self.llm_client.send_embedding_request(global_config["embedding"]["model"], s) return self.llm_client.send_embedding_request(global_config["embedding"]["model"], s)
def get_test_file_path(self):
return EMBEDDING_TEST_FILE
def save_embedding_test_vectors(self):
"""保存测试字符串的嵌入到本地"""
test_vectors = {}
for idx, s in enumerate(EMBEDDING_TEST_STRINGS):
test_vectors[str(idx)] = self._get_embedding(s)
with open(self.get_test_file_path(), "w", encoding="utf-8") as f:
json.dump(test_vectors, f, ensure_ascii=False, indent=2)
def load_embedding_test_vectors(self):
"""加载本地保存的测试字符串嵌入"""
path = self.get_test_file_path()
if not os.path.exists(path):
return None
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
def check_embedding_model_consistency(self):
"""校验当前模型与本地嵌入模型是否一致"""
local_vectors = self.load_embedding_test_vectors()
if local_vectors is None:
logger.warning("未检测到本地嵌入模型测试文件,将保存当前模型的测试嵌入。")
self.save_embedding_test_vectors()
return True
for idx, s in enumerate(EMBEDDING_TEST_STRINGS):
local_emb = local_vectors.get(str(idx))
if local_emb is None:
logger.warning("本地嵌入模型测试文件缺失部分测试字符串,将重新保存。")
self.save_embedding_test_vectors()
return True
new_emb = self._get_embedding(s)
sim = cosine_similarity(local_emb, new_emb)
if sim < EMBEDDING_SIM_THRESHOLD:
logger.error("嵌入模型一致性校验失败")
return False
logger.info("嵌入模型一致性校验通过。")
return True
def batch_insert_strs(self, strs: List[str], times: int) -> None: def batch_insert_strs(self, strs: List[str], times: int) -> None:
"""向库中存入字符串""" """向库中存入字符串"""
total = len(strs) total = len(strs)
@@ -216,6 +287,17 @@ class EmbeddingManager:
) )
self.stored_pg_hashes = set() self.stored_pg_hashes = set()
def check_all_embedding_model_consistency(self):
"""对所有嵌入库做模型一致性校验"""
for store in [
self.paragraphs_embedding_store,
self.entities_embedding_store,
self.relation_embedding_store,
]:
if not store.check_embedding_model_consistency():
return False
return True
def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]): def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]):
"""将段落编码存入Embedding库""" """将段落编码存入Embedding库"""
self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values()), times=1) self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values()), times=1)
@@ -239,6 +321,8 @@ class EmbeddingManager:
def load_from_file(self): def load_from_file(self):
"""从文件加载""" """从文件加载"""
if not self.check_all_embedding_model_consistency():
raise Exception("嵌入模型与本地存储不一致,请检查模型设置或清空嵌入库后重试。")
self.paragraphs_embedding_store.load_from_file() self.paragraphs_embedding_store.load_from_file()
self.entities_embedding_store.load_from_file() self.entities_embedding_store.load_from_file()
self.relation_embedding_store.load_from_file() self.relation_embedding_store.load_from_file()
@@ -250,6 +334,8 @@ class EmbeddingManager:
raw_paragraphs: Dict[str, str], raw_paragraphs: Dict[str, str],
triple_list_data: Dict[str, List[List[str]]], triple_list_data: Dict[str, List[List[str]]],
): ):
if not self.check_all_embedding_model_consistency():
raise Exception("嵌入模型与本地存储不一致,请检查模型设置或清空嵌入库后重试。")
"""存储新的数据集""" """存储新的数据集"""
self._store_pg_into_embedding(raw_paragraphs) self._store_pg_into_embedding(raw_paragraphs)
self._store_ent_into_embedding(triple_list_data) self._store_ent_into_embedding(triple_list_data)