feat(embedding): 优化嵌入处理,支持 NumPy 数组格式并减少内存分配

This commit is contained in:
Windpicker-owo
2025-12-10 11:00:46 +08:00
parent 487e49c1c1
commit 410d85fb26
6 changed files with 202 additions and 60 deletions

View File

@@ -4,6 +4,7 @@
""" """
import traceback import traceback
from collections import OrderedDict
from datetime import datetime from datetime import datetime
from typing import Any, cast from typing import Any, cast
@@ -18,15 +19,20 @@ from src.utils.json_parser import extract_and_parse_json
logger = get_logger("bot_interest_manager") logger = get_logger("bot_interest_manager")
# 🔧 内存优化配置
MAX_EMBEDDING_CACHE_SIZE = 500 # embedding 缓存最大条目数LRU淘汰
MAX_EXPANDED_TAG_CACHE_SIZE = 200 # 扩展标签缓存最大条目数
class BotInterestManager: class BotInterestManager:
"""机器人兴趣标签管理器""" """机器人兴趣标签管理器"""
def __init__(self): def __init__(self):
self.current_interests: BotPersonalityInterests | None = None self.current_interests: BotPersonalityInterests | None = None
self.embedding_cache: dict[str, list[float]] = {} # embedding缓存 # 🔧 使用 OrderedDict 实现 LRU 缓存,避免无限增长
self.expanded_tag_cache: dict[str, str] = {} # 扩展标签缓存 self.embedding_cache: OrderedDict[str, np.ndarray] = OrderedDict() # embedding缓存NumPy格式
self.expanded_embedding_cache: dict[str, list[float]] = {} # 扩展标签的embedding缓存 self.expanded_tag_cache: OrderedDict[str, str] = OrderedDict() # 扩展标签缓存
self.expanded_embedding_cache: OrderedDict[str, np.ndarray] = OrderedDict() # 扩展标签的embedding缓存
self._initialized = False self._initialized = False
# Embedding客户端配置 # Embedding客户端配置
@@ -358,7 +364,7 @@ class BotInterestManager:
embedding_text = tag.tag_name embedding_text = tag.tag_name
embedding = await self._get_embedding(embedding_text) embedding = await self._get_embedding(embedding_text)
if embedding: if embedding is not None and embedding.size > 0:
tag.embedding = embedding # 设置到 tag 对象(内存中) tag.embedding = embedding # 设置到 tag 对象(内存中)
self.embedding_cache[tag.tag_name] = embedding # 同时缓存到内存 self.embedding_cache[tag.tag_name] = embedding # 同时缓存到内存
generated_count += 1 generated_count += 1
@@ -376,16 +382,20 @@ class BotInterestManager:
interests.last_updated = datetime.now() interests.last_updated = datetime.now()
async def _get_embedding(self, text: str, cache: bool = True) -> list[float]: async def _get_embedding(self, text: str, cache: bool = True) -> np.ndarray:
"""获取文本的embedding向量 """获取文本的embedding向量
cache=False 用于消息内容,避免在 embedding_cache 中长期保留大文本导致内存膨胀。 cache=False 用于消息内容,避免在 embedding_cache 中长期保留大文本导致内存膨胀。
- 返回 NumPy 数组而非 list[float],减少对象分配
- 实现 LRU 缓存,防止缓存无限增长
""" """
if not hasattr(self, "embedding_request"): if not hasattr(self, "embedding_request"):
raise RuntimeError("Embedding请求客户端未初始化") raise RuntimeError("Embedding请求客户端未初始化")
# 检查缓存 # LRU 缓存查找:移到末尾表示最近使用
if cache and text in self.embedding_cache: if cache and text in self.embedding_cache:
self.embedding_cache.move_to_end(text)
return self.embedding_cache[text] return self.embedding_cache[text]
# 使用LLMRequest获取embedding # 使用LLMRequest获取embedding
@@ -393,18 +403,42 @@ class BotInterestManager:
raise RuntimeError("Embedding客户端未初始化") raise RuntimeError("Embedding客户端未初始化")
embedding, model_name = await self.embedding_request.get_embedding(text) embedding, model_name = await self.embedding_request.get_embedding(text)
if embedding and len(embedding) > 0: if embedding is not None and (isinstance(embedding, np.ndarray) and embedding.size > 0 or isinstance(embedding, list) and len(embedding) > 0):
if isinstance(embedding[0], list): # 处理不同类型的 embedding 返回值
# If it's a list of lists, take the first one (though get_embedding(str) should return list[float]) # 类型注解确保返回 np.ndarray
embedding = embedding[0] embedding_array: np.ndarray
if isinstance(embedding, np.ndarray):
# Now we can safely cast to list[float] as we've handled the nested list case # 已经是 NumPy 数组,检查维度
embedding_float = cast(list[float], embedding) if embedding.ndim == 1:
# 一维数组,直接使用
embedding_array = embedding
elif embedding.ndim == 2:
# 二维数组(批量结果),取第一行
logger.warning(f"_get_embedding 收到二维数组 {embedding.shape},取第一行作为单个向量")
embedding_array = embedding[0]
else:
raise RuntimeError(f"不支持的 embedding 维度: {embedding.ndim},形状: {embedding.shape}")
elif isinstance(embedding, list):
if len(embedding) > 0 and isinstance(embedding[0], list):
# 嵌套列表,取第一个
embedding_array = np.array(embedding[0], dtype=np.float32)
else:
# 普通列表
embedding_array = np.array(embedding, dtype=np.float32)
else:
raise RuntimeError(f"不支持的 embedding 类型: {type(embedding)}")
# 🔧 LRU 缓存写入:自动淘汰最旧条目
if cache: if cache:
self.embedding_cache[text] = embedding_float self.embedding_cache[text] = embedding_array
self.embedding_cache.move_to_end(text)
# 超过限制时删除最旧条目
if len(self.embedding_cache) > MAX_EMBEDDING_CACHE_SIZE:
oldest_key = next(iter(self.embedding_cache))
del self.embedding_cache[oldest_key]
logger.debug(f"LRU缓存淘汰: '{oldest_key}' (当前大小: {len(self.embedding_cache)})")
current_dim = len(embedding_float) current_dim = embedding_array.shape[0]
if self._detected_embedding_dimension is None: if self._detected_embedding_dimension is None:
self._detected_embedding_dimension = current_dim self._detected_embedding_dimension = current_dim
if self.embedding_dimension and self.embedding_dimension != current_dim: if self.embedding_dimension and self.embedding_dimension != current_dim:
@@ -421,11 +455,11 @@ class BotInterestManager:
self.embedding_dimension, self.embedding_dimension,
current_dim, current_dim,
) )
return embedding_float return embedding_array
else: else:
raise RuntimeError(f"返回的embedding为空: {embedding}") raise RuntimeError(f"返回的embedding为空: {embedding}")
async def _generate_message_embedding(self, message_text: str, keywords: list[str]) -> list[float]: async def _generate_message_embedding(self, message_text: str, keywords: list[str]) -> np.ndarray:
"""为消息生成embedding向量""" """为消息生成embedding向量"""
# 组合消息文本和关键词作为embedding输入 # 组合消息文本和关键词作为embedding输入
if keywords: if keywords:
@@ -439,8 +473,11 @@ class BotInterestManager:
async def generate_embeddings_for_texts( async def generate_embeddings_for_texts(
self, text_map: dict[str, str], batch_size: int = 16 self, text_map: dict[str, str], batch_size: int = 16
) -> dict[str, list[float]]: ) -> dict[str, np.ndarray]:
"""批量获取多段文本的embedding供上层统一处理。""" """批量获取多段文本的embedding供上层统一处理。
返回 NumPy 数组而非 list[float],减少对象分配
"""
if not text_map: if not text_map:
return {} return {}
@@ -449,7 +486,7 @@ class BotInterestManager:
batch_size = max(1, batch_size) batch_size = max(1, batch_size)
keys = list(text_map.keys()) keys = list(text_map.keys())
results: dict[str, list[float]] = {} results: dict[str, np.ndarray] = {}
for start in range(0, len(keys), batch_size): for start in range(0, len(keys), batch_size):
chunk_keys = keys[start : start + batch_size] chunk_keys = keys[start : start + batch_size]
@@ -461,26 +498,48 @@ class BotInterestManager:
logger.error(f"批量获取embedding失败 (chunk {start // batch_size + 1}): {exc}") logger.error(f"批量获取embedding失败 (chunk {start // batch_size + 1}): {exc}")
continue continue
if isinstance(chunk_embeddings, list) and chunk_embeddings and isinstance(chunk_embeddings[0], list): # 🔧 处理不同类型的返回值,统一转换为 NumPy 数组列表
normalized = chunk_embeddings normalized: list[np.ndarray] = []
elif isinstance(chunk_embeddings, list):
normalized = [chunk_embeddings] if isinstance(chunk_embeddings, np.ndarray):
else: # NumPy 数组:检查是一维还是二维
normalized = [] if chunk_embeddings.ndim == 1:
# 一维数组(单个向量),包装为列表
normalized = [chunk_embeddings]
elif chunk_embeddings.ndim == 2:
# 二维数组(批量向量),拆分为列表
normalized = [chunk_embeddings[i] for i in range(chunk_embeddings.shape[0])] # type: ignore
else:
logger.warning(f"意外的 embedding 维度: {chunk_embeddings.ndim},形状: {chunk_embeddings.shape}")
normalized = []
elif isinstance(chunk_embeddings, list) and chunk_embeddings:
if isinstance(chunk_embeddings[0], np.ndarray):
# 已经是 NumPy 数组列表
normalized = chunk_embeddings # type: ignore
elif isinstance(chunk_embeddings[0], list):
# list[list[float]] 格式,转换为 NumPy 数组
normalized = [np.array(vec, dtype=np.float32) for vec in chunk_embeddings]
else:
# 单个向量,包装为列表
normalized = [np.array(chunk_embeddings, dtype=np.float32)]
for idx_offset, message_id in enumerate(chunk_keys): for idx_offset, message_id in enumerate(chunk_keys):
vector = normalized[idx_offset] if idx_offset < len(normalized) else [] if idx_offset < len(normalized):
if isinstance(vector, list) and vector and isinstance(vector[0], float): results[message_id] = normalized[idx_offset]
results[message_id] = cast(list[float], vector)
else: else:
results[message_id] = [] # 返回空数组而非空列表
results[message_id] = np.array([], dtype=np.float32)
return results return results
async def _calculate_similarity_scores( async def _calculate_similarity_scores(
self, result: InterestMatchResult, message_embedding: list[float], keywords: list[str] self, result: InterestMatchResult, message_embedding: np.ndarray, keywords: list[str]
): ):
"""计算消息与兴趣标签的相似度分数""" """计算消息与兴趣标签的相似度分数
🔧 内存优化:接受 NumPy 数组参数,避免类型转换
"""
try: try:
if not self.current_interests: if not self.current_interests:
return return
@@ -492,9 +551,12 @@ class BotInterestManager:
logger.debug(f"开始计算与 {len(active_tags)} 个兴趣标签的相似度") logger.debug(f"开始计算与 {len(active_tags)} 个兴趣标签的相似度")
for tag in active_tags: for tag in active_tags:
if tag.embedding: if tag.embedding is not None:
# 确保 tag.embedding 是 NumPy 数组
tag_embedding = tag.embedding if isinstance(tag.embedding, np.ndarray) else np.array(tag.embedding, dtype=np.float32)
# 计算余弦相似度 # 计算余弦相似度
similarity = self._calculate_cosine_similarity(message_embedding, tag.embedding) similarity = self._calculate_cosine_similarity(message_embedding, tag_embedding)
weighted_score = similarity * tag.weight weighted_score = similarity * tag.weight
# 设置相似度阈值为0.3 # 设置相似度阈值为0.3
@@ -508,7 +570,7 @@ class BotInterestManager:
logger.error(f"计算相似度分数失败: {e}") logger.error(f"计算相似度分数失败: {e}")
async def calculate_interest_match( async def calculate_interest_match(
self, message_text: str, keywords: list[str] | None = None, message_embedding: list[float] | None = None self, message_text: str, keywords: list[str] | None = None, message_embedding: np.ndarray | None = None
) -> InterestMatchResult: ) -> InterestMatchResult:
"""计算消息与机器人兴趣的匹配度(优化版 - 标签扩展策略) """计算消息与机器人兴趣的匹配度(优化版 - 标签扩展策略)
@@ -540,7 +602,7 @@ class BotInterestManager:
# 生成消息的embedding # 生成消息的embedding
logger.debug("正在生成消息 embedding...") logger.debug("正在生成消息 embedding...")
if not message_embedding: if message_embedding is None:
# 消息文本embedding不入全局缓存避免缓存随着对话历史无限增长 # 消息文本embedding不入全局缓存避免缓存随着对话历史无限增长
message_embedding = await self._get_embedding(message_text, cache=False) message_embedding = await self._get_embedding(message_text, cache=False)
logger.debug(f"消息 embedding 生成成功, 维度: {len(message_embedding)}") logger.debug(f"消息 embedding 生成成功, 维度: {len(message_embedding)}")
@@ -563,11 +625,11 @@ class BotInterestManager:
logger.debug(f"使用分级相似度阈值: 高={high_threshold}, 中={medium_threshold}, 低={low_threshold}") logger.debug(f"使用分级相似度阈值: 高={high_threshold}, 中={medium_threshold}, 低={low_threshold}")
for tag in active_tags: for tag in active_tags:
if tag.embedding: if tag.embedding is not None and (isinstance(tag.embedding, np.ndarray) and tag.embedding.size > 0 or isinstance(tag.embedding, list) and len(tag.embedding) > 0):
# 🔧 优化:获取扩展标签的 embedding带缓存 # 🔧 优化:获取扩展标签的 embedding带缓存
expanded_embedding = await self._get_expanded_tag_embedding(tag.tag_name) expanded_embedding = await self._get_expanded_tag_embedding(tag.tag_name)
if expanded_embedding: if expanded_embedding is not None and expanded_embedding.size > 0:
# 使用扩展标签的 embedding 进行匹配 # 使用扩展标签的 embedding 进行匹配
similarity = self._calculate_cosine_similarity(message_embedding, expanded_embedding) similarity = self._calculate_cosine_similarity(message_embedding, expanded_embedding)
@@ -651,7 +713,7 @@ class BotInterestManager:
return result return result
async def _get_expanded_tag_embedding(self, tag_name: str) -> list[float] | None: async def _get_expanded_tag_embedding(self, tag_name: str) -> np.ndarray | None:
"""获取扩展标签的 embedding带缓存 """获取扩展标签的 embedding带缓存
优先使用缓存,如果没有则生成并缓存 优先使用缓存,如果没有则生成并缓存
@@ -666,7 +728,7 @@ class BotInterestManager:
# 生成 embedding # 生成 embedding
try: try:
embedding = await self._get_embedding(expanded_tag) embedding = await self._get_embedding(expanded_tag)
if embedding: if embedding is not None and embedding.size > 0:
# 缓存结果 # 缓存结果
self.expanded_tag_cache[tag_name] = expanded_tag self.expanded_tag_cache[tag_name] = expanded_tag
self.expanded_embedding_cache[tag_name] = embedding self.expanded_embedding_cache[tag_name] = embedding
@@ -852,11 +914,26 @@ class BotInterestManager:
return previous_row[-1] return previous_row[-1]
def _calculate_cosine_similarity(self, vec1: list[float], vec2: list[float]) -> float: def _calculate_cosine_similarity(self, vec1: np.ndarray | list[float], vec2: np.ndarray | list[float]) -> float:
"""计算余弦相似度""" """计算余弦相似度
支持 NumPy 数组参数,避免重复转换
"""
try: try:
np_vec1 = np.array(vec1) # 确保是 NumPy 数组
np_vec2 = np.array(vec2) np_vec1 = vec1 if isinstance(vec1, np.ndarray) else np.array(vec1, dtype=np.float32)
np_vec2 = vec2 if isinstance(vec2, np.ndarray) else np.array(vec2, dtype=np.float32)
# 🔧 确保是一维数组
np_vec1 = np_vec1.flatten()
np_vec2 = np_vec2.flatten()
# 检查维度是否匹配
if np_vec1.shape[0] != np_vec2.shape[0]:
logger.warning(
f"向量维度不匹配: vec1={np_vec1.shape[0]}, vec2={np_vec2.shape[0]}返回0.0"
)
return 0.0
dot_product = np.dot(np_vec1, np_vec2) dot_product = np.dot(np_vec1, np_vec2)
norm1 = np.linalg.norm(np_vec1) norm1 = np.linalg.norm(np_vec1)
@@ -866,7 +943,8 @@ class BotInterestManager:
return 0.0 return 0.0
similarity = dot_product / (norm1 * norm2) similarity = dot_product / (norm1 * norm2)
return float(similarity) # 🔧 使用 item() 方法安全地提取标量值
return float(similarity.item() if hasattr(similarity, 'item') else similarity)
except Exception as e: except Exception as e:
logger.error(f"计算余弦相似度失败: {e}") logger.error(f"计算余弦相似度失败: {e}")
@@ -1056,8 +1134,11 @@ class BotInterestManager:
logger.error("🔍 错误详情:") logger.error("🔍 错误详情:")
traceback.print_exc() traceback.print_exc()
async def _load_embedding_cache_from_file(self, personality_id: str) -> dict[str, list[float]] | None: async def _load_embedding_cache_from_file(self, personality_id: str) -> dict[str, np.ndarray] | None:
"""从文件加载embedding缓存""" """从文件加载embedding缓存
内存优化:转换为 NumPy 数组格式
"""
try: try:
from pathlib import Path from pathlib import Path
@@ -1089,11 +1170,14 @@ class BotInterestManager:
logger.warning(f"⚠️ Embedding模型已变更 ({cache_embedding_model}{current_embedding_model}),忽略旧缓存") logger.warning(f"⚠️ Embedding模型已变更 ({cache_embedding_model}{current_embedding_model}),忽略旧缓存")
return None return None
embeddings = cache_data.get("embeddings", {}) # 🔧 转换为 NumPy 数组格式
embeddings_raw = cache_data.get("embeddings", {})
embeddings = {key: np.array(value, dtype=np.float32) for key, value in embeddings_raw.items()}
# 同时加载扩展标签的embedding缓存 # 同时加载扩展标签的embedding缓存
expanded_embeddings = cache_data.get("expanded_embeddings", {}) expanded_embeddings_raw = cache_data.get("expanded_embeddings", {})
if expanded_embeddings: if expanded_embeddings_raw:
expanded_embeddings = {key: np.array(value, dtype=np.float32) for key, value in expanded_embeddings_raw.items()}
self.expanded_embedding_cache.update(expanded_embeddings) self.expanded_embedding_cache.update(expanded_embeddings)
logger.info(f"成功从文件加载 {len(embeddings)} 个标签embedding缓存 (版本: {cache_version}, 模型: {cache_embedding_model})") logger.info(f"成功从文件加载 {len(embeddings)} 个标签embedding缓存 (版本: {cache_version}, 模型: {cache_embedding_model})")
@@ -1125,13 +1209,17 @@ class BotInterestManager:
allowed_keys = {tag.tag_name for tag in self.current_interests.interest_tags} allowed_keys = {tag.tag_name for tag in self.current_interests.interest_tags}
tag_embeddings = {key: value for key, value in self.embedding_cache.items() if key in allowed_keys} tag_embeddings = {key: value for key, value in self.embedding_cache.items() if key in allowed_keys}
# 将 NumPy 数组转换为列表以便 JSON 序列化
tag_embeddings_serializable = {key: value.tolist() if isinstance(value, np.ndarray) else value for key, value in tag_embeddings.items()}
expanded_embeddings_serializable = {key: value.tolist() if isinstance(value, np.ndarray) else value for key, value in self.expanded_embedding_cache.items()}
cache_data = { cache_data = {
"version": 1, "version": 1,
"personality_id": personality_id, "personality_id": personality_id,
"embedding_model": current_embedding_model, "embedding_model": current_embedding_model,
"last_updated": datetime.now().isoformat(), "last_updated": datetime.now().isoformat(),
"embeddings": tag_embeddings, "embeddings": tag_embeddings_serializable,
"expanded_embeddings": self.expanded_embedding_cache, # 同时保存扩展标签的embedding "expanded_embeddings": expanded_embeddings_serializable, # 同时保存扩展标签的embedding
} }
# 写入文件 # 写入文件

