@@ -1,6 +1,7 @@
from dataclasses import dataclass
import json
import os
import math
from typing import Dict , List , Tuple
import numpy as np
@@ -25,9 +26,39 @@ from rich.progress import (
)
install ( extra_lines = 3 )
ROOT_PATH = os . path . abspath ( os . path . join ( os . path . dirname ( __file__ ) , " .. " , " .. " , " .. " , " .. " ) )
TOTAL_EMBEDDING_TIMES = 3 # 统计嵌入次数
# 嵌入模型测试字符串,测试模型一致性,来自开发群的聊天记录
# 这些字符串的嵌入结果应该是固定的,不能随时间变化
EMBEDDING_TEST_STRINGS = [
" 阿卡伊真的太好玩了,神秘性感大女同等着你 " ,
" 你怎么知道我arc12.64了 " ,
" 我是蕾缪乐小姐的狗 " ,
" 关注Oct谢谢喵 " ,
" 不是w6我不草 " ,
" 关注千石可乐谢谢喵 " ,
" 来玩CLANNAD, AIR, 樱之诗, 樱之刻谢谢喵 " ,
" 关注墨梓柒谢谢喵 " ,
" Ciallo~ " ,
" 来玩巧克甜恋谢谢喵 " ,
" 水印 " ,
" 我也在纠结晚饭,铁锅炒鸡听着就香! " ,
" test你妈喵 " ,
]
EMBEDDING_TEST_FILE = os . path . join ( ROOT_PATH , " data " , " embedding_model_test.json " )
EMBEDDING_SIM_THRESHOLD = 0.99
def cosine_similarity ( a , b ) :
# 计算余弦相似度
dot = sum ( x * y for x , y in zip ( a , b ) )
norm_a = math . sqrt ( sum ( x * x for x in a ) )
norm_b = math . sqrt ( sum ( x * x for x in b ) )
if norm_a == 0 or norm_b == 0 :
return 0.0
return dot / ( norm_a * norm_b )
@dataclass
class EmbeddingStoreItem :
@@ -64,6 +95,46 @@ class EmbeddingStore:
def _get_embedding ( self , s : str ) - > List [ float ] :
return self . llm_client . send_embedding_request ( global_config [ " embedding " ] [ " model " ] , s )
def get_test_file_path ( self ) :
return EMBEDDING_TEST_FILE
def save_embedding_test_vectors ( self ) :
""" 保存测试字符串的嵌入到本地 """
test_vectors = { }
for idx , s in enumerate ( EMBEDDING_TEST_STRINGS ) :
test_vectors [ str ( idx ) ] = self . _get_embedding ( s )
with open ( self . get_test_file_path ( ) , " w " , encoding = " utf-8 " ) as f :
json . dump ( test_vectors , f , ensure_ascii = False , indent = 2 )
def load_embedding_test_vectors ( self ) :
""" 加载本地保存的测试字符串嵌入 """
path = self . get_test_file_path ( )
if not os . path . exists ( path ) :
return None
with open ( path , " r " , encoding = " utf-8 " ) as f :
return json . load ( f )
def check_embedding_model_consistency ( self ) :
""" 校验当前模型与本地嵌入模型是否一致 """
local_vectors = self . load_embedding_test_vectors ( )
if local_vectors is None :
logger . warning ( " 未检测到本地嵌入模型测试文件,将保存当前模型的测试嵌入。 " )
self . save_embedding_test_vectors ( )
return True
for idx , s in enumerate ( EMBEDDING_TEST_STRINGS ) :
local_emb = local_vectors . get ( str ( idx ) )
if local_emb is None :
logger . warning ( " 本地嵌入模型测试文件缺失部分测试字符串,将重新保存。 " )
self . save_embedding_test_vectors ( )
return True
new_emb = self . _get_embedding ( s )
sim = cosine_similarity ( local_emb , new_emb )
if sim < EMBEDDING_SIM_THRESHOLD :
logger . error ( " 嵌入模型一致性校验失败 " )
return False
logger . info ( " 嵌入模型一致性校验通过。 " )
return True
def batch_insert_strs ( self , strs : List [ str ] , times : int ) - > None :
""" 向库中存入字符串 """
total = len ( strs )
@@ -216,6 +287,17 @@ class EmbeddingManager:
)
self . stored_pg_hashes = set ( )
def check_all_embedding_model_consistency ( self ) :
""" 对所有嵌入库做模型一致性校验 """
for store in [
self . paragraphs_embedding_store ,
self . entities_embedding_store ,
self . relation_embedding_store ,
] :
if not store . check_embedding_model_consistency ( ) :
return False
return True
def _store_pg_into_embedding ( self , raw_paragraphs : Dict [ str , str ] ) :
""" 将段落编码存入Embedding库 """
self . paragraphs_embedding_store . batch_insert_strs ( list ( raw_paragraphs . values ( ) ) , times = 1 )
@@ -239,6 +321,8 @@ class EmbeddingManager:
def load_from_file ( self ) :
""" 从文件加载 """
if not self . check_all_embedding_model_consistency ( ) :
raise Exception ( " 嵌入模型与本地存储不一致,请检查模型设置或清空嵌入库后重试。 " )
self . paragraphs_embedding_store . load_from_file ( )
self . entities_embedding_store . load_from_file ( )
self . relation_embedding_store . load_from_file ( )
@@ -250,6 +334,8 @@ class EmbeddingManager:
raw_paragraphs : Dict [ str , str ] ,
triple_list_data : Dict [ str , List [ List [ str ] ] ] ,
) :
if not self . check_all_embedding_model_consistency ( ) :
raise Exception ( " 嵌入模型与本地存储不一致,请检查模型设置或清空嵌入库后重试。 " )
""" 存储新的数据集 """
self . _store_pg_into_embedding ( raw_paragraphs )
self . _store_ent_into_embedding ( triple_list_data )