From 423525ead594b65a8430716ca85fdb49e21f8226 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A2=A8=E6=A2=93=E6=9F=92?= <1787882683@qq.com> Date: Sat, 2 Aug 2025 23:52:41 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E5=A4=9A=E7=BA=BF=E7=A8=8B?= =?UTF-8?q?=E5=A4=84=E7=90=86=EF=BC=8C=E8=B0=83=E6=95=B4=E5=B5=8C=E5=85=A5?= =?UTF-8?q?=E8=8E=B7=E5=8F=96=E5=92=8C=E5=AD=98=E5=82=A8=E9=80=BB=E8=BE=91?= =?UTF-8?q?=EF=BC=8C=E5=A2=9E=E5=BC=BA=E6=A8=A1=E5=9E=8B=E4=B8=80=E8=87=B4?= =?UTF-8?q?=E6=80=A7=E6=A0=A1=E9=AA=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/info_extraction.py | 6 +- src/chat/knowledge/embedding_store.py | 271 ++++++++++++++++++++++---- src/chat/knowledge/knowledge_lib.py | 54 +---- 3 files changed, 238 insertions(+), 93 deletions(-) diff --git a/scripts/info_extraction.py b/scripts/info_extraction.py index c36a77892..cb545a44d 100644 --- a/scripts/info_extraction.py +++ b/scripts/info_extraction.py @@ -25,7 +25,7 @@ from rich.progress import ( TextColumn, ) from raw_data_preprocessor import RAW_DATA_PATH, load_raw_data -from src.config.config import global_config +from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest from dotenv import load_dotenv @@ -96,11 +96,11 @@ open_ie_doc_lock = Lock() shutdown_event = Event() lpmm_entity_extract_llm = LLMRequest( - model=global_config.model.lpmm_entity_extract, + model_set=model_config.model_task_config.lpmm_entity_extract, request_type="lpmm.entity_extract" ) lpmm_rdf_build_llm = LLMRequest( - model=global_config.model.lpmm_rdf_build, + model_set=model_config.model_task_config.lpmm_rdf_build, request_type="lpmm.rdf_build" ) def process_single_text(pg_hash, raw_data): diff --git a/src/chat/knowledge/embedding_store.py b/src/chat/knowledge/embedding_store.py index d732683ae..447ef8e7e 100644 --- a/src/chat/knowledge/embedding_store.py +++ b/src/chat/knowledge/embedding_store.py @@ -3,6 +3,7 @@ import json import os import math import asyncio +from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Dict, List, Tuple import numpy as np @@ -26,12 +27,20 @@ from rich.progress import ( SpinnerColumn, TextColumn, ) -from src.manager.local_store_manager import local_storage from src.chat.utils.utils import get_embedding from src.config.config import global_config install(extra_lines=3) + +# 多线程embedding配置常量 +DEFAULT_MAX_WORKERS = 10 # 默认最大线程数 +DEFAULT_CHUNK_SIZE = 10 # 默认每个线程处理的数据块大小 +MIN_CHUNK_SIZE = 1 # 最小分块大小 +MAX_CHUNK_SIZE = 50 # 最大分块大小 +MIN_WORKERS = 1 # 最小线程数 +MAX_WORKERS = 20 # 最大线程数 + ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) EMBEDDING_DATA_DIR = os.path.join(ROOT_PATH, "data", "embedding") EMBEDDING_DATA_DIR_STR = str(EMBEDDING_DATA_DIR).replace("\\", "/") @@ -87,13 +96,23 @@ class EmbeddingStoreItem: class EmbeddingStore: - def __init__(self, namespace: str, dir_path: str): + def __init__(self, namespace: str, dir_path: str, max_workers: int = DEFAULT_MAX_WORKERS, chunk_size: int = DEFAULT_CHUNK_SIZE): self.namespace = namespace self.dir = dir_path self.embedding_file_path = f"{dir_path}/{namespace}.parquet" self.index_file_path = f"{dir_path}/{namespace}.index" self.idx2hash_file_path = dir_path + "/" + namespace + "_i2h.json" + # 多线程配置参数验证和设置 + self.max_workers = max(MIN_WORKERS, min(MAX_WORKERS, max_workers)) + self.chunk_size = max(MIN_CHUNK_SIZE, min(MAX_CHUNK_SIZE, chunk_size)) + + # 如果配置值被调整,记录日志 + if self.max_workers != max_workers: + logger.warning(f"max_workers 已从 {max_workers} 调整为 {self.max_workers} (范围: {MIN_WORKERS}-{MAX_WORKERS})") + if self.chunk_size != chunk_size: + logger.warning(f"chunk_size 已从 {chunk_size} 调整为 {self.chunk_size} (范围: {MIN_CHUNK_SIZE}-{MAX_CHUNK_SIZE})") + self.store = {} self.faiss_index = None @@ -125,16 +144,134 @@ class EmbeddingStore: return [] return result + def _get_embeddings_batch_threaded(self, strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None) -> List[Tuple[str, List[float]]]: + """使用多线程批量获取嵌入向量 + + Args: + strs: 要获取嵌入的字符串列表 + chunk_size: 每个线程处理的数据块大小 + max_workers: 最大线程数 + progress_callback: 进度回调函数,接收一个参数表示完成的数量 + + Returns: + 包含(原始字符串, 嵌入向量)的元组列表,保持与输入顺序一致 + """ + if not strs: + return [] + + # 分块 + chunks = [] + for i in range(0, len(strs), chunk_size): + chunk = strs[i:i + chunk_size] + chunks.append((i, chunk)) # 保存起始索引以维持顺序 + + # 结果存储,使用字典按索引存储以保证顺序 + results = {} + + def process_chunk(chunk_data): + """处理单个数据块的函数""" + start_idx, chunk_strs = chunk_data + chunk_results = [] + + # 为每个线程创建独立的LLMRequest实例 + from src.llm_models.utils_model import LLMRequest + from src.config.config import model_config + + try: + # 创建线程专用的LLM实例 + llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding") + + for i, s in enumerate(chunk_strs): + try: + # 直接使用异步函数 + embedding = asyncio.run(llm.get_embedding(s)) + if embedding and len(embedding) > 0: + chunk_results.append((start_idx + i, s, embedding[0])) # embedding[0] 是实际的向量 + else: + logger.error(f"获取嵌入失败: {s}") + chunk_results.append((start_idx + i, s, [])) + + # 每完成一个嵌入立即更新进度 + if progress_callback: + progress_callback(1) + + except Exception as e: + logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}") + chunk_results.append((start_idx + i, s, [])) + + # 即使失败也要更新进度 + if progress_callback: + progress_callback(1) + + except Exception as e: + logger.error(f"创建LLM实例失败: {e}") + # 如果创建LLM实例失败,返回空结果 + for i, s in enumerate(chunk_strs): + chunk_results.append((start_idx + i, s, [])) + # 即使失败也要更新进度 + if progress_callback: + progress_callback(1) + + return chunk_results + + # 使用线程池处理 + with ThreadPoolExecutor(max_workers=max_workers) as executor: + # 提交所有任务 + future_to_chunk = {executor.submit(process_chunk, chunk): chunk for chunk in chunks} + + # 收集结果(进度已在process_chunk中实时更新) + for future in as_completed(future_to_chunk): + try: + chunk_results = future.result() + for idx, s, embedding in chunk_results: + results[idx] = (s, embedding) + except Exception as e: + chunk = future_to_chunk[future] + logger.error(f"处理数据块时发生异常: {chunk}, 错误: {e}") + # 为失败的块添加空结果 + start_idx, chunk_strs = chunk + for i, s in enumerate(chunk_strs): + results[start_idx + i] = (s, []) + + # 按原始顺序返回结果 + ordered_results = [] + for i in range(len(strs)): + if i in results: + ordered_results.append(results[i]) + else: + # 防止遗漏 + ordered_results.append((strs[i], [])) + + return ordered_results + def get_test_file_path(self): return EMBEDDING_TEST_FILE def save_embedding_test_vectors(self): - """保存测试字符串的嵌入到本地""" + """保存测试字符串的嵌入到本地(使用多线程优化)""" + logger.info("开始保存测试字符串的嵌入向量...") + + # 使用多线程批量获取测试字符串的嵌入 + embedding_results = self._get_embeddings_batch_threaded( + EMBEDDING_TEST_STRINGS, + chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)), + max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS)) + ) + + # 构建测试向量字典 test_vectors = {} - for idx, s in enumerate(EMBEDDING_TEST_STRINGS): - test_vectors[str(idx)] = self._get_embedding(s) + for idx, (s, embedding) in enumerate(embedding_results): + if embedding: + test_vectors[str(idx)] = embedding + else: + logger.error(f"获取测试字符串嵌入失败: {s}") + # 使用原始单线程方法作为后备 + test_vectors[str(idx)] = self._get_embedding(s) + with open(self.get_test_file_path(), "w", encoding="utf-8") as f: json.dump(test_vectors, f, ensure_ascii=False, indent=2) + + logger.info("测试字符串嵌入向量保存完成") def load_embedding_test_vectors(self): """加载本地保存的测试字符串嵌入""" @@ -145,29 +282,64 @@ class EmbeddingStore: return json.load(f) def check_embedding_model_consistency(self): - """校验当前模型与本地嵌入模型是否一致""" + """校验当前模型与本地嵌入模型是否一致(使用多线程优化)""" local_vectors = self.load_embedding_test_vectors() if local_vectors is None: logger.warning("未检测到本地嵌入模型测试文件,将保存当前模型的测试嵌入。") self.save_embedding_test_vectors() return True - for idx, s in enumerate(EMBEDDING_TEST_STRINGS): - local_emb = local_vectors.get(str(idx)) - if local_emb is None: + + # 检查本地向量完整性 + for idx in range(len(EMBEDDING_TEST_STRINGS)): + if local_vectors.get(str(idx)) is None: logger.warning("本地嵌入模型测试文件缺失部分测试字符串,将重新保存。") self.save_embedding_test_vectors() return True - new_emb = self._get_embedding(s) + + logger.info("开始检验嵌入模型一致性...") + + # 使用多线程批量获取当前模型的嵌入 + embedding_results = self._get_embeddings_batch_threaded( + EMBEDDING_TEST_STRINGS, + chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)), + max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS)) + ) + + # 检查一致性 + for idx, (s, new_emb) in enumerate(embedding_results): + local_emb = local_vectors.get(str(idx)) + if not new_emb: + logger.error(f"获取测试字符串嵌入失败: {s}") + return False + sim = cosine_similarity(local_emb, new_emb) if sim < EMBEDDING_SIM_THRESHOLD: - logger.error("嵌入模型一致性校验失败") + logger.error(f"嵌入模型一致性校验失败,字符串: {s}, 相似度: {sim:.4f}") return False + logger.info("嵌入模型一致性校验通过。") return True def batch_insert_strs(self, strs: List[str], times: int) -> None: - """向库中存入字符串""" + """向库中存入字符串(使用多线程优化)""" + if not strs: + return + total = len(strs) + + # 过滤已存在的字符串 + new_strs = [] + for s in strs: + item_hash = self.namespace + "-" + get_sha256(s) + if item_hash not in self.store: + new_strs.append(s) + + if not new_strs: + logger.info(f"所有字符串已存在于{self.namespace}嵌入库中,跳过处理") + return + + logger.info(f"需要处理 {len(new_strs)}/{total} 个新字符串") + with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), @@ -181,19 +353,38 @@ class EmbeddingStore: transient=False, ) as progress: task = progress.add_task(f"存入嵌入库:({times}/{TOTAL_EMBEDDING_TIMES})", total=total) - for s in strs: - # 计算hash去重 - item_hash = self.namespace + "-" + get_sha256(s) - if item_hash in self.store: - progress.update(task, advance=1) - continue - - # 获取embedding - embedding = self._get_embedding(s) - - # 存入 - self.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, s) - progress.update(task, advance=1) + + # 首先更新已存在项的进度 + already_processed = total - len(new_strs) + if already_processed > 0: + progress.update(task, advance=already_processed) + + if new_strs: + # 使用实例配置的参数,智能调整分块和线程数 + optimal_chunk_size = max(MIN_CHUNK_SIZE, min(self.chunk_size, len(new_strs) // self.max_workers if self.max_workers > 0 else self.chunk_size)) + optimal_max_workers = min(self.max_workers, max(MIN_WORKERS, len(new_strs) // optimal_chunk_size if optimal_chunk_size > 0 else 1)) + + logger.debug(f"使用多线程处理: chunk_size={optimal_chunk_size}, max_workers={optimal_max_workers}") + + # 定义进度更新回调函数 + def update_progress(count): + progress.update(task, advance=count) + + # 批量获取嵌入,并实时更新进度 + embedding_results = self._get_embeddings_batch_threaded( + new_strs, + chunk_size=optimal_chunk_size, + max_workers=optimal_max_workers, + progress_callback=update_progress + ) + + # 存入结果(不再需要在这里更新进度,因为已经在回调中更新了) + for s, embedding in embedding_results: + item_hash = self.namespace + "-" + get_sha256(s) + if embedding: # 只有成功获取到嵌入才存入 + self.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, s) + else: + logger.warning(f"跳过存储失败的嵌入: {s[:50]}...") def save_to_file(self) -> None: """保存到文件""" @@ -316,31 +507,37 @@ class EmbeddingStore: class EmbeddingManager: - def __init__(self): + def __init__(self, max_workers: int = DEFAULT_MAX_WORKERS, chunk_size: int = DEFAULT_CHUNK_SIZE): + """ + 初始化EmbeddingManager + + Args: + max_workers: 最大线程数 + chunk_size: 每个线程处理的数据块大小 + """ self.paragraphs_embedding_store = EmbeddingStore( - local_storage["pg_namespace"], # type: ignore + "paragraph", # type: ignore EMBEDDING_DATA_DIR_STR, + max_workers=max_workers, + chunk_size=chunk_size, ) self.entities_embedding_store = EmbeddingStore( - local_storage["pg_namespace"], # type: ignore + "entity", # type: ignore EMBEDDING_DATA_DIR_STR, + max_workers=max_workers, + chunk_size=chunk_size, ) self.relation_embedding_store = EmbeddingStore( - local_storage["pg_namespace"], # type: ignore + "relation", # type: ignore EMBEDDING_DATA_DIR_STR, + max_workers=max_workers, + chunk_size=chunk_size, ) self.stored_pg_hashes = set() def check_all_embedding_model_consistency(self): """对所有嵌入库做模型一致性校验""" - for store in [ - self.paragraphs_embedding_store, - self.entities_embedding_store, - self.relation_embedding_store, - ]: - if not store.check_embedding_model_consistency(): - return False - return True + return self.paragraphs_embedding_store.check_embedding_model_consistency() def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]): """将段落编码存入Embedding库""" diff --git a/src/chat/knowledge/knowledge_lib.py b/src/chat/knowledge/knowledge_lib.py index 1e87d3824..31cc20c1d 100644 --- a/src/chat/knowledge/knowledge_lib.py +++ b/src/chat/knowledge/knowledge_lib.py @@ -6,7 +6,6 @@ from src.chat.knowledge.qa_manager import QAManager from src.chat.knowledge.kg_manager import KGManager from src.chat.knowledge.global_logger import logger from src.config.config import global_config as bot_global_config -from src.manager.local_store_manager import local_storage import os INVALID_ENTITY = [ @@ -21,9 +20,6 @@ INVALID_ENTITY = [ "她们", "它们", ] -PG_NAMESPACE = "paragraph" -ENT_NAMESPACE = "entity" -REL_NAMESPACE = "relation" RAG_GRAPH_NAMESPACE = "rag-graph" RAG_ENT_CNT_NAMESPACE = "rag-ent-cnt" @@ -34,54 +30,6 @@ ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", DATA_PATH = os.path.join(ROOT_PATH, "data") -def _initialize_knowledge_local_storage(): - """ - 初始化知识库相关的本地存储配置 - 使用字典批量设置,避免重复的if判断 - """ - # 定义所有需要初始化的配置项 - default_configs = { - # 路径配置 - "root_path": ROOT_PATH, - "data_path": f"{ROOT_PATH}/data", - # 实体和命名空间配置 - "lpmm_invalid_entity": INVALID_ENTITY, - "pg_namespace": PG_NAMESPACE, - "ent_namespace": ENT_NAMESPACE, - "rel_namespace": REL_NAMESPACE, - # RAG相关命名空间配置 - "rag_graph_namespace": RAG_GRAPH_NAMESPACE, - "rag_ent_cnt_namespace": RAG_ENT_CNT_NAMESPACE, - "rag_pg_hash_namespace": RAG_PG_HASH_NAMESPACE, - } - - # 日志级别映射:重要配置用info,其他用debug - important_configs = {"root_path", "data_path"} - - # 批量设置配置项 - initialized_count = 0 - for key, default_value in default_configs.items(): - if local_storage[key] is None: - local_storage[key] = default_value - - # 根据重要性选择日志级别 - if key in important_configs: - logger.info(f"设置{key}: {default_value}") - else: - logger.debug(f"设置{key}: {default_value}") - - initialized_count += 1 - - if initialized_count > 0: - logger.info(f"知识库本地存储初始化完成,共设置 {initialized_count} 项配置") - else: - logger.debug("知识库本地存储配置已存在,跳过初始化") - - -# 初始化本地存储路径 -# sourcery skip: dict-comprehension -_initialize_knowledge_local_storage() - qa_manager = None inspire_manager = None @@ -120,7 +68,7 @@ if bot_global_config.lpmm_knowledge.enable: # 数据比对:Embedding库与KG的段落hash集合 for pg_hash in kg_manager.stored_paragraph_hashes: - key = f"{PG_NAMESPACE}-{pg_hash}" + key = f"paragraph-{pg_hash}" if key not in embed_manager.stored_pg_hashes: logger.warning(f"KG中存在Embedding库中不存在的段落:{key}")