feat: 批量生成文本embedding,优化兴趣匹配计算逻辑,支持消息兴趣值的批量更新

This commit is contained in:
Windpicker-owo
2025-11-19 16:30:44 +08:00
parent a11d251ec1
commit 14133410e6
15 changed files with 231 additions and 323 deletions

View File

@@ -1111,15 +1111,15 @@ class LLMRequest:
return response.content or "", (response.reasoning_content or "", model_info.name, response.tool_calls)
async def get_embedding(self, embedding_input: str) -> tuple[list[float], str]:
async def get_embedding(self, embedding_input: str | list[str]) -> tuple[list[float] | list[list[float]], str]:
"""
获取嵌入向量
获取嵌入向量,支持批量文本
Args:
embedding_input (str): 获取嵌入的目标
embedding_input (str | list[str]): 需要生成嵌入的文本或文本列表
Returns:
(Tuple[List[float], str]): (嵌入向量,使用的模型名称)
(Tuple[Union[List[float], List[List[float]]], str]): 嵌入结果及使用的模型名称
"""
start_time = time.time()
response, model_info = await self._strategy.execute_with_failover(
@@ -1128,10 +1128,25 @@ class LLMRequest:
await self._record_usage(model_info, response.usage, time.time() - start_time, "/embeddings")
if not response.embedding:
if response.embedding is None:
raise RuntimeError("获取embedding失败")
return response.embedding, model_info.name
embeddings = response.embedding
is_batch_request = isinstance(embedding_input, list)
if is_batch_request:
if not isinstance(embeddings, list):
raise RuntimeError("获取embedding失败批量结果格式异常")
if embeddings and not isinstance(embeddings[0], list):
embeddings = [embeddings] # type: ignore[list-item]
return embeddings, model_info.name
if isinstance(embeddings, list) and embeddings and isinstance(embeddings[0], list):
return embeddings[0], model_info.name
return embeddings, model_info.name
async def _record_usage(self, model_info: ModelInfo, usage: UsageRecord | None, time_cost: float, endpoint: str):
"""