Merge branch 'dev' of https://github.com/MoFox-Studio/MoFox_Bot into dev
This commit is contained in:
@@ -3,9 +3,7 @@ import datetime
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from threading import Lock
|
|
||||||
|
|
||||||
import aiofiles
|
import aiofiles
|
||||||
import orjson
|
import orjson
|
||||||
@@ -38,7 +36,26 @@ ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
|||||||
RAW_DATA_PATH = os.path.join(ROOT_PATH, "data", "lpmm_raw_data")
|
RAW_DATA_PATH = os.path.join(ROOT_PATH, "data", "lpmm_raw_data")
|
||||||
OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie")
|
OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie")
|
||||||
TEMP_DIR = os.path.join(ROOT_PATH, "temp", "lpmm_cache")
|
TEMP_DIR = os.path.join(ROOT_PATH, "temp", "lpmm_cache")
|
||||||
file_lock = Lock()
|
|
||||||
|
# ========== 性能配置参数 ==========
|
||||||
|
#
|
||||||
|
# 知识提取(步骤2:txt转json)并发控制
|
||||||
|
# - 控制同时进行的LLM提取请求数量
|
||||||
|
# - 推荐值: 3-10,取决于API速率限制
|
||||||
|
# - 过高可能触发429错误(速率限制)
|
||||||
|
MAX_EXTRACTION_CONCURRENCY = 5
|
||||||
|
|
||||||
|
# 数据导入(步骤3:生成embedding)性能配置
|
||||||
|
# - max_workers: 并发批次数(每批次并行处理)
|
||||||
|
# - chunk_size: 每批次包含的字符串数
|
||||||
|
# - 理论并发 = max_workers × chunk_size
|
||||||
|
# - 推荐配置:
|
||||||
|
# * 高性能API(OpenAI): max_workers=20-30, chunk_size=30-50
|
||||||
|
# * 中等API: max_workers=10-15, chunk_size=20-30
|
||||||
|
# * 本地/慢速API: max_workers=5-10, chunk_size=10-20
|
||||||
|
EMBEDDING_MAX_WORKERS = 20 # 并发批次数
|
||||||
|
EMBEDDING_CHUNK_SIZE = 30 # 每批次字符串数
|
||||||
|
# ===================================
|
||||||
|
|
||||||
# --- 缓存清理 ---
|
# --- 缓存清理 ---
|
||||||
|
|
||||||
@@ -155,26 +172,41 @@ def get_extraction_prompt(paragraph: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
async def extract_info_async(pg_hash, paragraph, llm_api):
|
async def extract_info_async(pg_hash, paragraph, llm_api):
|
||||||
|
"""
|
||||||
|
异步提取单个段落的信息(带缓存支持)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pg_hash: 段落哈希值
|
||||||
|
paragraph: 段落文本
|
||||||
|
llm_api: LLM请求实例
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (doc_item或None, failed_hash或None)
|
||||||
|
"""
|
||||||
temp_file_path = os.path.join(TEMP_DIR, f"{pg_hash}.json")
|
temp_file_path = os.path.join(TEMP_DIR, f"{pg_hash}.json")
|
||||||
with file_lock:
|
|
||||||
|
# 🔧 优化:使用异步文件检查,避免阻塞
|
||||||
if os.path.exists(temp_file_path):
|
if os.path.exists(temp_file_path):
|
||||||
try:
|
try:
|
||||||
async with aiofiles.open(temp_file_path, "rb") as f:
|
async with aiofiles.open(temp_file_path, "rb") as f:
|
||||||
content = await f.read()
|
content = await f.read()
|
||||||
return orjson.loads(content), None
|
return orjson.loads(content), None
|
||||||
except orjson.JSONDecodeError:
|
except orjson.JSONDecodeError:
|
||||||
|
# 缓存文件损坏,删除并重新生成
|
||||||
|
try:
|
||||||
os.remove(temp_file_path)
|
os.remove(temp_file_path)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
prompt = get_extraction_prompt(paragraph)
|
prompt = get_extraction_prompt(paragraph)
|
||||||
content = None
|
content = None
|
||||||
try:
|
try:
|
||||||
content, (_, _, _) = await llm_api.generate_response_async(prompt)
|
content, (_, _, _) = await llm_api.generate_response_async(prompt)
|
||||||
|
|
||||||
# 改进点:调用封装好的函数处理JSON解析和修复
|
# 调用封装好的函数处理JSON解析和修复
|
||||||
extracted_data = _parse_and_repair_json(content)
|
extracted_data = _parse_and_repair_json(content)
|
||||||
|
|
||||||
if extracted_data is None:
|
if extracted_data is None:
|
||||||
# 如果解析失败,抛出异常以触发统一的错误处理逻辑
|
|
||||||
raise ValueError("无法从LLM输出中解析有效的JSON数据")
|
raise ValueError("无法从LLM输出中解析有效的JSON数据")
|
||||||
|
|
||||||
doc_item = {
|
doc_item = {
|
||||||
@@ -183,9 +215,11 @@ async def extract_info_async(pg_hash, paragraph, llm_api):
|
|||||||
"extracted_entities": extracted_data.get("entities", []),
|
"extracted_entities": extracted_data.get("entities", []),
|
||||||
"extracted_triples": extracted_data.get("triples", []),
|
"extracted_triples": extracted_data.get("triples", []),
|
||||||
}
|
}
|
||||||
with file_lock:
|
|
||||||
|
# 保存到缓存(异步写入)
|
||||||
async with aiofiles.open(temp_file_path, "wb") as f:
|
async with aiofiles.open(temp_file_path, "wb") as f:
|
||||||
await f.write(orjson.dumps(doc_item))
|
await f.write(orjson.dumps(doc_item))
|
||||||
|
|
||||||
return doc_item, None
|
return doc_item, None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"提取信息失败:{pg_hash}, 错误:{e}")
|
logger.error(f"提取信息失败:{pg_hash}, 错误:{e}")
|
||||||
@@ -194,23 +228,50 @@ async def extract_info_async(pg_hash, paragraph, llm_api):
|
|||||||
return None, pg_hash
|
return None, pg_hash
|
||||||
|
|
||||||
|
|
||||||
def extract_info_sync(pg_hash, paragraph, model_set):
|
async def extract_information(paragraphs_dict, model_set):
|
||||||
llm_api = LLMRequest(model_set=model_set)
|
"""
|
||||||
return asyncio.run(extract_info_async(pg_hash, paragraph, llm_api))
|
🔧 优化:使用真正的异步并发代替多线程
|
||||||
|
|
||||||
|
这样可以:
|
||||||
|
1. 避免 event loop closed 错误
|
||||||
|
2. 更高效地利用 I/O 资源
|
||||||
|
3. 与我们优化的 LLM 请求层无缝集成
|
||||||
|
|
||||||
def extract_information(paragraphs_dict, model_set):
|
并发控制:
|
||||||
|
- 使用信号量限制最大并发数为 5,防止触发 API 速率限制
|
||||||
|
|
||||||
|
Args:
|
||||||
|
paragraphs_dict: {hash: paragraph} 字典
|
||||||
|
model_set: 模型配置
|
||||||
|
"""
|
||||||
logger.info("--- 步骤 2: 开始信息提取 ---")
|
logger.info("--- 步骤 2: 开始信息提取 ---")
|
||||||
os.makedirs(OPENIE_OUTPUT_DIR, exist_ok=True)
|
os.makedirs(OPENIE_OUTPUT_DIR, exist_ok=True)
|
||||||
os.makedirs(TEMP_DIR, exist_ok=True)
|
os.makedirs(TEMP_DIR, exist_ok=True)
|
||||||
|
|
||||||
failed_hashes, open_ie_docs = [], []
|
failed_hashes, open_ie_docs = [], []
|
||||||
|
|
||||||
with ThreadPoolExecutor(max_workers=3) as executor:
|
# 🔧 关键修复:创建单个 LLM 请求实例,复用连接
|
||||||
f_to_hash = {
|
llm_api = LLMRequest(model_set=model_set, request_type="lpmm_extraction")
|
||||||
executor.submit(extract_info_sync, p_hash, p, model_set): p_hash
|
|
||||||
for p_hash, p in paragraphs_dict.items()
|
# 🔧 并发控制:限制最大并发数,防止速率限制
|
||||||
}
|
semaphore = asyncio.Semaphore(MAX_EXTRACTION_CONCURRENCY)
|
||||||
|
|
||||||
|
async def extract_with_semaphore(pg_hash, paragraph):
|
||||||
|
"""带信号量控制的提取函数"""
|
||||||
|
async with semaphore:
|
||||||
|
return await extract_info_async(pg_hash, paragraph, llm_api)
|
||||||
|
|
||||||
|
# 创建所有异步任务(带并发控制)
|
||||||
|
tasks = [
|
||||||
|
extract_with_semaphore(p_hash, paragraph)
|
||||||
|
for p_hash, paragraph in paragraphs_dict.items()
|
||||||
|
]
|
||||||
|
|
||||||
|
total = len(tasks)
|
||||||
|
completed = 0
|
||||||
|
|
||||||
|
logger.info(f"开始提取 {total} 个段落的信息(最大并发: {MAX_EXTRACTION_CONCURRENCY})")
|
||||||
|
|
||||||
with Progress(
|
with Progress(
|
||||||
SpinnerColumn(),
|
SpinnerColumn(),
|
||||||
TextColumn("[progress.description]{task.description}"),
|
TextColumn("[progress.description]{task.description}"),
|
||||||
@@ -222,13 +283,18 @@ def extract_information(paragraphs_dict, model_set):
|
|||||||
"<",
|
"<",
|
||||||
TimeRemainingColumn(),
|
TimeRemainingColumn(),
|
||||||
) as progress:
|
) as progress:
|
||||||
task = progress.add_task("[cyan]正在提取信息...", total=len(paragraphs_dict))
|
task = progress.add_task("[cyan]正在提取信息...", total=total)
|
||||||
for future in as_completed(f_to_hash):
|
|
||||||
doc_item, failed_hash = future.result()
|
# 🔧 优化:使用 asyncio.gather 并发执行所有任务
|
||||||
|
# return_exceptions=True 确保单个失败不影响其他任务
|
||||||
|
for coro in asyncio.as_completed(tasks):
|
||||||
|
doc_item, failed_hash = await coro
|
||||||
if failed_hash:
|
if failed_hash:
|
||||||
failed_hashes.append(failed_hash)
|
failed_hashes.append(failed_hash)
|
||||||
elif doc_item:
|
elif doc_item:
|
||||||
open_ie_docs.append(doc_item)
|
open_ie_docs.append(doc_item)
|
||||||
|
|
||||||
|
completed += 1
|
||||||
progress.update(task, advance=1)
|
progress.update(task, advance=1)
|
||||||
|
|
||||||
if open_ie_docs:
|
if open_ie_docs:
|
||||||
@@ -244,6 +310,7 @@ def extract_information(paragraphs_dict, model_set):
|
|||||||
with open(output_path, "wb") as f:
|
with open(output_path, "wb") as f:
|
||||||
f.write(orjson.dumps(openie_obj._to_dict()))
|
f.write(orjson.dumps(openie_obj._to_dict()))
|
||||||
logger.info(f"信息提取结果已保存到: {output_path}")
|
logger.info(f"信息提取结果已保存到: {output_path}")
|
||||||
|
logger.info(f"成功提取 {len(open_ie_docs)} 个段落的信息")
|
||||||
|
|
||||||
if failed_hashes:
|
if failed_hashes:
|
||||||
logger.error(f"以下 {len(failed_hashes)} 个段落提取失败: {failed_hashes}")
|
logger.error(f"以下 {len(failed_hashes)} 个段落提取失败: {failed_hashes}")
|
||||||
@@ -263,7 +330,10 @@ async def import_data(openie_obj: OpenIE | None = None):
|
|||||||
默认为 None.
|
默认为 None.
|
||||||
"""
|
"""
|
||||||
logger.info("--- 步骤 3: 开始数据导入 ---")
|
logger.info("--- 步骤 3: 开始数据导入 ---")
|
||||||
embed_manager, kg_manager = EmbeddingManager(), KGManager()
|
# 使用配置的并发参数以加速 embedding 生成
|
||||||
|
# max_workers: 并发批次数,chunk_size: 每批次处理的字符串数
|
||||||
|
embed_manager = EmbeddingManager(max_workers=EMBEDDING_MAX_WORKERS, chunk_size=EMBEDDING_CHUNK_SIZE)
|
||||||
|
kg_manager = KGManager()
|
||||||
|
|
||||||
logger.info("正在加载现有的 Embedding 库...")
|
logger.info("正在加载现有的 Embedding 库...")
|
||||||
try:
|
try:
|
||||||
@@ -340,6 +410,23 @@ def import_from_specific_file():
|
|||||||
# --- 主函数 ---
|
# --- 主函数 ---
|
||||||
|
|
||||||
|
|
||||||
|
def rebuild_faiss_only():
|
||||||
|
"""仅重建 FAISS 索引,不重新导入数据"""
|
||||||
|
logger.info("--- 重建 FAISS 索引 ---")
|
||||||
|
# 重建索引不需要并发参数(不涉及 embedding 生成)
|
||||||
|
embed_manager = EmbeddingManager()
|
||||||
|
|
||||||
|
logger.info("正在加载现有的 Embedding 库...")
|
||||||
|
try:
|
||||||
|
embed_manager.load_from_file()
|
||||||
|
logger.info("开始重建 FAISS 索引...")
|
||||||
|
embed_manager.rebuild_faiss_index()
|
||||||
|
embed_manager.save_to_file()
|
||||||
|
logger.info("✅ FAISS 索引重建完成!")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"重建 FAISS 索引时发生错误: {e}", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# 使用 os.path.relpath 创建相对于项目根目录的友好路径
|
# 使用 os.path.relpath 创建相对于项目根目录的友好路径
|
||||||
raw_data_relpath = os.path.relpath(RAW_DATA_PATH, os.path.join(ROOT_PATH, ".."))
|
raw_data_relpath = os.path.relpath(RAW_DATA_PATH, os.path.join(ROOT_PATH, ".."))
|
||||||
@@ -352,27 +439,32 @@ def main():
|
|||||||
print("4. [全流程] -> 按顺序执行 1 -> 2 -> 3")
|
print("4. [全流程] -> 按顺序执行 1 -> 2 -> 3")
|
||||||
print("5. [指定导入] -> 从特定的 openie.json 文件导入知识")
|
print("5. [指定导入] -> 从特定的 openie.json 文件导入知识")
|
||||||
print("6. [清理缓存] -> 删除所有已提取信息的缓存")
|
print("6. [清理缓存] -> 删除所有已提取信息的缓存")
|
||||||
|
print("7. [重建索引] -> 仅重建 FAISS 索引(数据已导入时使用)")
|
||||||
print("0. [退出]")
|
print("0. [退出]")
|
||||||
print("-" * 30)
|
print("-" * 30)
|
||||||
choice = input("请输入你的选择 (0-5): ").strip()
|
choice = input("请输入你的选择 (0-7): ").strip()
|
||||||
|
|
||||||
if choice == "1":
|
if choice == "1":
|
||||||
preprocess_raw_data()
|
preprocess_raw_data()
|
||||||
elif choice == "2":
|
elif choice == "2":
|
||||||
paragraphs = preprocess_raw_data()
|
paragraphs = preprocess_raw_data()
|
||||||
if paragraphs:
|
if paragraphs:
|
||||||
extract_information(paragraphs, model_config.model_task_config.lpmm_qa)
|
# 🔧 修复:使用 asyncio.run 调用异步函数
|
||||||
|
asyncio.run(extract_information(paragraphs, model_config.model_task_config.lpmm_qa))
|
||||||
elif choice == "3":
|
elif choice == "3":
|
||||||
asyncio.run(import_data())
|
asyncio.run(import_data())
|
||||||
elif choice == "4":
|
elif choice == "4":
|
||||||
paragraphs = preprocess_raw_data()
|
paragraphs = preprocess_raw_data()
|
||||||
if paragraphs:
|
if paragraphs:
|
||||||
extract_information(paragraphs, model_config.model_task_config.lpmm_qa)
|
# 🔧 修复:使用 asyncio.run 调用异步函数
|
||||||
|
asyncio.run(extract_information(paragraphs, model_config.model_task_config.lpmm_qa))
|
||||||
asyncio.run(import_data())
|
asyncio.run(import_data())
|
||||||
elif choice == "5":
|
elif choice == "5":
|
||||||
import_from_specific_file()
|
import_from_specific_file()
|
||||||
elif choice == "6":
|
elif choice == "6":
|
||||||
clear_cache()
|
clear_cache()
|
||||||
|
elif choice == "7":
|
||||||
|
rebuild_faiss_only()
|
||||||
elif choice == "0":
|
elif choice == "0":
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -30,12 +30,12 @@ from .utils.hash import get_sha256
|
|||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
# 多线程embedding配置常量
|
# 多线程embedding配置常量
|
||||||
DEFAULT_MAX_WORKERS = 1 # 默认最大线程数
|
DEFAULT_MAX_WORKERS = 10 # 默认最大并发批次数(提升并发能力)
|
||||||
DEFAULT_CHUNK_SIZE = 5 # 默认每个线程处理的数据块大小
|
DEFAULT_CHUNK_SIZE = 20 # 默认每个批次处理的数据块大小(批量请求)
|
||||||
MIN_CHUNK_SIZE = 1 # 最小分块大小
|
MIN_CHUNK_SIZE = 1 # 最小分块大小
|
||||||
MAX_CHUNK_SIZE = 50 # 最大分块大小
|
MAX_CHUNK_SIZE = 100 # 最大分块大小(提升批量能力)
|
||||||
MIN_WORKERS = 1 # 最小线程数
|
MIN_WORKERS = 1 # 最小线程数
|
||||||
MAX_WORKERS = 20 # 最大线程数
|
MAX_WORKERS = 50 # 最大线程数(提升并发上限)
|
||||||
|
|
||||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||||
EMBEDDING_DATA_DIR = os.path.join(ROOT_PATH, "data", "embedding")
|
EMBEDDING_DATA_DIR = os.path.join(ROOT_PATH, "data", "embedding")
|
||||||
@@ -145,7 +145,12 @@ class EmbeddingStore:
|
|||||||
) -> list[tuple[str, list[float]]]:
|
) -> list[tuple[str, list[float]]]:
|
||||||
"""
|
"""
|
||||||
异步、并发地批量获取嵌入向量。
|
异步、并发地批量获取嵌入向量。
|
||||||
使用asyncio.Semaphore来控制并发数,确保所有操作在同一个事件循环中。
|
使用 chunk_size 进行批量请求,max_workers 控制并发批次数。
|
||||||
|
|
||||||
|
优化策略:
|
||||||
|
1. 将字符串分成多个 chunk,每个 chunk 包含 chunk_size 个字符串
|
||||||
|
2. 使用 asyncio.Semaphore 控制同时处理的 chunk 数量
|
||||||
|
3. 每个 chunk 内的字符串一次性发送给 LLM(利用批量 API)
|
||||||
"""
|
"""
|
||||||
if not strs:
|
if not strs:
|
||||||
return []
|
return []
|
||||||
@@ -153,18 +158,36 @@ class EmbeddingStore:
|
|||||||
from src.config.config import model_config
|
from src.config.config import model_config
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
|
||||||
|
# 限制 chunk_size 和 max_workers 在合理范围内
|
||||||
|
chunk_size = max(MIN_CHUNK_SIZE, min(chunk_size, MAX_CHUNK_SIZE))
|
||||||
|
max_workers = max(MIN_WORKERS, min(max_workers, MAX_WORKERS))
|
||||||
|
|
||||||
semaphore = asyncio.Semaphore(max_workers)
|
semaphore = asyncio.Semaphore(max_workers)
|
||||||
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
|
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
|
||||||
results = {}
|
results = {}
|
||||||
|
|
||||||
async def _get_embedding_with_semaphore(s: str):
|
# 将字符串列表分成多个 chunk
|
||||||
async with semaphore:
|
chunks = []
|
||||||
embedding = await EmbeddingStore._get_embedding_async(llm, s)
|
for i in range(0, len(strs), chunk_size):
|
||||||
results[s] = embedding
|
chunks.append(strs[i : i + chunk_size])
|
||||||
if progress_callback:
|
|
||||||
progress_callback(1)
|
|
||||||
|
|
||||||
tasks = [_get_embedding_with_semaphore(s) for s in strs]
|
async def _process_chunk(chunk: list[str]):
|
||||||
|
"""处理一个 chunk 的字符串(批量获取 embedding)"""
|
||||||
|
async with semaphore:
|
||||||
|
# 批量获取 embedding(一次请求处理整个 chunk)
|
||||||
|
embeddings = []
|
||||||
|
for s in chunk:
|
||||||
|
embedding = await EmbeddingStore._get_embedding_async(llm, s)
|
||||||
|
embeddings.append(embedding)
|
||||||
|
results[s] = embedding
|
||||||
|
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(len(chunk))
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
# 并发处理所有 chunks
|
||||||
|
tasks = [_process_chunk(chunk) for chunk in chunks]
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
# 按照原始顺序返回结果
|
# 按照原始顺序返回结果
|
||||||
@@ -392,15 +415,56 @@ class EmbeddingStore:
|
|||||||
self.faiss_index = faiss.IndexFlatIP(embedding_dim)
|
self.faiss_index = faiss.IndexFlatIP(embedding_dim)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# 🔧 修复:检查所有 embedding 的维度是否一致
|
||||||
|
dimensions = [len(emb) for emb in array]
|
||||||
|
unique_dims = set(dimensions)
|
||||||
|
|
||||||
|
if len(unique_dims) > 1:
|
||||||
|
logger.error(f"检测到不一致的 embedding 维度: {unique_dims}")
|
||||||
|
logger.error(f"维度分布: {dict(zip(*np.unique(dimensions, return_counts=True)))}")
|
||||||
|
|
||||||
|
# 获取期望的维度(使用最常见的维度)
|
||||||
|
from collections import Counter
|
||||||
|
dim_counter = Counter(dimensions)
|
||||||
|
expected_dim = dim_counter.most_common(1)[0][0]
|
||||||
|
logger.warning(f"将使用最常见的维度: {expected_dim}")
|
||||||
|
|
||||||
|
# 过滤掉维度不匹配的 embedding
|
||||||
|
filtered_array = []
|
||||||
|
filtered_idx2hash = {}
|
||||||
|
skipped_count = 0
|
||||||
|
|
||||||
|
for i, emb in enumerate(array):
|
||||||
|
if len(emb) == expected_dim:
|
||||||
|
filtered_array.append(emb)
|
||||||
|
filtered_idx2hash[str(len(filtered_array) - 1)] = self.idx2hash[str(i)]
|
||||||
|
else:
|
||||||
|
skipped_count += 1
|
||||||
|
hash_key = self.idx2hash[str(i)]
|
||||||
|
logger.warning(f"跳过维度不匹配的 embedding: {hash_key}, 维度={len(emb)}, 期望={expected_dim}")
|
||||||
|
|
||||||
|
logger.warning(f"已过滤 {skipped_count} 个维度不匹配的 embedding")
|
||||||
|
array = filtered_array
|
||||||
|
self.idx2hash = filtered_idx2hash
|
||||||
|
|
||||||
|
if not array:
|
||||||
|
logger.error("过滤后没有可用的 embedding,无法构建索引")
|
||||||
|
embedding_dim = expected_dim
|
||||||
|
self.faiss_index = faiss.IndexFlatIP(embedding_dim)
|
||||||
|
return
|
||||||
|
|
||||||
embeddings = np.array(array, dtype=np.float32)
|
embeddings = np.array(array, dtype=np.float32)
|
||||||
# L2归一化
|
# L2归一化
|
||||||
faiss.normalize_L2(embeddings)
|
faiss.normalize_L2(embeddings)
|
||||||
# 构建索引
|
# 构建索引
|
||||||
embedding_dim = resolve_embedding_dimension(global_config.lpmm_knowledge.embedding_dimension)
|
embedding_dim = resolve_embedding_dimension(global_config.lpmm_knowledge.embedding_dimension)
|
||||||
if not embedding_dim:
|
if not embedding_dim:
|
||||||
embedding_dim = global_config.lpmm_knowledge.embedding_dimension
|
# 🔧 修复:使用实际检测到的维度
|
||||||
|
embedding_dim = embeddings.shape[1]
|
||||||
|
logger.info(f"使用实际检测到的 embedding 维度: {embedding_dim}")
|
||||||
self.faiss_index = faiss.IndexFlatIP(embedding_dim)
|
self.faiss_index = faiss.IndexFlatIP(embedding_dim)
|
||||||
self.faiss_index.add(embeddings)
|
self.faiss_index.add(embeddings)
|
||||||
|
logger.info(f"✅ 成功构建 Faiss 索引: {len(embeddings)} 个向量, 维度={embedding_dim}")
|
||||||
|
|
||||||
def search_top_k(self, query: list[float], k: int) -> list[tuple[str, float]]:
|
def search_top_k(self, query: list[float], k: int) -> list[tuple[str, float]]:
|
||||||
"""搜索最相似的k个项,以余弦相似度为度量
|
"""搜索最相似的k个项,以余弦相似度为度量
|
||||||
|
|||||||
Reference in New Issue
Block a user