View File

@@ -6,18 +6,24 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime from datetime import datetime
from typing import Any from typing import Any
import numpy as np
from src.config.config import model_config from src.config.config import model_config
from . import BaseDataModel from . import BaseDataModel
@dataclass @dataclass
class BotInterestTag(BaseDataModel): class BotInterestTag(BaseDataModel):
"""机器人兴趣标签""" """机器人兴趣标签
embedding 字段支持 NumPy 数组格式,减少对象分配
"""
tag_name: str tag_name: str
weight: float = 1.0 # 权重,表示对这个兴趣的喜好程度 (0.0-1.0) weight: float = 1.0 # 权重,表示对这个兴趣的喜好程度 (0.0-1.0)
expanded: str | None = None # 标签的扩展描述,用于更精准的语义匹配 expanded: str | None = None # 标签的扩展描述,用于更精准的语义匹配
embedding: list[float] | None = None # 标签的embedding向量 embedding: np.ndarray | list[float] | None = None # 标签的embedding向量(支持 NumPy 数组)
created_at: datetime = field(default_factory=datetime.now) created_at: datetime = field(default_factory=datetime.now)
updated_at: datetime = field(default_factory=datetime.now) updated_at: datetime = field(default_factory=datetime.now)
is_active: bool = True is_active: bool = True

