feat: 批量生成文本embedding,优化兴趣匹配计算逻辑,支持消息兴趣值的批量更新
This commit is contained in:
@@ -1111,15 +1111,15 @@ class LLMRequest:
|
||||
|
||||
return response.content or "", (response.reasoning_content or "", model_info.name, response.tool_calls)
|
||||
|
||||
async def get_embedding(self, embedding_input: str) -> tuple[list[float], str]:
|
||||
async def get_embedding(self, embedding_input: str | list[str]) -> tuple[list[float] | list[list[float]], str]:
|
||||
"""
|
||||
获取嵌入向量。
|
||||
获取嵌入向量,支持批量文本
|
||||
|
||||
Args:
|
||||
embedding_input (str): 获取嵌入的目标
|
||||
embedding_input (str | list[str]): 需要生成嵌入的文本或文本列表
|
||||
|
||||
Returns:
|
||||
(Tuple[List[float], str]): (嵌入向量,使用的模型名称)
|
||||
(Tuple[Union[List[float], List[List[float]]], str]): 嵌入结果及使用的模型名称
|
||||
"""
|
||||
start_time = time.time()
|
||||
response, model_info = await self._strategy.execute_with_failover(
|
||||
@@ -1128,10 +1128,25 @@ class LLMRequest:
|
||||
|
||||
await self._record_usage(model_info, response.usage, time.time() - start_time, "/embeddings")
|
||||
|
||||
if not response.embedding:
|
||||
if response.embedding is None:
|
||||
raise RuntimeError("获取embedding失败")
|
||||
|
||||
return response.embedding, model_info.name
|
||||
embeddings = response.embedding
|
||||
is_batch_request = isinstance(embedding_input, list)
|
||||
|
||||
if is_batch_request:
|
||||
if not isinstance(embeddings, list):
|
||||
raise RuntimeError("获取embedding失败,批量结果格式异常")
|
||||
|
||||
if embeddings and not isinstance(embeddings[0], list):
|
||||
embeddings = [embeddings] # type: ignore[list-item]
|
||||
|
||||
return embeddings, model_info.name
|
||||
|
||||
if isinstance(embeddings, list) and embeddings and isinstance(embeddings[0], list):
|
||||
return embeddings[0], model_info.name
|
||||
|
||||
return embeddings, model_info.name
|
||||
|
||||
async def _record_usage(self, model_info: ModelInfo, usage: UsageRecord | None, time_cost: float, endpoint: str):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user