feat(embedding): 优化嵌入处理,支持 NumPy 数组格式并减少内存分配
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
"""
|
||||
|
||||
import traceback
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime
|
||||
from typing import Any, cast
|
||||
|
||||
@@ -18,15 +19,20 @@ from src.utils.json_parser import extract_and_parse_json
|
||||
|
||||
logger = get_logger("bot_interest_manager")
|
||||
|
||||
# 🔧 内存优化配置
|
||||
MAX_EMBEDDING_CACHE_SIZE = 500 # embedding 缓存最大条目数(LRU淘汰)
|
||||
MAX_EXPANDED_TAG_CACHE_SIZE = 200 # 扩展标签缓存最大条目数
|
||||
|
||||
|
||||
class BotInterestManager:
|
||||
"""机器人兴趣标签管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.current_interests: BotPersonalityInterests | None = None
|
||||
self.embedding_cache: dict[str, list[float]] = {} # embedding缓存
|
||||
self.expanded_tag_cache: dict[str, str] = {} # 扩展标签缓存
|
||||
self.expanded_embedding_cache: dict[str, list[float]] = {} # 扩展标签的embedding缓存
|
||||
# 🔧 使用 OrderedDict 实现 LRU 缓存,避免无限增长
|
||||
self.embedding_cache: OrderedDict[str, np.ndarray] = OrderedDict() # embedding缓存(NumPy格式)
|
||||
self.expanded_tag_cache: OrderedDict[str, str] = OrderedDict() # 扩展标签缓存
|
||||
self.expanded_embedding_cache: OrderedDict[str, np.ndarray] = OrderedDict() # 扩展标签的embedding缓存
|
||||
self._initialized = False
|
||||
|
||||
# Embedding客户端配置
|
||||
@@ -358,7 +364,7 @@ class BotInterestManager:
|
||||
embedding_text = tag.tag_name
|
||||
embedding = await self._get_embedding(embedding_text)
|
||||
|
||||
if embedding:
|
||||
if embedding is not None and embedding.size > 0:
|
||||
tag.embedding = embedding # 设置到 tag 对象(内存中)
|
||||
self.embedding_cache[tag.tag_name] = embedding # 同时缓存到内存
|
||||
generated_count += 1
|
||||
@@ -376,16 +382,20 @@ class BotInterestManager:
|
||||
|
||||
interests.last_updated = datetime.now()
|
||||
|
||||
async def _get_embedding(self, text: str, cache: bool = True) -> list[float]:
|
||||
async def _get_embedding(self, text: str, cache: bool = True) -> np.ndarray:
|
||||
"""获取文本的embedding向量
|
||||
|
||||
cache=False 用于消息内容,避免在 embedding_cache 中长期保留大文本导致内存膨胀。
|
||||
|
||||
- 返回 NumPy 数组而非 list[float],减少对象分配
|
||||
- 实现 LRU 缓存,防止缓存无限增长
|
||||
"""
|
||||
if not hasattr(self, "embedding_request"):
|
||||
raise RuntimeError("Embedding请求客户端未初始化")
|
||||
|
||||
# 检查缓存
|
||||
# LRU 缓存查找:移到末尾表示最近使用
|
||||
if cache and text in self.embedding_cache:
|
||||
self.embedding_cache.move_to_end(text)
|
||||
return self.embedding_cache[text]
|
||||
|
||||
# 使用LLMRequest获取embedding
|
||||
@@ -393,18 +403,42 @@ class BotInterestManager:
|
||||
raise RuntimeError("Embedding客户端未初始化")
|
||||
embedding, model_name = await self.embedding_request.get_embedding(text)
|
||||
|
||||
if embedding and len(embedding) > 0:
|
||||
if isinstance(embedding[0], list):
|
||||
# If it's a list of lists, take the first one (though get_embedding(str) should return list[float])
|
||||
embedding = embedding[0]
|
||||
|
||||
# Now we can safely cast to list[float] as we've handled the nested list case
|
||||
embedding_float = cast(list[float], embedding)
|
||||
if embedding is not None and (isinstance(embedding, np.ndarray) and embedding.size > 0 or isinstance(embedding, list) and len(embedding) > 0):
|
||||
# 处理不同类型的 embedding 返回值
|
||||
# 类型注解确保返回 np.ndarray
|
||||
embedding_array: np.ndarray
|
||||
if isinstance(embedding, np.ndarray):
|
||||
# 已经是 NumPy 数组,检查维度
|
||||
if embedding.ndim == 1:
|
||||
# 一维数组,直接使用
|
||||
embedding_array = embedding
|
||||
elif embedding.ndim == 2:
|
||||
# 二维数组(批量结果),取第一行
|
||||
logger.warning(f"_get_embedding 收到二维数组 {embedding.shape},取第一行作为单个向量")
|
||||
embedding_array = embedding[0]
|
||||
else:
|
||||
raise RuntimeError(f"不支持的 embedding 维度: {embedding.ndim},形状: {embedding.shape}")
|
||||
elif isinstance(embedding, list):
|
||||
if len(embedding) > 0 and isinstance(embedding[0], list):
|
||||
# 嵌套列表,取第一个
|
||||
embedding_array = np.array(embedding[0], dtype=np.float32)
|
||||
else:
|
||||
# 普通列表
|
||||
embedding_array = np.array(embedding, dtype=np.float32)
|
||||
else:
|
||||
raise RuntimeError(f"不支持的 embedding 类型: {type(embedding)}")
|
||||
|
||||
# 🔧 LRU 缓存写入:自动淘汰最旧条目
|
||||
if cache:
|
||||
self.embedding_cache[text] = embedding_float
|
||||
self.embedding_cache[text] = embedding_array
|
||||
self.embedding_cache.move_to_end(text)
|
||||
# 超过限制时删除最旧条目
|
||||
if len(self.embedding_cache) > MAX_EMBEDDING_CACHE_SIZE:
|
||||
oldest_key = next(iter(self.embedding_cache))
|
||||
del self.embedding_cache[oldest_key]
|
||||
logger.debug(f"LRU缓存淘汰: '{oldest_key}' (当前大小: {len(self.embedding_cache)})")
|
||||
|
||||
current_dim = len(embedding_float)
|
||||
current_dim = embedding_array.shape[0]
|
||||
if self._detected_embedding_dimension is None:
|
||||
self._detected_embedding_dimension = current_dim
|
||||
if self.embedding_dimension and self.embedding_dimension != current_dim:
|
||||
@@ -421,11 +455,11 @@ class BotInterestManager:
|
||||
self.embedding_dimension,
|
||||
current_dim,
|
||||
)
|
||||
return embedding_float
|
||||
return embedding_array
|
||||
else:
|
||||
raise RuntimeError(f"返回的embedding为空: {embedding}")
|
||||
|
||||
async def _generate_message_embedding(self, message_text: str, keywords: list[str]) -> list[float]:
|
||||
async def _generate_message_embedding(self, message_text: str, keywords: list[str]) -> np.ndarray:
|
||||
"""为消息生成embedding向量"""
|
||||
# 组合消息文本和关键词作为embedding输入
|
||||
if keywords:
|
||||
@@ -439,8 +473,11 @@ class BotInterestManager:
|
||||
|
||||
async def generate_embeddings_for_texts(
|
||||
self, text_map: dict[str, str], batch_size: int = 16
|
||||
) -> dict[str, list[float]]:
|
||||
"""批量获取多段文本的embedding,供上层统一处理。"""
|
||||
) -> dict[str, np.ndarray]:
|
||||
"""批量获取多段文本的embedding,供上层统一处理。
|
||||
|
||||
返回 NumPy 数组而非 list[float],减少对象分配
|
||||
"""
|
||||
if not text_map:
|
||||
return {}
|
||||
|
||||
@@ -449,7 +486,7 @@ class BotInterestManager:
|
||||
|
||||
batch_size = max(1, batch_size)
|
||||
keys = list(text_map.keys())
|
||||
results: dict[str, list[float]] = {}
|
||||
results: dict[str, np.ndarray] = {}
|
||||
|
||||
for start in range(0, len(keys), batch_size):
|
||||
chunk_keys = keys[start : start + batch_size]
|
||||
@@ -461,26 +498,48 @@ class BotInterestManager:
|
||||
logger.error(f"批量获取embedding失败 (chunk {start // batch_size + 1}): {exc}")
|
||||
continue
|
||||
|
||||
if isinstance(chunk_embeddings, list) and chunk_embeddings and isinstance(chunk_embeddings[0], list):
|
||||
normalized = chunk_embeddings
|
||||
elif isinstance(chunk_embeddings, list):
|
||||
normalized = [chunk_embeddings]
|
||||
else:
|
||||
normalized = []
|
||||
# 🔧 处理不同类型的返回值,统一转换为 NumPy 数组列表
|
||||
normalized: list[np.ndarray] = []
|
||||
|
||||
if isinstance(chunk_embeddings, np.ndarray):
|
||||
# NumPy 数组:检查是一维还是二维
|
||||
if chunk_embeddings.ndim == 1:
|
||||
# 一维数组(单个向量),包装为列表
|
||||
normalized = [chunk_embeddings]
|
||||
elif chunk_embeddings.ndim == 2:
|
||||
# 二维数组(批量向量),拆分为列表
|
||||
normalized = [chunk_embeddings[i] for i in range(chunk_embeddings.shape[0])] # type: ignore
|
||||
else:
|
||||
logger.warning(f"意外的 embedding 维度: {chunk_embeddings.ndim},形状: {chunk_embeddings.shape}")
|
||||
normalized = []
|
||||
elif isinstance(chunk_embeddings, list) and chunk_embeddings:
|
||||
if isinstance(chunk_embeddings[0], np.ndarray):
|
||||
# 已经是 NumPy 数组列表
|
||||
normalized = chunk_embeddings # type: ignore
|
||||
elif isinstance(chunk_embeddings[0], list):
|
||||
# list[list[float]] 格式,转换为 NumPy 数组
|
||||
normalized = [np.array(vec, dtype=np.float32) for vec in chunk_embeddings]
|
||||
else:
|
||||
# 单个向量,包装为列表
|
||||
normalized = [np.array(chunk_embeddings, dtype=np.float32)]
|
||||
|
||||
for idx_offset, message_id in enumerate(chunk_keys):
|
||||
vector = normalized[idx_offset] if idx_offset < len(normalized) else []
|
||||
if isinstance(vector, list) and vector and isinstance(vector[0], float):
|
||||
results[message_id] = cast(list[float], vector)
|
||||
if idx_offset < len(normalized):
|
||||
results[message_id] = normalized[idx_offset]
|
||||
else:
|
||||
results[message_id] = []
|
||||
# 返回空数组而非空列表
|
||||
results[message_id] = np.array([], dtype=np.float32)
|
||||
|
||||
|
||||
return results
|
||||
|
||||
async def _calculate_similarity_scores(
|
||||
self, result: InterestMatchResult, message_embedding: list[float], keywords: list[str]
|
||||
self, result: InterestMatchResult, message_embedding: np.ndarray, keywords: list[str]
|
||||
):
|
||||
"""计算消息与兴趣标签的相似度分数"""
|
||||
"""计算消息与兴趣标签的相似度分数
|
||||
|
||||
🔧 内存优化:接受 NumPy 数组参数,避免类型转换
|
||||
"""
|
||||
try:
|
||||
if not self.current_interests:
|
||||
return
|
||||
@@ -492,9 +551,12 @@ class BotInterestManager:
|
||||
logger.debug(f"开始计算与 {len(active_tags)} 个兴趣标签的相似度")
|
||||
|
||||
for tag in active_tags:
|
||||
if tag.embedding:
|
||||
if tag.embedding is not None:
|
||||
# 确保 tag.embedding 是 NumPy 数组
|
||||
tag_embedding = tag.embedding if isinstance(tag.embedding, np.ndarray) else np.array(tag.embedding, dtype=np.float32)
|
||||
|
||||
# 计算余弦相似度
|
||||
similarity = self._calculate_cosine_similarity(message_embedding, tag.embedding)
|
||||
similarity = self._calculate_cosine_similarity(message_embedding, tag_embedding)
|
||||
weighted_score = similarity * tag.weight
|
||||
|
||||
# 设置相似度阈值为0.3
|
||||
@@ -508,7 +570,7 @@ class BotInterestManager:
|
||||
logger.error(f"计算相似度分数失败: {e}")
|
||||
|
||||
async def calculate_interest_match(
|
||||
self, message_text: str, keywords: list[str] | None = None, message_embedding: list[float] | None = None
|
||||
self, message_text: str, keywords: list[str] | None = None, message_embedding: np.ndarray | None = None
|
||||
) -> InterestMatchResult:
|
||||
"""计算消息与机器人兴趣的匹配度(优化版 - 标签扩展策略)
|
||||
|
||||
@@ -540,7 +602,7 @@ class BotInterestManager:
|
||||
|
||||
# 生成消息的embedding
|
||||
logger.debug("正在生成消息 embedding...")
|
||||
if not message_embedding:
|
||||
if message_embedding is None:
|
||||
# 消息文本embedding不入全局缓存,避免缓存随着对话历史无限增长
|
||||
message_embedding = await self._get_embedding(message_text, cache=False)
|
||||
logger.debug(f"消息 embedding 生成成功, 维度: {len(message_embedding)}")
|
||||
@@ -563,11 +625,11 @@ class BotInterestManager:
|
||||
logger.debug(f"使用分级相似度阈值: 高={high_threshold}, 中={medium_threshold}, 低={low_threshold}")
|
||||
|
||||
for tag in active_tags:
|
||||
if tag.embedding:
|
||||
if tag.embedding is not None and (isinstance(tag.embedding, np.ndarray) and tag.embedding.size > 0 or isinstance(tag.embedding, list) and len(tag.embedding) > 0):
|
||||
# 🔧 优化:获取扩展标签的 embedding(带缓存)
|
||||
expanded_embedding = await self._get_expanded_tag_embedding(tag.tag_name)
|
||||
|
||||
if expanded_embedding:
|
||||
if expanded_embedding is not None and expanded_embedding.size > 0:
|
||||
# 使用扩展标签的 embedding 进行匹配
|
||||
similarity = self._calculate_cosine_similarity(message_embedding, expanded_embedding)
|
||||
|
||||
@@ -651,7 +713,7 @@ class BotInterestManager:
|
||||
|
||||
return result
|
||||
|
||||
async def _get_expanded_tag_embedding(self, tag_name: str) -> list[float] | None:
|
||||
async def _get_expanded_tag_embedding(self, tag_name: str) -> np.ndarray | None:
|
||||
"""获取扩展标签的 embedding(带缓存)
|
||||
|
||||
优先使用缓存,如果没有则生成并缓存
|
||||
@@ -666,7 +728,7 @@ class BotInterestManager:
|
||||
# 生成 embedding
|
||||
try:
|
||||
embedding = await self._get_embedding(expanded_tag)
|
||||
if embedding:
|
||||
if embedding is not None and embedding.size > 0:
|
||||
# 缓存结果
|
||||
self.expanded_tag_cache[tag_name] = expanded_tag
|
||||
self.expanded_embedding_cache[tag_name] = embedding
|
||||
@@ -852,11 +914,26 @@ class BotInterestManager:
|
||||
|
||||
return previous_row[-1]
|
||||
|
||||
def _calculate_cosine_similarity(self, vec1: list[float], vec2: list[float]) -> float:
|
||||
"""计算余弦相似度"""
|
||||
def _calculate_cosine_similarity(self, vec1: np.ndarray | list[float], vec2: np.ndarray | list[float]) -> float:
|
||||
"""计算余弦相似度
|
||||
|
||||
支持 NumPy 数组参数,避免重复转换
|
||||
"""
|
||||
try:
|
||||
np_vec1 = np.array(vec1)
|
||||
np_vec2 = np.array(vec2)
|
||||
# 确保是 NumPy 数组
|
||||
np_vec1 = vec1 if isinstance(vec1, np.ndarray) else np.array(vec1, dtype=np.float32)
|
||||
np_vec2 = vec2 if isinstance(vec2, np.ndarray) else np.array(vec2, dtype=np.float32)
|
||||
|
||||
# 🔧 确保是一维数组
|
||||
np_vec1 = np_vec1.flatten()
|
||||
np_vec2 = np_vec2.flatten()
|
||||
|
||||
# 检查维度是否匹配
|
||||
if np_vec1.shape[0] != np_vec2.shape[0]:
|
||||
logger.warning(
|
||||
f"向量维度不匹配: vec1={np_vec1.shape[0]}, vec2={np_vec2.shape[0]},返回0.0"
|
||||
)
|
||||
return 0.0
|
||||
|
||||
dot_product = np.dot(np_vec1, np_vec2)
|
||||
norm1 = np.linalg.norm(np_vec1)
|
||||
@@ -866,7 +943,8 @@ class BotInterestManager:
|
||||
return 0.0
|
||||
|
||||
similarity = dot_product / (norm1 * norm2)
|
||||
return float(similarity)
|
||||
# 🔧 使用 item() 方法安全地提取标量值
|
||||
return float(similarity.item() if hasattr(similarity, 'item') else similarity)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"计算余弦相似度失败: {e}")
|
||||
@@ -1056,8 +1134,11 @@ class BotInterestManager:
|
||||
logger.error("🔍 错误详情:")
|
||||
traceback.print_exc()
|
||||
|
||||
async def _load_embedding_cache_from_file(self, personality_id: str) -> dict[str, list[float]] | None:
|
||||
"""从文件加载embedding缓存"""
|
||||
async def _load_embedding_cache_from_file(self, personality_id: str) -> dict[str, np.ndarray] | None:
|
||||
"""从文件加载embedding缓存
|
||||
|
||||
内存优化:转换为 NumPy 数组格式
|
||||
"""
|
||||
try:
|
||||
from pathlib import Path
|
||||
|
||||
@@ -1089,11 +1170,14 @@ class BotInterestManager:
|
||||
logger.warning(f"⚠️ Embedding模型已变更 ({cache_embedding_model} → {current_embedding_model}),忽略旧缓存")
|
||||
return None
|
||||
|
||||
embeddings = cache_data.get("embeddings", {})
|
||||
# 🔧 转换为 NumPy 数组格式
|
||||
embeddings_raw = cache_data.get("embeddings", {})
|
||||
embeddings = {key: np.array(value, dtype=np.float32) for key, value in embeddings_raw.items()}
|
||||
|
||||
# 同时加载扩展标签的embedding缓存
|
||||
expanded_embeddings = cache_data.get("expanded_embeddings", {})
|
||||
if expanded_embeddings:
|
||||
expanded_embeddings_raw = cache_data.get("expanded_embeddings", {})
|
||||
if expanded_embeddings_raw:
|
||||
expanded_embeddings = {key: np.array(value, dtype=np.float32) for key, value in expanded_embeddings_raw.items()}
|
||||
self.expanded_embedding_cache.update(expanded_embeddings)
|
||||
|
||||
logger.info(f"成功从文件加载 {len(embeddings)} 个标签embedding缓存 (版本: {cache_version}, 模型: {cache_embedding_model})")
|
||||
@@ -1125,13 +1209,17 @@ class BotInterestManager:
|
||||
allowed_keys = {tag.tag_name for tag in self.current_interests.interest_tags}
|
||||
tag_embeddings = {key: value for key, value in self.embedding_cache.items() if key in allowed_keys}
|
||||
|
||||
# 将 NumPy 数组转换为列表以便 JSON 序列化
|
||||
tag_embeddings_serializable = {key: value.tolist() if isinstance(value, np.ndarray) else value for key, value in tag_embeddings.items()}
|
||||
expanded_embeddings_serializable = {key: value.tolist() if isinstance(value, np.ndarray) else value for key, value in self.expanded_embedding_cache.items()}
|
||||
|
||||
cache_data = {
|
||||
"version": 1,
|
||||
"personality_id": personality_id,
|
||||
"embedding_model": current_embedding_model,
|
||||
"last_updated": datetime.now().isoformat(),
|
||||
"embeddings": tag_embeddings,
|
||||
"expanded_embeddings": self.expanded_embedding_cache, # 同时保存扩展标签的embedding
|
||||
"embeddings": tag_embeddings_serializable,
|
||||
"expanded_embeddings": expanded_embeddings_serializable, # 同时保存扩展标签的embedding
|
||||
}
|
||||
|
||||
# 写入文件
|
||||
|
||||
@@ -6,18 +6,24 @@
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from src.config.config import model_config
|
||||
from . import BaseDataModel
|
||||
|
||||
|
||||
@dataclass
|
||||
class BotInterestTag(BaseDataModel):
|
||||
"""机器人兴趣标签"""
|
||||
"""机器人兴趣标签
|
||||
|
||||
embedding 字段支持 NumPy 数组格式,减少对象分配
|
||||
"""
|
||||
|
||||
tag_name: str
|
||||
weight: float = 1.0 # 权重,表示对这个兴趣的喜好程度 (0.0-1.0)
|
||||
expanded: str | None = None # 标签的扩展描述,用于更精准的语义匹配
|
||||
embedding: list[float] | None = None # 标签的embedding向量
|
||||
embedding: np.ndarray | list[float] | None = None # 标签的embedding向量(支持 NumPy 数组)
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
updated_at: datetime = field(default_factory=datetime.now)
|
||||
is_active: bool = True
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import gc
|
||||
import io
|
||||
import re
|
||||
from collections.abc import Callable, Coroutine, Iterable
|
||||
from typing import Any, ClassVar
|
||||
|
||||
import numpy as np
|
||||
import orjson
|
||||
from json_repair import repair_json
|
||||
from openai import (
|
||||
@@ -588,13 +590,22 @@ class OpenaiClient(BaseClient):
|
||||
:param model_info: 模型信息
|
||||
:param embedding_input: 嵌入输入文本
|
||||
:return: 嵌入响应
|
||||
|
||||
- 请求时指定 encoding_format="base64",避免 OpenAI SDK 自动调用 .tolist()
|
||||
- 手动解码 base64 并保持 NumPy 数组格式,减少对象创建
|
||||
- 大批量请求后触发垃圾回收
|
||||
"""
|
||||
client = self._create_client()
|
||||
is_batch_request = isinstance(embedding_input, list)
|
||||
|
||||
# 关键修复:指定 encoding_format="base64" 避免 SDK 自动 tolist() 转换
|
||||
# OpenAI SDK 在不指定 encoding_format 时会调用 np.frombuffer().tolist()
|
||||
# 这会创建大量 Python float 对象,导致严重的内存泄露
|
||||
try:
|
||||
raw_response = await client.embeddings.create(
|
||||
model=model_info.model_identifier,
|
||||
input=embedding_input,
|
||||
encoding_format="base64", # 使用 base64 编码避免 tolist()
|
||||
extra_body=extra_params,
|
||||
)
|
||||
except APIConnectionError as e:
|
||||
@@ -615,9 +626,24 @@ class OpenaiClient(BaseClient):
|
||||
|
||||
response = APIResponse()
|
||||
|
||||
# 解析嵌入响应
|
||||
# 手动解码 base64 并转换为 NumPy 数组,避免创建大量 Python float 对象
|
||||
if len(raw_response.data) > 0:
|
||||
embeddings = [item.embedding for item in raw_response.data]
|
||||
embeddings = []
|
||||
for item in raw_response.data:
|
||||
# item.embedding 现在是 base64 编码的字符串
|
||||
if isinstance(item.embedding, str):
|
||||
# 解码 base64 为 NumPy 数组(float32)
|
||||
embedding_array = np.frombuffer(
|
||||
base64.b64decode(item.embedding),
|
||||
dtype=np.float32
|
||||
)
|
||||
# 保持为 NumPy 数组,不调用 .tolist()
|
||||
embeddings.append(embedding_array)
|
||||
else:
|
||||
# 兜底:如果 SDK 返回的不是 base64(旧版或其他情况)
|
||||
# 转换为 NumPy 数组
|
||||
embeddings.append(np.array(item.embedding, dtype=np.float32))
|
||||
|
||||
response.embedding = embeddings if is_batch_request else embeddings[0]
|
||||
else:
|
||||
raise RespParseException(
|
||||
@@ -625,6 +651,10 @@ class OpenaiClient(BaseClient):
|
||||
"响应解析失败,缺失嵌入数据。",
|
||||
)
|
||||
|
||||
# 大批量请求后触发垃圾回收(batch_size > 8)
|
||||
if is_batch_request and len(embedding_input) > 8:
|
||||
gc.collect()
|
||||
|
||||
# 解析使用情况
|
||||
if hasattr(raw_response, "usage"):
|
||||
response.usage = UsageRecord(
|
||||
|
||||
@@ -28,6 +28,8 @@ from collections.abc import Callable, Coroutine
|
||||
from enum import Enum
|
||||
from typing import Any, ClassVar, Literal
|
||||
|
||||
import numpy as np
|
||||
|
||||
from rich.traceback import install
|
||||
|
||||
from src.common.logger import get_logger
|
||||
@@ -1170,7 +1172,8 @@ class LLMRequest:
|
||||
if not isinstance(embeddings, list):
|
||||
raise RuntimeError("获取embedding失败,批量结果格式异常")
|
||||
|
||||
if embeddings and not isinstance(embeddings[0], list):
|
||||
# embeddings 正常应该是 list[vector];如果 provider 返回了一维列表(单向量),只在这种情况下套一层
|
||||
if embeddings and not isinstance(embeddings[0], (list, tuple, np.ndarray)):
|
||||
embeddings = [embeddings] # type: ignore[list-item]
|
||||
|
||||
# 批量请求返回二维列表
|
||||
|
||||
@@ -117,7 +117,13 @@ class EmbeddingGenerator:
|
||||
# 调用 API
|
||||
embedding_list, model_name = await self._llm_request.get_embedding(text)
|
||||
|
||||
if embedding_list and len(embedding_list) > 0:
|
||||
# 兼容返回 np.ndarray 或 Python list
|
||||
if isinstance(embedding_list, np.ndarray):
|
||||
if embedding_list.size > 0:
|
||||
embedding = np.array(embedding_list, dtype=np.float32)
|
||||
logger.debug(f"🌐 API 生成嵌入: {text[:30]}... -> {len(embedding)}维 (模型: {model_name})")
|
||||
return embedding
|
||||
elif embedding_list and len(embedding_list) > 0:
|
||||
embedding = np.array(embedding_list, dtype=np.float32)
|
||||
logger.debug(f"🌐 API 生成嵌入: {text[:30]}... -> {len(embedding)}维 (模型: {model_name})")
|
||||
return embedding
|
||||
@@ -187,12 +193,17 @@ class EmbeddingGenerator:
|
||||
return None
|
||||
|
||||
embeddings, model_name = await self._llm_request.get_embedding(texts)
|
||||
if not embeddings:
|
||||
if embeddings is None:
|
||||
return None
|
||||
|
||||
results: list[np.ndarray | None] = []
|
||||
for emb in embeddings:
|
||||
if emb:
|
||||
if isinstance(emb, np.ndarray):
|
||||
if emb.size > 0:
|
||||
results.append(np.array(emb, dtype=np.float32))
|
||||
else:
|
||||
results.append(None)
|
||||
elif emb:
|
||||
results.append(np.array(emb, dtype=np.float32))
|
||||
else:
|
||||
results.append(None)
|
||||
|
||||
@@ -178,6 +178,10 @@ class ChatterActionPlanner:
|
||||
message.interest_calculated = True
|
||||
interest_updates[message_id] = result.interest_value
|
||||
reply_updates[message_id] = result.should_reply
|
||||
|
||||
# 批量处理后清理 embeddings 字典
|
||||
embeddings.clear()
|
||||
text_map.clear()
|
||||
else:
|
||||
message.interest_calculated = False
|
||||
|
||||
|
||||
Reference in New Issue
Block a user