some typing
This commit is contained in:
@@ -106,10 +106,10 @@ class EmbeddingStore:
|
||||
asyncio.get_running_loop()
|
||||
# 如果在事件循环中,使用线程池执行
|
||||
import concurrent.futures
|
||||
|
||||
|
||||
def run_in_thread():
|
||||
return asyncio.run(get_embedding(s))
|
||||
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(run_in_thread)
|
||||
result = future.result()
|
||||
@@ -294,10 +294,10 @@ class EmbeddingStore:
|
||||
"""
|
||||
if self.faiss_index is None:
|
||||
logger.debug("FaissIndex尚未构建,返回None")
|
||||
return None
|
||||
return []
|
||||
if self.idx2hash is None:
|
||||
logger.warning("idx2hash尚未构建,返回None")
|
||||
return None
|
||||
return []
|
||||
|
||||
# L2归一化
|
||||
faiss.normalize_L2(np.array([query], dtype=np.float32))
|
||||
@@ -318,15 +318,15 @@ class EmbeddingStore:
|
||||
class EmbeddingManager:
|
||||
def __init__(self):
|
||||
self.paragraphs_embedding_store = EmbeddingStore(
|
||||
local_storage['pg_namespace'],
|
||||
local_storage["pg_namespace"], # type: ignore
|
||||
EMBEDDING_DATA_DIR_STR,
|
||||
)
|
||||
self.entities_embedding_store = EmbeddingStore(
|
||||
local_storage['pg_namespace'],
|
||||
local_storage["pg_namespace"], # type: ignore
|
||||
EMBEDDING_DATA_DIR_STR,
|
||||
)
|
||||
self.relation_embedding_store = EmbeddingStore(
|
||||
local_storage['pg_namespace'],
|
||||
local_storage["pg_namespace"], # type: ignore
|
||||
EMBEDDING_DATA_DIR_STR,
|
||||
)
|
||||
self.stored_pg_hashes = set()
|
||||
|
||||
Reference in New Issue
Block a user