fix:修复LPMM学习问题
This commit is contained in:
@@ -2,6 +2,7 @@ from dataclasses import dataclass
|
||||
import json
|
||||
import os
|
||||
import math
|
||||
import asyncio
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
@@ -99,7 +100,30 @@ class EmbeddingStore:
|
||||
self.idx2hash = None
|
||||
|
||||
def _get_embedding(self, s: str) -> List[float]:
|
||||
return get_embedding(s)
|
||||
"""获取字符串的嵌入向量,处理异步调用"""
|
||||
try:
|
||||
# 尝试获取当前事件循环
|
||||
asyncio.get_running_loop()
|
||||
# 如果在事件循环中,使用线程池执行
|
||||
import concurrent.futures
|
||||
|
||||
def run_in_thread():
|
||||
return asyncio.run(get_embedding(s))
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(run_in_thread)
|
||||
result = future.result()
|
||||
if result is None:
|
||||
logger.error(f"获取嵌入失败: {s}")
|
||||
return []
|
||||
return result
|
||||
except RuntimeError:
|
||||
# 没有运行的事件循环,直接运行
|
||||
result = asyncio.run(get_embedding(s))
|
||||
if result is None:
|
||||
logger.error(f"获取嵌入失败: {s}")
|
||||
return []
|
||||
return result
|
||||
|
||||
def get_test_file_path(self):
|
||||
return EMBEDDING_TEST_FILE
|
||||
|
||||
Reference in New Issue
Block a user