修复代码格式和文件名大小写问题
This commit is contained in:
@@ -32,11 +32,11 @@ install(extra_lines=3)
|
||||
|
||||
# 多线程embedding配置常量
|
||||
DEFAULT_MAX_WORKERS = 10 # 默认最大线程数
|
||||
DEFAULT_CHUNK_SIZE = 10 # 默认每个线程处理的数据块大小
|
||||
MIN_CHUNK_SIZE = 1 # 最小分块大小
|
||||
MAX_CHUNK_SIZE = 50 # 最大分块大小
|
||||
MIN_WORKERS = 1 # 最小线程数
|
||||
MAX_WORKERS = 20 # 最大线程数
|
||||
DEFAULT_CHUNK_SIZE = 10 # 默认每个线程处理的数据块大小
|
||||
MIN_CHUNK_SIZE = 1 # 最小分块大小
|
||||
MAX_CHUNK_SIZE = 50 # 最大分块大小
|
||||
MIN_WORKERS = 1 # 最小线程数
|
||||
MAX_WORKERS = 20 # 最大线程数
|
||||
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
EMBEDDING_DATA_DIR = os.path.join(ROOT_PATH, "data", "embedding")
|
||||
@@ -93,7 +93,13 @@ class EmbeddingStoreItem:
|
||||
|
||||
|
||||
class EmbeddingStore:
|
||||
def __init__(self, namespace: str, dir_path: str, max_workers: int = DEFAULT_MAX_WORKERS, chunk_size: int = DEFAULT_CHUNK_SIZE):
|
||||
def __init__(
|
||||
self,
|
||||
namespace: str,
|
||||
dir_path: str,
|
||||
max_workers: int = DEFAULT_MAX_WORKERS,
|
||||
chunk_size: int = DEFAULT_CHUNK_SIZE,
|
||||
):
|
||||
self.namespace = namespace
|
||||
self.dir = dir_path
|
||||
self.embedding_file_path = f"{dir_path}/{namespace}.parquet"
|
||||
@@ -103,12 +109,16 @@ class EmbeddingStore:
|
||||
# 多线程配置参数验证和设置
|
||||
self.max_workers = max(MIN_WORKERS, min(MAX_WORKERS, max_workers))
|
||||
self.chunk_size = max(MIN_CHUNK_SIZE, min(MAX_CHUNK_SIZE, chunk_size))
|
||||
|
||||
|
||||
# 如果配置值被调整,记录日志
|
||||
if self.max_workers != max_workers:
|
||||
logger.warning(f"max_workers 已从 {max_workers} 调整为 {self.max_workers} (范围: {MIN_WORKERS}-{MAX_WORKERS})")
|
||||
logger.warning(
|
||||
f"max_workers 已从 {max_workers} 调整为 {self.max_workers} (范围: {MIN_WORKERS}-{MAX_WORKERS})"
|
||||
)
|
||||
if self.chunk_size != chunk_size:
|
||||
logger.warning(f"chunk_size 已从 {chunk_size} 调整为 {self.chunk_size} (范围: {MIN_CHUNK_SIZE}-{MAX_CHUNK_SIZE})")
|
||||
logger.warning(
|
||||
f"chunk_size 已从 {chunk_size} 调整为 {self.chunk_size} (范围: {MIN_CHUNK_SIZE}-{MAX_CHUNK_SIZE})"
|
||||
)
|
||||
|
||||
self.store = {}
|
||||
|
||||
@@ -144,45 +154,48 @@ class EmbeddingStore:
|
||||
# 确保事件循环被正确关闭
|
||||
try:
|
||||
loop.close()
|
||||
except Exception: ...
|
||||
except Exception:
|
||||
...
|
||||
|
||||
def _get_embeddings_batch_threaded(self, strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None) -> List[Tuple[str, List[float]]]:
|
||||
def _get_embeddings_batch_threaded(
|
||||
self, strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None
|
||||
) -> List[Tuple[str, List[float]]]:
|
||||
"""使用多线程批量获取嵌入向量
|
||||
|
||||
|
||||
Args:
|
||||
strs: 要获取嵌入的字符串列表
|
||||
chunk_size: 每个线程处理的数据块大小
|
||||
max_workers: 最大线程数
|
||||
progress_callback: 进度回调函数,接收一个参数表示完成的数量
|
||||
|
||||
|
||||
Returns:
|
||||
包含(原始字符串, 嵌入向量)的元组列表,保持与输入顺序一致
|
||||
"""
|
||||
if not strs:
|
||||
return []
|
||||
|
||||
|
||||
# 分块
|
||||
chunks = []
|
||||
for i in range(0, len(strs), chunk_size):
|
||||
chunk = strs[i:i + chunk_size]
|
||||
chunk = strs[i : i + chunk_size]
|
||||
chunks.append((i, chunk)) # 保存起始索引以维持顺序
|
||||
|
||||
|
||||
# 结果存储,使用字典按索引存储以保证顺序
|
||||
results = {}
|
||||
|
||||
|
||||
def process_chunk(chunk_data):
|
||||
"""处理单个数据块的函数"""
|
||||
start_idx, chunk_strs = chunk_data
|
||||
chunk_results = []
|
||||
|
||||
|
||||
# 为每个线程创建独立的LLMRequest实例
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
|
||||
|
||||
try:
|
||||
# 创建线程专用的LLM实例
|
||||
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
|
||||
|
||||
|
||||
for i, s in enumerate(chunk_strs):
|
||||
try:
|
||||
# 在线程中创建独立的事件循环
|
||||
@@ -198,19 +211,19 @@ class EmbeddingStore:
|
||||
else:
|
||||
logger.error(f"获取嵌入失败: {s}")
|
||||
chunk_results.append((start_idx + i, s, []))
|
||||
|
||||
|
||||
# 每完成一个嵌入立即更新进度
|
||||
if progress_callback:
|
||||
progress_callback(1)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}")
|
||||
chunk_results.append((start_idx + i, s, []))
|
||||
|
||||
|
||||
# 即使失败也要更新进度
|
||||
if progress_callback:
|
||||
progress_callback(1)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建LLM实例失败: {e}")
|
||||
# 如果创建LLM实例失败,返回空结果
|
||||
@@ -219,14 +232,14 @@ class EmbeddingStore:
|
||||
# 即使失败也要更新进度
|
||||
if progress_callback:
|
||||
progress_callback(1)
|
||||
|
||||
|
||||
return chunk_results
|
||||
|
||||
|
||||
# 使用线程池处理
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# 提交所有任务
|
||||
future_to_chunk = {executor.submit(process_chunk, chunk): chunk for chunk in chunks}
|
||||
|
||||
|
||||
# 收集结果(进度已在process_chunk中实时更新)
|
||||
for future in as_completed(future_to_chunk):
|
||||
try:
|
||||
@@ -240,7 +253,7 @@ class EmbeddingStore:
|
||||
start_idx, chunk_strs = chunk
|
||||
for i, s in enumerate(chunk_strs):
|
||||
results[start_idx + i] = (s, [])
|
||||
|
||||
|
||||
# 按原始顺序返回结果
|
||||
ordered_results = []
|
||||
for i in range(len(strs)):
|
||||
@@ -249,7 +262,7 @@ class EmbeddingStore:
|
||||
else:
|
||||
# 防止遗漏
|
||||
ordered_results.append((strs[i], []))
|
||||
|
||||
|
||||
return ordered_results
|
||||
|
||||
def get_test_file_path(self):
|
||||
@@ -258,14 +271,14 @@ class EmbeddingStore:
|
||||
def save_embedding_test_vectors(self):
|
||||
"""保存测试字符串的嵌入到本地(使用多线程优化)"""
|
||||
logger.info("开始保存测试字符串的嵌入向量...")
|
||||
|
||||
|
||||
# 使用多线程批量获取测试字符串的嵌入
|
||||
embedding_results = self._get_embeddings_batch_threaded(
|
||||
EMBEDDING_TEST_STRINGS,
|
||||
chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)),
|
||||
max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS))
|
||||
max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS)),
|
||||
)
|
||||
|
||||
|
||||
# 构建测试向量字典
|
||||
test_vectors = {}
|
||||
for idx, (s, embedding) in enumerate(embedding_results):
|
||||
@@ -275,12 +288,9 @@ class EmbeddingStore:
|
||||
logger.error(f"获取测试字符串嵌入失败: {s}")
|
||||
# 使用原始单线程方法作为后备
|
||||
test_vectors[str(idx)] = self._get_embedding(s)
|
||||
|
||||
|
||||
with open(self.get_test_file_path(), "w", encoding="utf-8") as f:
|
||||
f.write(orjson.dumps(
|
||||
test_vectors,
|
||||
option=orjson.OPT_INDENT_2
|
||||
).decode('utf-8'))
|
||||
f.write(orjson.dumps(test_vectors, option=orjson.OPT_INDENT_2).decode("utf-8"))
|
||||
|
||||
logger.info("测试字符串嵌入向量保存完成")
|
||||
|
||||
@@ -299,35 +309,35 @@ class EmbeddingStore:
|
||||
logger.warning("未检测到本地嵌入模型测试文件,将保存当前模型的测试嵌入。")
|
||||
self.save_embedding_test_vectors()
|
||||
return True
|
||||
|
||||
|
||||
# 检查本地向量完整性
|
||||
for idx in range(len(EMBEDDING_TEST_STRINGS)):
|
||||
if local_vectors.get(str(idx)) is None:
|
||||
logger.warning("本地嵌入模型测试文件缺失部分测试字符串,将重新保存。")
|
||||
self.save_embedding_test_vectors()
|
||||
return True
|
||||
|
||||
|
||||
logger.info("开始检验嵌入模型一致性...")
|
||||
|
||||
|
||||
# 使用多线程批量获取当前模型的嵌入
|
||||
embedding_results = self._get_embeddings_batch_threaded(
|
||||
EMBEDDING_TEST_STRINGS,
|
||||
chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)),
|
||||
max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS))
|
||||
max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS)),
|
||||
)
|
||||
|
||||
|
||||
# 检查一致性
|
||||
for idx, (s, new_emb) in enumerate(embedding_results):
|
||||
local_emb = local_vectors.get(str(idx))
|
||||
if not new_emb:
|
||||
logger.error(f"获取测试字符串嵌入失败: {s}")
|
||||
return False
|
||||
|
||||
|
||||
sim = cosine_similarity(local_emb, new_emb)
|
||||
if sim < EMBEDDING_SIM_THRESHOLD:
|
||||
logger.error(f"嵌入模型一致性校验失败,字符串: {s}, 相似度: {sim:.4f}")
|
||||
return False
|
||||
|
||||
|
||||
logger.info("嵌入模型一致性校验通过。")
|
||||
return True
|
||||
|
||||
@@ -335,22 +345,22 @@ class EmbeddingStore:
|
||||
"""向库中存入字符串(使用多线程优化)"""
|
||||
if not strs:
|
||||
return
|
||||
|
||||
|
||||
total = len(strs)
|
||||
|
||||
|
||||
# 过滤已存在的字符串
|
||||
new_strs = []
|
||||
for s in strs:
|
||||
item_hash = self.namespace + "-" + get_sha256(s)
|
||||
if item_hash not in self.store:
|
||||
new_strs.append(s)
|
||||
|
||||
|
||||
if not new_strs:
|
||||
logger.info(f"所有字符串已存在于{self.namespace}嵌入库中,跳过处理")
|
||||
return
|
||||
|
||||
|
||||
logger.info(f"需要处理 {len(new_strs)}/{total} 个新字符串")
|
||||
|
||||
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
@@ -364,31 +374,39 @@ class EmbeddingStore:
|
||||
transient=False,
|
||||
) as progress:
|
||||
task = progress.add_task(f"存入嵌入库:({times}/{TOTAL_EMBEDDING_TIMES})", total=total)
|
||||
|
||||
|
||||
# 首先更新已存在项的进度
|
||||
already_processed = total - len(new_strs)
|
||||
if already_processed > 0:
|
||||
progress.update(task, advance=already_processed)
|
||||
|
||||
|
||||
if new_strs:
|
||||
# 使用实例配置的参数,智能调整分块和线程数
|
||||
optimal_chunk_size = max(MIN_CHUNK_SIZE, min(self.chunk_size, len(new_strs) // self.max_workers if self.max_workers > 0 else self.chunk_size))
|
||||
optimal_max_workers = min(self.max_workers, max(MIN_WORKERS, len(new_strs) // optimal_chunk_size if optimal_chunk_size > 0 else 1))
|
||||
|
||||
optimal_chunk_size = max(
|
||||
MIN_CHUNK_SIZE,
|
||||
min(
|
||||
self.chunk_size, len(new_strs) // self.max_workers if self.max_workers > 0 else self.chunk_size
|
||||
),
|
||||
)
|
||||
optimal_max_workers = min(
|
||||
self.max_workers,
|
||||
max(MIN_WORKERS, len(new_strs) // optimal_chunk_size if optimal_chunk_size > 0 else 1),
|
||||
)
|
||||
|
||||
logger.debug(f"使用多线程处理: chunk_size={optimal_chunk_size}, max_workers={optimal_max_workers}")
|
||||
|
||||
|
||||
# 定义进度更新回调函数
|
||||
def update_progress(count):
|
||||
progress.update(task, advance=count)
|
||||
|
||||
|
||||
# 批量获取嵌入,并实时更新进度
|
||||
embedding_results = self._get_embeddings_batch_threaded(
|
||||
new_strs,
|
||||
chunk_size=optimal_chunk_size,
|
||||
new_strs,
|
||||
chunk_size=optimal_chunk_size,
|
||||
max_workers=optimal_max_workers,
|
||||
progress_callback=update_progress
|
||||
progress_callback=update_progress,
|
||||
)
|
||||
|
||||
|
||||
# 存入结果(不再需要在这里更新进度,因为已经在回调中更新了)
|
||||
for s, embedding in embedding_results:
|
||||
item_hash = self.namespace + "-" + get_sha256(s)
|
||||
@@ -419,9 +437,7 @@ class EmbeddingStore:
|
||||
logger.info(f"{self.namespace}嵌入库的FaissIndex保存成功")
|
||||
logger.info(f"正在保存{self.namespace}嵌入库的idx2hash映射到文件{self.idx2hash_file_path}")
|
||||
with open(self.idx2hash_file_path, "w", encoding="utf-8") as f:
|
||||
f.write(orjson.dumps(
|
||||
self.idx2hash, option=orjson.OPT_INDENT_2
|
||||
).decode('utf-8'))
|
||||
f.write(orjson.dumps(self.idx2hash, option=orjson.OPT_INDENT_2).decode("utf-8"))
|
||||
logger.info(f"{self.namespace}嵌入库的idx2hash映射保存成功")
|
||||
|
||||
def load_from_file(self) -> None:
|
||||
@@ -523,7 +539,7 @@ class EmbeddingManager:
|
||||
def __init__(self, max_workers: int = DEFAULT_MAX_WORKERS, chunk_size: int = DEFAULT_CHUNK_SIZE):
|
||||
"""
|
||||
初始化EmbeddingManager
|
||||
|
||||
|
||||
Args:
|
||||
max_workers: 最大线程数
|
||||
chunk_size: 每个线程处理的数据块大小
|
||||
|
||||
Reference in New Issue
Block a user