View File

@@ -1,10 +1,12 @@
import asyncio import asyncio
import base64 import base64
import gc
import io import io
import re import re
from collections.abc import Callable, Coroutine, Iterable from collections.abc import Callable, Coroutine, Iterable
from typing import Any, ClassVar from typing import Any, ClassVar
import numpy as np
import orjson import orjson
from json_repair import repair_json from json_repair import repair_json
from openai import ( from openai import (
@@ -588,13 +590,22 @@ class OpenaiClient(BaseClient):
:param model_info: 模型信息 :param model_info: 模型信息
:param embedding_input: 嵌入输入文本 :param embedding_input: 嵌入输入文本
:return: 嵌入响应 :return: 嵌入响应
- 请求时指定 encoding_format="base64",避免 OpenAI SDK 自动调用 .tolist()
- 手动解码 base64 并保持 NumPy 数组格式,减少对象创建
- 大批量请求后触发垃圾回收
""" """
client = self._create_client() client = self._create_client()
is_batch_request = isinstance(embedding_input, list) is_batch_request = isinstance(embedding_input, list)
# 关键修复:指定 encoding_format="base64" 避免 SDK 自动 tolist() 转换
# OpenAI SDK 在不指定 encoding_format 时会调用 np.frombuffer().tolist()
# 这会创建大量 Python float 对象,导致严重的内存泄露
try: try:
raw_response = await client.embeddings.create( raw_response = await client.embeddings.create(
model=model_info.model_identifier, model=model_info.model_identifier,
input=embedding_input, input=embedding_input,
encoding_format="base64", # 使用 base64 编码避免 tolist()
extra_body=extra_params, extra_body=extra_params,
) )
except APIConnectionError as e: except APIConnectionError as e:
@@ -615,15 +626,34 @@ class OpenaiClient(BaseClient):
response = APIResponse() response = APIResponse()
# 解析嵌入响应 # 手动解码 base64 并转换为 NumPy 数组,避免创建大量 Python float 对象
if len(raw_response.data) > 0: if len(raw_response.data) > 0:
embeddings = [item.embedding for item in raw_response.data] embeddings = []
for item in raw_response.data:
# item.embedding 现在是 base64 编码的字符串
if isinstance(item.embedding, str):
# 解码 base64 为 NumPy 数组float32
embedding_array = np.frombuffer(
base64.b64decode(item.embedding),
dtype=np.float32
)
# 保持为 NumPy 数组,不调用 .tolist()
embeddings.append(embedding_array)
else:
# 兜底:如果 SDK 返回的不是 base64旧版或其他情况
# 转换为 NumPy 数组
embeddings.append(np.array(item.embedding, dtype=np.float32))
response.embedding = embeddings if is_batch_request else embeddings[0] response.embedding = embeddings if is_batch_request else embeddings[0]
else: else:
raise RespParseException( raise RespParseException(
raw_response, raw_response,
"响应解析失败,缺失嵌入数据。", "响应解析失败,缺失嵌入数据。",
) )
# 大批量请求后触发垃圾回收batch_size > 8
if is_batch_request and len(embedding_input) > 8:
gc.collect()
# 解析使用情况 # 解析使用情况
if hasattr(raw_response, "usage"): if hasattr(raw_response, "usage"):

View File

@@ -28,6 +28,8 @@ from collections.abc import Callable, Coroutine
from enum import Enum from enum import Enum
from typing import Any, ClassVar, Literal from typing import Any, ClassVar, Literal
import numpy as np
from rich.traceback import install from rich.traceback import install
from src.common.logger import get_logger from src.common.logger import get_logger
@@ -1170,7 +1172,8 @@ class LLMRequest:
if not isinstance(embeddings, list): if not isinstance(embeddings, list):
raise RuntimeError("获取embedding失败批量结果格式异常") raise RuntimeError("获取embedding失败批量结果格式异常")
if embeddings and not isinstance(embeddings[0], list): # embeddings 正常应该是 list[vector];如果 provider 返回了一维列表(单向量),只在这种情况下套一层
if embeddings and not isinstance(embeddings[0], (list, tuple, np.ndarray)):
embeddings = [embeddings] # type: ignore[list-item] embeddings = [embeddings] # type: ignore[list-item]
# 批量请求返回二维列表 # 批量请求返回二维列表

View File

@@ -117,7 +117,13 @@ class EmbeddingGenerator:
# 调用 API # 调用 API
embedding_list, model_name = await self._llm_request.get_embedding(text) embedding_list, model_name = await self._llm_request.get_embedding(text)
if embedding_list and len(embedding_list) > 0: # 兼容返回 np.ndarray 或 Python list
if isinstance(embedding_list, np.ndarray):
if embedding_list.size > 0:
embedding = np.array(embedding_list, dtype=np.float32)
logger.debug(f"🌐 API 生成嵌入: {text[:30]}... -> {len(embedding)}维 (模型: {model_name})")
return embedding
elif embedding_list and len(embedding_list) > 0:
embedding = np.array(embedding_list, dtype=np.float32) embedding = np.array(embedding_list, dtype=np.float32)
logger.debug(f"🌐 API 生成嵌入: {text[:30]}... -> {len(embedding)}维 (模型: {model_name})") logger.debug(f"🌐 API 生成嵌入: {text[:30]}... -> {len(embedding)}维 (模型: {model_name})")
return embedding return embedding
@@ -187,12 +193,17 @@ class EmbeddingGenerator:
return None return None
embeddings, model_name = await self._llm_request.get_embedding(texts) embeddings, model_name = await self._llm_request.get_embedding(texts)
if not embeddings: if embeddings is None:
return None return None
results: list[np.ndarray | None] = [] results: list[np.ndarray | None] = []
for emb in embeddings: for emb in embeddings:
if emb: if isinstance(emb, np.ndarray):
if emb.size > 0:
results.append(np.array(emb, dtype=np.float32))
else:
results.append(None)
elif emb:
results.append(np.array(emb, dtype=np.float32)) results.append(np.array(emb, dtype=np.float32))
else: else:
results.append(None) results.append(None)

View File

@@ -178,6 +178,10 @@ class ChatterActionPlanner:
message.interest_calculated = True message.interest_calculated = True
interest_updates[message_id] = result.interest_value interest_updates[message_id] = result.interest_value
reply_updates[message_id] = result.should_reply reply_updates[message_id] = result.should_reply
# 批量处理后清理 embeddings 字典
embeddings.clear()
text_map.clear()
else: else:
message.interest_calculated = False message.interest_calculated = False