feat: 知识库小重构

This commit is contained in:
墨梓柒
2025-06-07 13:46:04 +08:00
parent 1a713ed0d9
commit 0038d9ff60
17 changed files with 6 additions and 6 deletions

View File

@@ -0,0 +1,376 @@
from dataclasses import dataclass
import json
import os
import math
from typing import Dict, List, Tuple
import numpy as np
import pandas as pd
# import tqdm
import faiss
from .llm_client import LLMClient
from .lpmmconfig import ENT_NAMESPACE, PG_NAMESPACE, REL_NAMESPACE, global_config
from src.chat.knowledge.utils import get_sha256
from .global_logger import logger
from rich.traceback import install
from rich.progress import (
Progress,
BarColumn,
TimeElapsedColumn,
TimeRemainingColumn,
TaskProgressColumn,
MofNCompleteColumn,
SpinnerColumn,
TextColumn,
)
install(extra_lines=3)
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
EMBEDDING_DATA_DIR = (
os.path.join(ROOT_PATH, "data", "embedding")
if global_config["persistence"]["embedding_data_dir"] is None
else os.path.join(ROOT_PATH, global_config["persistence"]["embedding_data_dir"])
)
EMBEDDING_DATA_DIR_STR = str(EMBEDDING_DATA_DIR).replace("\\", "/")
TOTAL_EMBEDDING_TIMES = 3 # 统计嵌入次数
# 嵌入模型测试字符串,测试模型一致性,来自开发群的聊天记录
# 这些字符串的嵌入结果应该是固定的,不能随时间变化
EMBEDDING_TEST_STRINGS = [
"阿卡伊真的太好玩了,神秘性感大女同等着你",
"你怎么知道我arc12.64了",
"我是蕾缪乐小姐的狗",
"关注Oct谢谢喵",
"不是w6我不草",
"关注千石可乐谢谢喵",
"来玩CLANNADAIR樱之诗樱之刻谢谢喵",
"关注墨梓柒谢谢喵",
"Ciallo~",
"来玩巧克甜恋谢谢喵",
"水印",
"我也在纠结晚饭,铁锅炒鸡听着就香!",
"test你妈喵",
]
EMBEDDING_TEST_FILE = os.path.join(ROOT_PATH, "data", "embedding_model_test.json")
EMBEDDING_SIM_THRESHOLD = 0.99
def cosine_similarity(a, b):
# 计算余弦相似度
dot = sum(x * y for x, y in zip(a, b))
norm_a = math.sqrt(sum(x * x for x in a))
norm_b = math.sqrt(sum(x * x for x in b))
if norm_a == 0 or norm_b == 0:
return 0.0
return dot / (norm_a * norm_b)
@dataclass
class EmbeddingStoreItem:
"""嵌入库中的项"""
def __init__(self, item_hash: str, embedding: List[float], content: str):
self.hash = item_hash
self.embedding = embedding
self.str = content
def to_dict(self) -> dict:
"""转为dict"""
return {
"hash": self.hash,
"embedding": self.embedding,
"str": self.str,
}
class EmbeddingStore:
def __init__(self, llm_client: LLMClient, namespace: str, dir_path: str):
self.namespace = namespace
self.llm_client = llm_client
self.dir = dir_path
self.embedding_file_path = dir_path + "/" + namespace + ".parquet"
self.index_file_path = dir_path + "/" + namespace + ".index"
self.idx2hash_file_path = dir_path + "/" + namespace + "_i2h.json"
self.store = dict()
self.faiss_index = None
self.idx2hash = None
def _get_embedding(self, s: str) -> List[float]:
return self.llm_client.send_embedding_request(global_config["embedding"]["model"], s)
def get_test_file_path(self):
return EMBEDDING_TEST_FILE
def save_embedding_test_vectors(self):
"""保存测试字符串的嵌入到本地"""
test_vectors = {}
for idx, s in enumerate(EMBEDDING_TEST_STRINGS):
test_vectors[str(idx)] = self._get_embedding(s)
with open(self.get_test_file_path(), "w", encoding="utf-8") as f:
json.dump(test_vectors, f, ensure_ascii=False, indent=2)
def load_embedding_test_vectors(self):
"""加载本地保存的测试字符串嵌入"""
path = self.get_test_file_path()
if not os.path.exists(path):
return None
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
def check_embedding_model_consistency(self):
"""校验当前模型与本地嵌入模型是否一致"""
local_vectors = self.load_embedding_test_vectors()
if local_vectors is None:
logger.warning("未检测到本地嵌入模型测试文件,将保存当前模型的测试嵌入。")
self.save_embedding_test_vectors()
return True
for idx, s in enumerate(EMBEDDING_TEST_STRINGS):
local_emb = local_vectors.get(str(idx))
if local_emb is None:
logger.warning("本地嵌入模型测试文件缺失部分测试字符串,将重新保存。")
self.save_embedding_test_vectors()
return True
new_emb = self._get_embedding(s)
sim = cosine_similarity(local_emb, new_emb)
if sim < EMBEDDING_SIM_THRESHOLD:
logger.error("嵌入模型一致性校验失败")
return False
logger.info("嵌入模型一致性校验通过。")
return True
def batch_insert_strs(self, strs: List[str], times: int) -> None:
"""向库中存入字符串"""
total = len(strs)
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
MofNCompleteColumn(),
"",
TimeElapsedColumn(),
"<",
TimeRemainingColumn(),
transient=False,
) as progress:
task = progress.add_task(f"存入嵌入库:({times}/{TOTAL_EMBEDDING_TIMES})", total=total)
for s in strs:
# 计算hash去重
item_hash = self.namespace + "-" + get_sha256(s)
if item_hash in self.store:
progress.update(task, advance=1)
continue
# 获取embedding
embedding = self._get_embedding(s)
# 存入
self.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, s)
progress.update(task, advance=1)
def save_to_file(self) -> None:
"""保存到文件"""
data = []
logger.info(f"正在保存{self.namespace}嵌入库到文件{self.embedding_file_path}")
for item in self.store.values():
data.append(item.to_dict())
data_frame = pd.DataFrame(data)
if not os.path.exists(self.dir):
os.makedirs(self.dir, exist_ok=True)
if not os.path.exists(self.embedding_file_path):
open(self.embedding_file_path, "w").close()
data_frame.to_parquet(self.embedding_file_path, engine="pyarrow", index=False)
logger.info(f"{self.namespace}嵌入库保存成功")
if self.faiss_index is not None and self.idx2hash is not None:
logger.info(f"正在保存{self.namespace}嵌入库的FaissIndex到文件{self.index_file_path}")
faiss.write_index(self.faiss_index, self.index_file_path)
logger.info(f"{self.namespace}嵌入库的FaissIndex保存成功")
logger.info(f"正在保存{self.namespace}嵌入库的idx2hash映射到文件{self.idx2hash_file_path}")
with open(self.idx2hash_file_path, "w", encoding="utf-8") as f:
f.write(json.dumps(self.idx2hash, ensure_ascii=False, indent=4))
logger.info(f"{self.namespace}嵌入库的idx2hash映射保存成功")
def load_from_file(self) -> None:
"""从文件中加载"""
if not os.path.exists(self.embedding_file_path):
raise Exception(f"文件{self.embedding_file_path}不存在")
logger.info(f"正在从文件{self.embedding_file_path}中加载{self.namespace}嵌入库")
data_frame = pd.read_parquet(self.embedding_file_path, engine="pyarrow")
total = len(data_frame)
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
MofNCompleteColumn(),
"",
TimeElapsedColumn(),
"<",
TimeRemainingColumn(),
transient=False,
) as progress:
task = progress.add_task("加载嵌入库", total=total)
for _, row in data_frame.iterrows():
self.store[row["hash"]] = EmbeddingStoreItem(row["hash"], row["embedding"], row["str"])
progress.update(task, advance=1)
logger.info(f"{self.namespace}嵌入库加载成功")
try:
if os.path.exists(self.index_file_path):
logger.info(f"正在从文件{self.index_file_path}中加载{self.namespace}嵌入库的FaissIndex")
self.faiss_index = faiss.read_index(self.index_file_path)
logger.info(f"{self.namespace}嵌入库的FaissIndex加载成功")
else:
raise Exception(f"文件{self.index_file_path}不存在")
if os.path.exists(self.idx2hash_file_path):
logger.info(f"正在从文件{self.idx2hash_file_path}中加载{self.namespace}嵌入库的idx2hash映射")
with open(self.idx2hash_file_path, "r") as f:
self.idx2hash = json.load(f)
logger.info(f"{self.namespace}嵌入库的idx2hash映射加载成功")
else:
raise Exception(f"文件{self.idx2hash_file_path}不存在")
except Exception as e:
logger.error(f"加载{self.namespace}嵌入库的FaissIndex时发生错误{e}")
logger.warning("正在重建Faiss索引")
self.build_faiss_index()
logger.info(f"{self.namespace}嵌入库的FaissIndex重建成功")
self.save_to_file()
def build_faiss_index(self) -> None:
"""重新构建Faiss索引以余弦相似度为度量"""
# 获取所有的embedding
array = []
self.idx2hash = dict()
for key in self.store:
array.append(self.store[key].embedding)
self.idx2hash[str(len(array) - 1)] = key
embeddings = np.array(array, dtype=np.float32)
# L2归一化
faiss.normalize_L2(embeddings)
# 构建索引
self.faiss_index = faiss.IndexFlatIP(global_config["embedding"]["dimension"])
self.faiss_index.add(embeddings)
def search_top_k(self, query: List[float], k: int) -> List[Tuple[str, float]]:
"""搜索最相似的k个项以余弦相似度为度量
Args:
query: 查询的embedding
k: 返回的最相似的k个项
Returns:
result: 最相似的k个项的(hash, 余弦相似度)列表
"""
if self.faiss_index is None:
logger.warning("FaissIndex尚未构建,返回None")
return None
if self.idx2hash is None:
logger.warning("idx2hash尚未构建,返回None")
return None
# L2归一化
faiss.normalize_L2(np.array([query], dtype=np.float32))
# 搜索
distances, indices = self.faiss_index.search(np.array([query]), k)
# 整理结果
indices = list(indices.flatten())
distances = list(distances.flatten())
result = [
(self.idx2hash[str(int(idx))], float(sim))
for (idx, sim) in zip(indices, distances)
if idx in range(len(self.idx2hash))
]
return result
class EmbeddingManager:
def __init__(self, llm_client: LLMClient):
self.paragraphs_embedding_store = EmbeddingStore(
llm_client,
PG_NAMESPACE,
EMBEDDING_DATA_DIR_STR,
)
self.entities_embedding_store = EmbeddingStore(
llm_client,
ENT_NAMESPACE,
EMBEDDING_DATA_DIR_STR,
)
self.relation_embedding_store = EmbeddingStore(
llm_client,
REL_NAMESPACE,
EMBEDDING_DATA_DIR_STR,
)
self.stored_pg_hashes = set()
def check_all_embedding_model_consistency(self):
"""对所有嵌入库做模型一致性校验"""
for store in [
self.paragraphs_embedding_store,
self.entities_embedding_store,
self.relation_embedding_store,
]:
if not store.check_embedding_model_consistency():
return False
return True
def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]):
"""将段落编码存入Embedding库"""
self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values()), times=1)
def _store_ent_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]):
"""将实体编码存入Embedding库"""
entities = set()
for triple_list in triple_list_data.values():
for triple in triple_list:
entities.add(triple[0])
entities.add(triple[2])
self.entities_embedding_store.batch_insert_strs(list(entities), times=2)
def _store_rel_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]):
"""将关系编码存入Embedding库"""
graph_triples = [] # a list of unique relation triple (in tuple) from all chunks
for triples in triple_list_data.values():
graph_triples.extend([tuple(t) for t in triples])
graph_triples = list(set(graph_triples))
self.relation_embedding_store.batch_insert_strs([str(triple) for triple in graph_triples], times=3)
def load_from_file(self):
"""从文件加载"""
if not self.check_all_embedding_model_consistency():
raise Exception("嵌入模型与本地存储不一致,请检查模型设置或清空嵌入库后重试。")
self.paragraphs_embedding_store.load_from_file()
self.entities_embedding_store.load_from_file()
self.relation_embedding_store.load_from_file()
# 从段落库中获取已存储的hash
self.stored_pg_hashes = set(self.paragraphs_embedding_store.store.keys())
def store_new_data_set(
self,
raw_paragraphs: Dict[str, str],
triple_list_data: Dict[str, List[List[str]]],
):
if not self.check_all_embedding_model_consistency():
raise Exception("嵌入模型与本地存储不一致,请检查模型设置或清空嵌入库后重试。")
"""存储新的数据集"""
self._store_pg_into_embedding(raw_paragraphs)
self._store_ent_into_embedding(triple_list_data)
self._store_rel_into_embedding(triple_list_data)
self.stored_pg_hashes.update(raw_paragraphs.keys())
def save_to_file(self):
"""保存到文件"""
self.paragraphs_embedding_store.save_to_file()
self.entities_embedding_store.save_to_file()
self.relation_embedding_store.save_to_file()
def rebuild_faiss_index(self):
"""重建Faiss索引请在添加新数据后调用"""
self.paragraphs_embedding_store.build_faiss_index()
self.entities_embedding_store.build_faiss_index()
self.relation_embedding_store.build_faiss_index()