feat(embedding): 优化嵌入处理,支持 NumPy 数组格式并减少内存分配
This commit is contained in:
@@ -4,6 +4,7 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import traceback
|
import traceback
|
||||||
|
from collections import OrderedDict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
@@ -18,15 +19,20 @@ from src.utils.json_parser import extract_and_parse_json
|
|||||||
|
|
||||||
logger = get_logger("bot_interest_manager")
|
logger = get_logger("bot_interest_manager")
|
||||||
|
|
||||||
|
# 🔧 内存优化配置
|
||||||
|
MAX_EMBEDDING_CACHE_SIZE = 500 # embedding 缓存最大条目数(LRU淘汰)
|
||||||
|
MAX_EXPANDED_TAG_CACHE_SIZE = 200 # 扩展标签缓存最大条目数
|
||||||
|
|
||||||
|
|
||||||
class BotInterestManager:
|
class BotInterestManager:
|
||||||
"""机器人兴趣标签管理器"""
|
"""机器人兴趣标签管理器"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.current_interests: BotPersonalityInterests | None = None
|
self.current_interests: BotPersonalityInterests | None = None
|
||||||
self.embedding_cache: dict[str, list[float]] = {} # embedding缓存
|
# 🔧 使用 OrderedDict 实现 LRU 缓存,避免无限增长
|
||||||
self.expanded_tag_cache: dict[str, str] = {} # 扩展标签缓存
|
self.embedding_cache: OrderedDict[str, np.ndarray] = OrderedDict() # embedding缓存(NumPy格式)
|
||||||
self.expanded_embedding_cache: dict[str, list[float]] = {} # 扩展标签的embedding缓存
|
self.expanded_tag_cache: OrderedDict[str, str] = OrderedDict() # 扩展标签缓存
|
||||||
|
self.expanded_embedding_cache: OrderedDict[str, np.ndarray] = OrderedDict() # 扩展标签的embedding缓存
|
||||||
self._initialized = False
|
self._initialized = False
|
||||||
|
|
||||||
# Embedding客户端配置
|
# Embedding客户端配置
|
||||||
@@ -358,7 +364,7 @@ class BotInterestManager:
|
|||||||
embedding_text = tag.tag_name
|
embedding_text = tag.tag_name
|
||||||
embedding = await self._get_embedding(embedding_text)
|
embedding = await self._get_embedding(embedding_text)
|
||||||
|
|
||||||
if embedding:
|
if embedding is not None and embedding.size > 0:
|
||||||
tag.embedding = embedding # 设置到 tag 对象(内存中)
|
tag.embedding = embedding # 设置到 tag 对象(内存中)
|
||||||
self.embedding_cache[tag.tag_name] = embedding # 同时缓存到内存
|
self.embedding_cache[tag.tag_name] = embedding # 同时缓存到内存
|
||||||
generated_count += 1
|
generated_count += 1
|
||||||
@@ -376,16 +382,20 @@ class BotInterestManager:
|
|||||||
|
|
||||||
interests.last_updated = datetime.now()
|
interests.last_updated = datetime.now()
|
||||||
|
|
||||||
async def _get_embedding(self, text: str, cache: bool = True) -> list[float]:
|
async def _get_embedding(self, text: str, cache: bool = True) -> np.ndarray:
|
||||||
"""获取文本的embedding向量
|
"""获取文本的embedding向量
|
||||||
|
|
||||||
cache=False 用于消息内容,避免在 embedding_cache 中长期保留大文本导致内存膨胀。
|
cache=False 用于消息内容,避免在 embedding_cache 中长期保留大文本导致内存膨胀。
|
||||||
|
|
||||||
|
- 返回 NumPy 数组而非 list[float],减少对象分配
|
||||||
|
- 实现 LRU 缓存,防止缓存无限增长
|
||||||
"""
|
"""
|
||||||
if not hasattr(self, "embedding_request"):
|
if not hasattr(self, "embedding_request"):
|
||||||
raise RuntimeError("Embedding请求客户端未初始化")
|
raise RuntimeError("Embedding请求客户端未初始化")
|
||||||
|
|
||||||
# 检查缓存
|
# LRU 缓存查找:移到末尾表示最近使用
|
||||||
if cache and text in self.embedding_cache:
|
if cache and text in self.embedding_cache:
|
||||||
|
self.embedding_cache.move_to_end(text)
|
||||||
return self.embedding_cache[text]
|
return self.embedding_cache[text]
|
||||||
|
|
||||||
# 使用LLMRequest获取embedding
|
# 使用LLMRequest获取embedding
|
||||||
@@ -393,18 +403,42 @@ class BotInterestManager:
|
|||||||
raise RuntimeError("Embedding客户端未初始化")
|
raise RuntimeError("Embedding客户端未初始化")
|
||||||
embedding, model_name = await self.embedding_request.get_embedding(text)
|
embedding, model_name = await self.embedding_request.get_embedding(text)
|
||||||
|
|
||||||
if embedding and len(embedding) > 0:
|
if embedding is not None and (isinstance(embedding, np.ndarray) and embedding.size > 0 or isinstance(embedding, list) and len(embedding) > 0):
|
||||||
if isinstance(embedding[0], list):
|
# 处理不同类型的 embedding 返回值
|
||||||
# If it's a list of lists, take the first one (though get_embedding(str) should return list[float])
|
# 类型注解确保返回 np.ndarray
|
||||||
embedding = embedding[0]
|
embedding_array: np.ndarray
|
||||||
|
if isinstance(embedding, np.ndarray):
|
||||||
# Now we can safely cast to list[float] as we've handled the nested list case
|
# 已经是 NumPy 数组,检查维度
|
||||||
embedding_float = cast(list[float], embedding)
|
if embedding.ndim == 1:
|
||||||
|
# 一维数组,直接使用
|
||||||
|
embedding_array = embedding
|
||||||
|
elif embedding.ndim == 2:
|
||||||
|
# 二维数组(批量结果),取第一行
|
||||||
|
logger.warning(f"_get_embedding 收到二维数组 {embedding.shape},取第一行作为单个向量")
|
||||||
|
embedding_array = embedding[0]
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"不支持的 embedding 维度: {embedding.ndim},形状: {embedding.shape}")
|
||||||
|
elif isinstance(embedding, list):
|
||||||
|
if len(embedding) > 0 and isinstance(embedding[0], list):
|
||||||
|
# 嵌套列表,取第一个
|
||||||
|
embedding_array = np.array(embedding[0], dtype=np.float32)
|
||||||
|
else:
|
||||||
|
# 普通列表
|
||||||
|
embedding_array = np.array(embedding, dtype=np.float32)
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"不支持的 embedding 类型: {type(embedding)}")
|
||||||
|
|
||||||
|
# 🔧 LRU 缓存写入:自动淘汰最旧条目
|
||||||
if cache:
|
if cache:
|
||||||
self.embedding_cache[text] = embedding_float
|
self.embedding_cache[text] = embedding_array
|
||||||
|
self.embedding_cache.move_to_end(text)
|
||||||
|
# 超过限制时删除最旧条目
|
||||||
|
if len(self.embedding_cache) > MAX_EMBEDDING_CACHE_SIZE:
|
||||||
|
oldest_key = next(iter(self.embedding_cache))
|
||||||
|
del self.embedding_cache[oldest_key]
|
||||||
|
logger.debug(f"LRU缓存淘汰: '{oldest_key}' (当前大小: {len(self.embedding_cache)})")
|
||||||
|
|
||||||
current_dim = len(embedding_float)
|
current_dim = embedding_array.shape[0]
|
||||||
if self._detected_embedding_dimension is None:
|
if self._detected_embedding_dimension is None:
|
||||||
self._detected_embedding_dimension = current_dim
|
self._detected_embedding_dimension = current_dim
|
||||||
if self.embedding_dimension and self.embedding_dimension != current_dim:
|
if self.embedding_dimension and self.embedding_dimension != current_dim:
|
||||||
@@ -421,11 +455,11 @@ class BotInterestManager:
|
|||||||
self.embedding_dimension,
|
self.embedding_dimension,
|
||||||
current_dim,
|
current_dim,
|
||||||
)
|
)
|
||||||
return embedding_float
|
return embedding_array
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"返回的embedding为空: {embedding}")
|
raise RuntimeError(f"返回的embedding为空: {embedding}")
|
||||||
|
|
||||||
async def _generate_message_embedding(self, message_text: str, keywords: list[str]) -> list[float]:
|
async def _generate_message_embedding(self, message_text: str, keywords: list[str]) -> np.ndarray:
|
||||||
"""为消息生成embedding向量"""
|
"""为消息生成embedding向量"""
|
||||||
# 组合消息文本和关键词作为embedding输入
|
# 组合消息文本和关键词作为embedding输入
|
||||||
if keywords:
|
if keywords:
|
||||||
@@ -439,8 +473,11 @@ class BotInterestManager:
|
|||||||
|
|
||||||
async def generate_embeddings_for_texts(
|
async def generate_embeddings_for_texts(
|
||||||
self, text_map: dict[str, str], batch_size: int = 16
|
self, text_map: dict[str, str], batch_size: int = 16
|
||||||
) -> dict[str, list[float]]:
|
) -> dict[str, np.ndarray]:
|
||||||
"""批量获取多段文本的embedding,供上层统一处理。"""
|
"""批量获取多段文本的embedding,供上层统一处理。
|
||||||
|
|
||||||
|
返回 NumPy 数组而非 list[float],减少对象分配
|
||||||
|
"""
|
||||||
if not text_map:
|
if not text_map:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
@@ -449,7 +486,7 @@ class BotInterestManager:
|
|||||||
|
|
||||||
batch_size = max(1, batch_size)
|
batch_size = max(1, batch_size)
|
||||||
keys = list(text_map.keys())
|
keys = list(text_map.keys())
|
||||||
results: dict[str, list[float]] = {}
|
results: dict[str, np.ndarray] = {}
|
||||||
|
|
||||||
for start in range(0, len(keys), batch_size):
|
for start in range(0, len(keys), batch_size):
|
||||||
chunk_keys = keys[start : start + batch_size]
|
chunk_keys = keys[start : start + batch_size]
|
||||||
@@ -461,26 +498,48 @@ class BotInterestManager:
|
|||||||
logger.error(f"批量获取embedding失败 (chunk {start // batch_size + 1}): {exc}")
|
logger.error(f"批量获取embedding失败 (chunk {start // batch_size + 1}): {exc}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if isinstance(chunk_embeddings, list) and chunk_embeddings and isinstance(chunk_embeddings[0], list):
|
# 🔧 处理不同类型的返回值,统一转换为 NumPy 数组列表
|
||||||
normalized = chunk_embeddings
|
normalized: list[np.ndarray] = []
|
||||||
elif isinstance(chunk_embeddings, list):
|
|
||||||
normalized = [chunk_embeddings]
|
if isinstance(chunk_embeddings, np.ndarray):
|
||||||
else:
|
# NumPy 数组:检查是一维还是二维
|
||||||
normalized = []
|
if chunk_embeddings.ndim == 1:
|
||||||
|
# 一维数组(单个向量),包装为列表
|
||||||
|
normalized = [chunk_embeddings]
|
||||||
|
elif chunk_embeddings.ndim == 2:
|
||||||
|
# 二维数组(批量向量),拆分为列表
|
||||||
|
normalized = [chunk_embeddings[i] for i in range(chunk_embeddings.shape[0])] # type: ignore
|
||||||
|
else:
|
||||||
|
logger.warning(f"意外的 embedding 维度: {chunk_embeddings.ndim},形状: {chunk_embeddings.shape}")
|
||||||
|
normalized = []
|
||||||
|
elif isinstance(chunk_embeddings, list) and chunk_embeddings:
|
||||||
|
if isinstance(chunk_embeddings[0], np.ndarray):
|
||||||
|
# 已经是 NumPy 数组列表
|
||||||
|
normalized = chunk_embeddings # type: ignore
|
||||||
|
elif isinstance(chunk_embeddings[0], list):
|
||||||
|
# list[list[float]] 格式,转换为 NumPy 数组
|
||||||
|
normalized = [np.array(vec, dtype=np.float32) for vec in chunk_embeddings]
|
||||||
|
else:
|
||||||
|
# 单个向量,包装为列表
|
||||||
|
normalized = [np.array(chunk_embeddings, dtype=np.float32)]
|
||||||
|
|
||||||
for idx_offset, message_id in enumerate(chunk_keys):
|
for idx_offset, message_id in enumerate(chunk_keys):
|
||||||
vector = normalized[idx_offset] if idx_offset < len(normalized) else []
|
if idx_offset < len(normalized):
|
||||||
if isinstance(vector, list) and vector and isinstance(vector[0], float):
|
results[message_id] = normalized[idx_offset]
|
||||||
results[message_id] = cast(list[float], vector)
|
|
||||||
else:
|
else:
|
||||||
results[message_id] = []
|
# 返回空数组而非空列表
|
||||||
|
results[message_id] = np.array([], dtype=np.float32)
|
||||||
|
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
async def _calculate_similarity_scores(
|
async def _calculate_similarity_scores(
|
||||||
self, result: InterestMatchResult, message_embedding: list[float], keywords: list[str]
|
self, result: InterestMatchResult, message_embedding: np.ndarray, keywords: list[str]
|
||||||
):
|
):
|
||||||
"""计算消息与兴趣标签的相似度分数"""
|
"""计算消息与兴趣标签的相似度分数
|
||||||
|
|
||||||
|
🔧 内存优化:接受 NumPy 数组参数,避免类型转换
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
if not self.current_interests:
|
if not self.current_interests:
|
||||||
return
|
return
|
||||||
@@ -492,9 +551,12 @@ class BotInterestManager:
|
|||||||
logger.debug(f"开始计算与 {len(active_tags)} 个兴趣标签的相似度")
|
logger.debug(f"开始计算与 {len(active_tags)} 个兴趣标签的相似度")
|
||||||
|
|
||||||
for tag in active_tags:
|
for tag in active_tags:
|
||||||
if tag.embedding:
|
if tag.embedding is not None:
|
||||||
|
# 确保 tag.embedding 是 NumPy 数组
|
||||||
|
tag_embedding = tag.embedding if isinstance(tag.embedding, np.ndarray) else np.array(tag.embedding, dtype=np.float32)
|
||||||
|
|
||||||
# 计算余弦相似度
|
# 计算余弦相似度
|
||||||
similarity = self._calculate_cosine_similarity(message_embedding, tag.embedding)
|
similarity = self._calculate_cosine_similarity(message_embedding, tag_embedding)
|
||||||
weighted_score = similarity * tag.weight
|
weighted_score = similarity * tag.weight
|
||||||
|
|
||||||
# 设置相似度阈值为0.3
|
# 设置相似度阈值为0.3
|
||||||
@@ -508,7 +570,7 @@ class BotInterestManager:
|
|||||||
logger.error(f"计算相似度分数失败: {e}")
|
logger.error(f"计算相似度分数失败: {e}")
|
||||||
|
|
||||||
async def calculate_interest_match(
|
async def calculate_interest_match(
|
||||||
self, message_text: str, keywords: list[str] | None = None, message_embedding: list[float] | None = None
|
self, message_text: str, keywords: list[str] | None = None, message_embedding: np.ndarray | None = None
|
||||||
) -> InterestMatchResult:
|
) -> InterestMatchResult:
|
||||||
"""计算消息与机器人兴趣的匹配度(优化版 - 标签扩展策略)
|
"""计算消息与机器人兴趣的匹配度(优化版 - 标签扩展策略)
|
||||||
|
|
||||||
@@ -540,7 +602,7 @@ class BotInterestManager:
|
|||||||
|
|
||||||
# 生成消息的embedding
|
# 生成消息的embedding
|
||||||
logger.debug("正在生成消息 embedding...")
|
logger.debug("正在生成消息 embedding...")
|
||||||
if not message_embedding:
|
if message_embedding is None:
|
||||||
# 消息文本embedding不入全局缓存,避免缓存随着对话历史无限增长
|
# 消息文本embedding不入全局缓存,避免缓存随着对话历史无限增长
|
||||||
message_embedding = await self._get_embedding(message_text, cache=False)
|
message_embedding = await self._get_embedding(message_text, cache=False)
|
||||||
logger.debug(f"消息 embedding 生成成功, 维度: {len(message_embedding)}")
|
logger.debug(f"消息 embedding 生成成功, 维度: {len(message_embedding)}")
|
||||||
@@ -563,11 +625,11 @@ class BotInterestManager:
|
|||||||
logger.debug(f"使用分级相似度阈值: 高={high_threshold}, 中={medium_threshold}, 低={low_threshold}")
|
logger.debug(f"使用分级相似度阈值: 高={high_threshold}, 中={medium_threshold}, 低={low_threshold}")
|
||||||
|
|
||||||
for tag in active_tags:
|
for tag in active_tags:
|
||||||
if tag.embedding:
|
if tag.embedding is not None and (isinstance(tag.embedding, np.ndarray) and tag.embedding.size > 0 or isinstance(tag.embedding, list) and len(tag.embedding) > 0):
|
||||||
# 🔧 优化:获取扩展标签的 embedding(带缓存)
|
# 🔧 优化:获取扩展标签的 embedding(带缓存)
|
||||||
expanded_embedding = await self._get_expanded_tag_embedding(tag.tag_name)
|
expanded_embedding = await self._get_expanded_tag_embedding(tag.tag_name)
|
||||||
|
|
||||||
if expanded_embedding:
|
if expanded_embedding is not None and expanded_embedding.size > 0:
|
||||||
# 使用扩展标签的 embedding 进行匹配
|
# 使用扩展标签的 embedding 进行匹配
|
||||||
similarity = self._calculate_cosine_similarity(message_embedding, expanded_embedding)
|
similarity = self._calculate_cosine_similarity(message_embedding, expanded_embedding)
|
||||||
|
|
||||||
@@ -651,7 +713,7 @@ class BotInterestManager:
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def _get_expanded_tag_embedding(self, tag_name: str) -> list[float] | None:
|
async def _get_expanded_tag_embedding(self, tag_name: str) -> np.ndarray | None:
|
||||||
"""获取扩展标签的 embedding(带缓存)
|
"""获取扩展标签的 embedding(带缓存)
|
||||||
|
|
||||||
优先使用缓存,如果没有则生成并缓存
|
优先使用缓存,如果没有则生成并缓存
|
||||||
@@ -666,7 +728,7 @@ class BotInterestManager:
|
|||||||
# 生成 embedding
|
# 生成 embedding
|
||||||
try:
|
try:
|
||||||
embedding = await self._get_embedding(expanded_tag)
|
embedding = await self._get_embedding(expanded_tag)
|
||||||
if embedding:
|
if embedding is not None and embedding.size > 0:
|
||||||
# 缓存结果
|
# 缓存结果
|
||||||
self.expanded_tag_cache[tag_name] = expanded_tag
|
self.expanded_tag_cache[tag_name] = expanded_tag
|
||||||
self.expanded_embedding_cache[tag_name] = embedding
|
self.expanded_embedding_cache[tag_name] = embedding
|
||||||
@@ -852,11 +914,26 @@ class BotInterestManager:
|
|||||||
|
|
||||||
return previous_row[-1]
|
return previous_row[-1]
|
||||||
|
|
||||||
def _calculate_cosine_similarity(self, vec1: list[float], vec2: list[float]) -> float:
|
def _calculate_cosine_similarity(self, vec1: np.ndarray | list[float], vec2: np.ndarray | list[float]) -> float:
|
||||||
"""计算余弦相似度"""
|
"""计算余弦相似度
|
||||||
|
|
||||||
|
支持 NumPy 数组参数,避免重复转换
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
np_vec1 = np.array(vec1)
|
# 确保是 NumPy 数组
|
||||||
np_vec2 = np.array(vec2)
|
np_vec1 = vec1 if isinstance(vec1, np.ndarray) else np.array(vec1, dtype=np.float32)
|
||||||
|
np_vec2 = vec2 if isinstance(vec2, np.ndarray) else np.array(vec2, dtype=np.float32)
|
||||||
|
|
||||||
|
# 🔧 确保是一维数组
|
||||||
|
np_vec1 = np_vec1.flatten()
|
||||||
|
np_vec2 = np_vec2.flatten()
|
||||||
|
|
||||||
|
# 检查维度是否匹配
|
||||||
|
if np_vec1.shape[0] != np_vec2.shape[0]:
|
||||||
|
logger.warning(
|
||||||
|
f"向量维度不匹配: vec1={np_vec1.shape[0]}, vec2={np_vec2.shape[0]},返回0.0"
|
||||||
|
)
|
||||||
|
return 0.0
|
||||||
|
|
||||||
dot_product = np.dot(np_vec1, np_vec2)
|
dot_product = np.dot(np_vec1, np_vec2)
|
||||||
norm1 = np.linalg.norm(np_vec1)
|
norm1 = np.linalg.norm(np_vec1)
|
||||||
@@ -866,7 +943,8 @@ class BotInterestManager:
|
|||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
similarity = dot_product / (norm1 * norm2)
|
similarity = dot_product / (norm1 * norm2)
|
||||||
return float(similarity)
|
# 🔧 使用 item() 方法安全地提取标量值
|
||||||
|
return float(similarity.item() if hasattr(similarity, 'item') else similarity)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"计算余弦相似度失败: {e}")
|
logger.error(f"计算余弦相似度失败: {e}")
|
||||||
@@ -1056,8 +1134,11 @@ class BotInterestManager:
|
|||||||
logger.error("🔍 错误详情:")
|
logger.error("🔍 错误详情:")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
async def _load_embedding_cache_from_file(self, personality_id: str) -> dict[str, list[float]] | None:
|
async def _load_embedding_cache_from_file(self, personality_id: str) -> dict[str, np.ndarray] | None:
|
||||||
"""从文件加载embedding缓存"""
|
"""从文件加载embedding缓存
|
||||||
|
|
||||||
|
内存优化:转换为 NumPy 数组格式
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -1089,11 +1170,14 @@ class BotInterestManager:
|
|||||||
logger.warning(f"⚠️ Embedding模型已变更 ({cache_embedding_model} → {current_embedding_model}),忽略旧缓存")
|
logger.warning(f"⚠️ Embedding模型已变更 ({cache_embedding_model} → {current_embedding_model}),忽略旧缓存")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
embeddings = cache_data.get("embeddings", {})
|
# 🔧 转换为 NumPy 数组格式
|
||||||
|
embeddings_raw = cache_data.get("embeddings", {})
|
||||||
|
embeddings = {key: np.array(value, dtype=np.float32) for key, value in embeddings_raw.items()}
|
||||||
|
|
||||||
# 同时加载扩展标签的embedding缓存
|
# 同时加载扩展标签的embedding缓存
|
||||||
expanded_embeddings = cache_data.get("expanded_embeddings", {})
|
expanded_embeddings_raw = cache_data.get("expanded_embeddings", {})
|
||||||
if expanded_embeddings:
|
if expanded_embeddings_raw:
|
||||||
|
expanded_embeddings = {key: np.array(value, dtype=np.float32) for key, value in expanded_embeddings_raw.items()}
|
||||||
self.expanded_embedding_cache.update(expanded_embeddings)
|
self.expanded_embedding_cache.update(expanded_embeddings)
|
||||||
|
|
||||||
logger.info(f"成功从文件加载 {len(embeddings)} 个标签embedding缓存 (版本: {cache_version}, 模型: {cache_embedding_model})")
|
logger.info(f"成功从文件加载 {len(embeddings)} 个标签embedding缓存 (版本: {cache_version}, 模型: {cache_embedding_model})")
|
||||||
@@ -1125,13 +1209,17 @@ class BotInterestManager:
|
|||||||
allowed_keys = {tag.tag_name for tag in self.current_interests.interest_tags}
|
allowed_keys = {tag.tag_name for tag in self.current_interests.interest_tags}
|
||||||
tag_embeddings = {key: value for key, value in self.embedding_cache.items() if key in allowed_keys}
|
tag_embeddings = {key: value for key, value in self.embedding_cache.items() if key in allowed_keys}
|
||||||
|
|
||||||
|
# 将 NumPy 数组转换为列表以便 JSON 序列化
|
||||||
|
tag_embeddings_serializable = {key: value.tolist() if isinstance(value, np.ndarray) else value for key, value in tag_embeddings.items()}
|
||||||
|
expanded_embeddings_serializable = {key: value.tolist() if isinstance(value, np.ndarray) else value for key, value in self.expanded_embedding_cache.items()}
|
||||||
|
|
||||||
cache_data = {
|
cache_data = {
|
||||||
"version": 1,
|
"version": 1,
|
||||||
"personality_id": personality_id,
|
"personality_id": personality_id,
|
||||||
"embedding_model": current_embedding_model,
|
"embedding_model": current_embedding_model,
|
||||||
"last_updated": datetime.now().isoformat(),
|
"last_updated": datetime.now().isoformat(),
|
||||||
"embeddings": tag_embeddings,
|
"embeddings": tag_embeddings_serializable,
|
||||||
"expanded_embeddings": self.expanded_embedding_cache, # 同时保存扩展标签的embedding
|
"expanded_embeddings": expanded_embeddings_serializable, # 同时保存扩展标签的embedding
|
||||||
}
|
}
|
||||||
|
|
||||||
# 写入文件
|
# 写入文件
|
||||||
|
|||||||
@@ -6,18 +6,24 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from src.config.config import model_config
|
from src.config.config import model_config
|
||||||
from . import BaseDataModel
|
from . import BaseDataModel
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BotInterestTag(BaseDataModel):
|
class BotInterestTag(BaseDataModel):
|
||||||
"""机器人兴趣标签"""
|
"""机器人兴趣标签
|
||||||
|
|
||||||
|
embedding 字段支持 NumPy 数组格式,减少对象分配
|
||||||
|
"""
|
||||||
|
|
||||||
tag_name: str
|
tag_name: str
|
||||||
weight: float = 1.0 # 权重,表示对这个兴趣的喜好程度 (0.0-1.0)
|
weight: float = 1.0 # 权重,表示对这个兴趣的喜好程度 (0.0-1.0)
|
||||||
expanded: str | None = None # 标签的扩展描述,用于更精准的语义匹配
|
expanded: str | None = None # 标签的扩展描述,用于更精准的语义匹配
|
||||||
embedding: list[float] | None = None # 标签的embedding向量
|
embedding: np.ndarray | list[float] | None = None # 标签的embedding向量(支持 NumPy 数组)
|
||||||
created_at: datetime = field(default_factory=datetime.now)
|
created_at: datetime = field(default_factory=datetime.now)
|
||||||
updated_at: datetime = field(default_factory=datetime.now)
|
updated_at: datetime = field(default_factory=datetime.now)
|
||||||
is_active: bool = True
|
is_active: bool = True
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
|
import gc
|
||||||
import io
|
import io
|
||||||
import re
|
import re
|
||||||
from collections.abc import Callable, Coroutine, Iterable
|
from collections.abc import Callable, Coroutine, Iterable
|
||||||
from typing import Any, ClassVar
|
from typing import Any, ClassVar
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import orjson
|
import orjson
|
||||||
from json_repair import repair_json
|
from json_repair import repair_json
|
||||||
from openai import (
|
from openai import (
|
||||||
@@ -588,13 +590,22 @@ class OpenaiClient(BaseClient):
|
|||||||
:param model_info: 模型信息
|
:param model_info: 模型信息
|
||||||
:param embedding_input: 嵌入输入文本
|
:param embedding_input: 嵌入输入文本
|
||||||
:return: 嵌入响应
|
:return: 嵌入响应
|
||||||
|
|
||||||
|
- 请求时指定 encoding_format="base64",避免 OpenAI SDK 自动调用 .tolist()
|
||||||
|
- 手动解码 base64 并保持 NumPy 数组格式,减少对象创建
|
||||||
|
- 大批量请求后触发垃圾回收
|
||||||
"""
|
"""
|
||||||
client = self._create_client()
|
client = self._create_client()
|
||||||
is_batch_request = isinstance(embedding_input, list)
|
is_batch_request = isinstance(embedding_input, list)
|
||||||
|
|
||||||
|
# 关键修复:指定 encoding_format="base64" 避免 SDK 自动 tolist() 转换
|
||||||
|
# OpenAI SDK 在不指定 encoding_format 时会调用 np.frombuffer().tolist()
|
||||||
|
# 这会创建大量 Python float 对象,导致严重的内存泄露
|
||||||
try:
|
try:
|
||||||
raw_response = await client.embeddings.create(
|
raw_response = await client.embeddings.create(
|
||||||
model=model_info.model_identifier,
|
model=model_info.model_identifier,
|
||||||
input=embedding_input,
|
input=embedding_input,
|
||||||
|
encoding_format="base64", # 使用 base64 编码避免 tolist()
|
||||||
extra_body=extra_params,
|
extra_body=extra_params,
|
||||||
)
|
)
|
||||||
except APIConnectionError as e:
|
except APIConnectionError as e:
|
||||||
@@ -615,15 +626,34 @@ class OpenaiClient(BaseClient):
|
|||||||
|
|
||||||
response = APIResponse()
|
response = APIResponse()
|
||||||
|
|
||||||
# 解析嵌入响应
|
# 手动解码 base64 并转换为 NumPy 数组,避免创建大量 Python float 对象
|
||||||
if len(raw_response.data) > 0:
|
if len(raw_response.data) > 0:
|
||||||
embeddings = [item.embedding for item in raw_response.data]
|
embeddings = []
|
||||||
|
for item in raw_response.data:
|
||||||
|
# item.embedding 现在是 base64 编码的字符串
|
||||||
|
if isinstance(item.embedding, str):
|
||||||
|
# 解码 base64 为 NumPy 数组(float32)
|
||||||
|
embedding_array = np.frombuffer(
|
||||||
|
base64.b64decode(item.embedding),
|
||||||
|
dtype=np.float32
|
||||||
|
)
|
||||||
|
# 保持为 NumPy 数组,不调用 .tolist()
|
||||||
|
embeddings.append(embedding_array)
|
||||||
|
else:
|
||||||
|
# 兜底:如果 SDK 返回的不是 base64(旧版或其他情况)
|
||||||
|
# 转换为 NumPy 数组
|
||||||
|
embeddings.append(np.array(item.embedding, dtype=np.float32))
|
||||||
|
|
||||||
response.embedding = embeddings if is_batch_request else embeddings[0]
|
response.embedding = embeddings if is_batch_request else embeddings[0]
|
||||||
else:
|
else:
|
||||||
raise RespParseException(
|
raise RespParseException(
|
||||||
raw_response,
|
raw_response,
|
||||||
"响应解析失败,缺失嵌入数据。",
|
"响应解析失败,缺失嵌入数据。",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 大批量请求后触发垃圾回收(batch_size > 8)
|
||||||
|
if is_batch_request and len(embedding_input) > 8:
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
# 解析使用情况
|
# 解析使用情况
|
||||||
if hasattr(raw_response, "usage"):
|
if hasattr(raw_response, "usage"):
|
||||||
|
|||||||
@@ -28,6 +28,8 @@ from collections.abc import Callable, Coroutine
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, ClassVar, Literal
|
from typing import Any, ClassVar, Literal
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
@@ -1170,7 +1172,8 @@ class LLMRequest:
|
|||||||
if not isinstance(embeddings, list):
|
if not isinstance(embeddings, list):
|
||||||
raise RuntimeError("获取embedding失败,批量结果格式异常")
|
raise RuntimeError("获取embedding失败,批量结果格式异常")
|
||||||
|
|
||||||
if embeddings and not isinstance(embeddings[0], list):
|
# embeddings 正常应该是 list[vector];如果 provider 返回了一维列表(单向量),只在这种情况下套一层
|
||||||
|
if embeddings and not isinstance(embeddings[0], (list, tuple, np.ndarray)):
|
||||||
embeddings = [embeddings] # type: ignore[list-item]
|
embeddings = [embeddings] # type: ignore[list-item]
|
||||||
|
|
||||||
# 批量请求返回二维列表
|
# 批量请求返回二维列表
|
||||||
|
|||||||
@@ -117,7 +117,13 @@ class EmbeddingGenerator:
|
|||||||
# 调用 API
|
# 调用 API
|
||||||
embedding_list, model_name = await self._llm_request.get_embedding(text)
|
embedding_list, model_name = await self._llm_request.get_embedding(text)
|
||||||
|
|
||||||
if embedding_list and len(embedding_list) > 0:
|
# 兼容返回 np.ndarray 或 Python list
|
||||||
|
if isinstance(embedding_list, np.ndarray):
|
||||||
|
if embedding_list.size > 0:
|
||||||
|
embedding = np.array(embedding_list, dtype=np.float32)
|
||||||
|
logger.debug(f"🌐 API 生成嵌入: {text[:30]}... -> {len(embedding)}维 (模型: {model_name})")
|
||||||
|
return embedding
|
||||||
|
elif embedding_list and len(embedding_list) > 0:
|
||||||
embedding = np.array(embedding_list, dtype=np.float32)
|
embedding = np.array(embedding_list, dtype=np.float32)
|
||||||
logger.debug(f"🌐 API 生成嵌入: {text[:30]}... -> {len(embedding)}维 (模型: {model_name})")
|
logger.debug(f"🌐 API 生成嵌入: {text[:30]}... -> {len(embedding)}维 (模型: {model_name})")
|
||||||
return embedding
|
return embedding
|
||||||
@@ -187,12 +193,17 @@ class EmbeddingGenerator:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
embeddings, model_name = await self._llm_request.get_embedding(texts)
|
embeddings, model_name = await self._llm_request.get_embedding(texts)
|
||||||
if not embeddings:
|
if embeddings is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
results: list[np.ndarray | None] = []
|
results: list[np.ndarray | None] = []
|
||||||
for emb in embeddings:
|
for emb in embeddings:
|
||||||
if emb:
|
if isinstance(emb, np.ndarray):
|
||||||
|
if emb.size > 0:
|
||||||
|
results.append(np.array(emb, dtype=np.float32))
|
||||||
|
else:
|
||||||
|
results.append(None)
|
||||||
|
elif emb:
|
||||||
results.append(np.array(emb, dtype=np.float32))
|
results.append(np.array(emb, dtype=np.float32))
|
||||||
else:
|
else:
|
||||||
results.append(None)
|
results.append(None)
|
||||||
|
|||||||
@@ -178,6 +178,10 @@ class ChatterActionPlanner:
|
|||||||
message.interest_calculated = True
|
message.interest_calculated = True
|
||||||
interest_updates[message_id] = result.interest_value
|
interest_updates[message_id] = result.interest_value
|
||||||
reply_updates[message_id] = result.should_reply
|
reply_updates[message_id] = result.should_reply
|
||||||
|
|
||||||
|
# 批量处理后清理 embeddings 字典
|
||||||
|
embeddings.clear()
|
||||||
|
text_map.clear()
|
||||||
else:
|
else:
|
||||||
message.interest_calculated = False
|
message.interest_calculated = False
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user