fix(embedding): 彻底解决事件循环冲突导致的嵌入生成异常

通过以下改动修复嵌入生成过程中的事件循环相关问题:
- 在 EmbeddingStore._get_embedding 中,改为同步创建-使用-销毁的新事件循环模式,彻底避免嵌套事件循环问题
- 调整批量嵌入 _get_embeddings_batch_threaded,确保每个线程使用独立、短生命周期的事件循环
- 新增 force_new 参数,LLM 请求嵌入任务时强制创建新的客户端实例,减少跨循环对象复用
- 在 OpenAI 客户端的 embedding 调用处补充详细日志,方便排查网络连接异常
- get_embedding() 每次都重建 LLMRequest,降低实例在多个事件循环中穿梭的概率

此次改动虽然以同步风格“硬掰”异步接口,但对现有接口零破坏,确保了向量数据库及相关知识检索功能的稳定性。(还有就是把的脚本文件夹移回来了)
This commit is contained in:
minecraft1024a
2025-08-19 20:41:00 +08:00
parent f3b5836eee
commit 3bef6f4bab
16 changed files with 4695 additions and 23 deletions

View File

@@ -117,30 +117,36 @@ class EmbeddingStore:
self.idx2hash = None
def _get_embedding(self, s: str) -> List[float]:
"""获取字符串的嵌入向量,处理异步调用"""
"""获取字符串的嵌入向量,使用完全同步的方式避免事件循环问题"""
# 创建新的事件循环并在完成后立即关闭
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
# 尝试获取当前事件循环
asyncio.get_running_loop()
# 如果在事件循环中,使用线程池执行
import concurrent.futures
# 创建新的LLMRequest实例
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
def run_in_thread():
return asyncio.run(get_embedding(s))
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_in_thread)
result = future.result()
if result is None:
logger.error(f"获取嵌入失败: {s}")
return []
return result
except RuntimeError:
# 没有运行的事件循环,直接运行
result = asyncio.run(get_embedding(s))
if result is None:
# 使用新的事件循环运行异步方法
embedding, _ = loop.run_until_complete(llm.get_embedding(s))
if embedding and len(embedding) > 0:
return embedding
else:
logger.error(f"获取嵌入失败: {s}")
return []
return result
except Exception as e:
logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}")
return []
finally:
# 确保事件循环被正确关闭
try:
loop.close()
except Exception:
pass
def _get_embeddings_batch_threaded(self, strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None) -> List[Tuple[str, List[float]]]:
"""使用多线程批量获取嵌入向量
@@ -181,8 +187,14 @@ class EmbeddingStore:
for i, s in enumerate(chunk_strs):
try:
# 直接使用异步函数
embedding = asyncio.run(llm.get_embedding(s))
# 在线程中创建独立的事件循环
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
embedding = loop.run_until_complete(llm.get_embedding(s))
finally:
loop.close()
if embedding and len(embedding) > 0:
chunk_results.append((start_idx + i, s, embedding[0])) # embedding[0] 是实际的向量
else:

View File

@@ -111,6 +111,7 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
async def get_embedding(text, request_type="embedding") -> Optional[List[float]]:
"""获取文本的embedding向量"""
# 每次都创建新的LLMRequest实例以避免事件循环冲突
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type=request_type)
try:
embedding, _ = await llm.get_embedding(text)

View File

@@ -159,14 +159,23 @@ class ClientRegistry:
return decorator
def get_client_class_instance(self, api_provider: APIProvider) -> BaseClient:
def get_client_class_instance(self, api_provider: APIProvider, force_new=False) -> BaseClient:
"""
获取注册的API客户端实例
Args:
api_provider: APIProvider实例
force_new: 是否强制创建新实例(用于解决事件循环问题)
Returns:
BaseClient: 注册的API客户端实例
"""
# 如果强制创建新实例,直接创建不使用缓存
if force_new:
if client_class := self.client_registry.get(api_provider.client_type):
return client_class(api_provider)
else:
raise KeyError(f"'{api_provider.client_type}' 类型的 Client 未注册")
# 正常的缓存逻辑
if api_provider.name not in self.client_instance_cache:
if client_class := self.client_registry.get(api_provider.client_type):
self.client_instance_cache[api_provider.name] = client_class(api_provider)

View File

@@ -380,6 +380,7 @@ class OpenaiClient(BaseClient):
base_url=api_provider.base_url,
api_key=api_provider.api_key,
max_retries=0,
timeout=api_provider.timeout,
)
async def get_response(
@@ -512,6 +513,11 @@ class OpenaiClient(BaseClient):
extra_body=extra_params,
)
except APIConnectionError as e:
# 添加详细的错误信息以便调试
logger.error(f"OpenAI API连接错误嵌入模型: {str(e)}")
logger.error(f"错误类型: {type(e)}")
if hasattr(e, '__cause__') and e.__cause__:
logger.error(f"底层错误: {str(e.__cause__)}")
raise NetworkConnectionError() from e
except APIStatusError as e:
# 重封装APIError为RespNotOkException

View File

@@ -416,7 +416,10 @@ class LLMRequest:
)
model_info = model_config.get_model_info(least_used_model_name)
api_provider = model_config.get_provider(model_info.api_provider)
client = client_registry.get_client_class_instance(api_provider)
# 对于嵌入任务,强制创建新的客户端实例以避免事件循环问题
force_new_client = (self.request_type == "embedding")
client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client)
logger.debug(f"选择请求模型: {model_info.name}")
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1) # 增加使用惩罚值防止连续使用