style: 统一代码风格并采用现代化类型注解
对整个代码库进行了一次全面的代码风格清理和现代化改造,主要包括: - 移除了所有文件中多余的行尾空格。 - 将类型提示更新为 PEP 585 和 PEP 604 引入的现代语法(例如,使用 `list` 代替 `List`,使用 `|` 代替 `Optional`)。 - 清理了多个模块中未被使用的导入语句。 - 移除了不含插值变量的冗余 f-string。 - 调整了部分 `__init__.py` 文件中的 `__all__` 导出顺序,以保持一致性。 这些改动旨在提升代码的可读性和可维护性,使其与现代 Python 最佳实践保持一致,但未修改任何核心逻辑。
This commit is contained in:
@@ -161,16 +161,16 @@ class EmbeddingStore:
|
||||
# 限制 chunk_size 和 max_workers 在合理范围内
|
||||
chunk_size = max(MIN_CHUNK_SIZE, min(chunk_size, MAX_CHUNK_SIZE))
|
||||
max_workers = max(MIN_WORKERS, min(max_workers, MAX_WORKERS))
|
||||
|
||||
|
||||
semaphore = asyncio.Semaphore(max_workers)
|
||||
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
|
||||
results = {}
|
||||
|
||||
|
||||
# 将字符串列表分成多个 chunk
|
||||
chunks = []
|
||||
for i in range(0, len(strs), chunk_size):
|
||||
chunks.append(strs[i : i + chunk_size])
|
||||
|
||||
|
||||
async def _process_chunk(chunk: list[str]):
|
||||
"""处理一个 chunk 的字符串(批量获取 embedding)"""
|
||||
async with semaphore:
|
||||
@@ -180,12 +180,12 @@ class EmbeddingStore:
|
||||
embedding = await EmbeddingStore._get_embedding_async(llm, s)
|
||||
embeddings.append(embedding)
|
||||
results[s] = embedding
|
||||
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(len(chunk))
|
||||
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
# 并发处理所有 chunks
|
||||
tasks = [_process_chunk(chunk) for chunk in chunks]
|
||||
await asyncio.gather(*tasks)
|
||||
@@ -418,22 +418,22 @@ class EmbeddingStore:
|
||||
# 🔧 修复:检查所有 embedding 的维度是否一致
|
||||
dimensions = [len(emb) for emb in array]
|
||||
unique_dims = set(dimensions)
|
||||
|
||||
|
||||
if len(unique_dims) > 1:
|
||||
logger.error(f"检测到不一致的 embedding 维度: {unique_dims}")
|
||||
logger.error(f"维度分布: {dict(zip(*np.unique(dimensions, return_counts=True)))}")
|
||||
|
||||
|
||||
# 获取期望的维度(使用最常见的维度)
|
||||
from collections import Counter
|
||||
dim_counter = Counter(dimensions)
|
||||
expected_dim = dim_counter.most_common(1)[0][0]
|
||||
logger.warning(f"将使用最常见的维度: {expected_dim}")
|
||||
|
||||
|
||||
# 过滤掉维度不匹配的 embedding
|
||||
filtered_array = []
|
||||
filtered_idx2hash = {}
|
||||
skipped_count = 0
|
||||
|
||||
|
||||
for i, emb in enumerate(array):
|
||||
if len(emb) == expected_dim:
|
||||
filtered_array.append(emb)
|
||||
@@ -442,11 +442,11 @@ class EmbeddingStore:
|
||||
skipped_count += 1
|
||||
hash_key = self.idx2hash[str(i)]
|
||||
logger.warning(f"跳过维度不匹配的 embedding: {hash_key}, 维度={len(emb)}, 期望={expected_dim}")
|
||||
|
||||
|
||||
logger.warning(f"已过滤 {skipped_count} 个维度不匹配的 embedding")
|
||||
array = filtered_array
|
||||
self.idx2hash = filtered_idx2hash
|
||||
|
||||
|
||||
if not array:
|
||||
logger.error("过滤后没有可用的 embedding,无法构建索引")
|
||||
embedding_dim = expected_dim
|
||||
|
||||
Reference in New Issue
Block a user