依旧修pyright喵喵喵~
This commit is contained in:
@@ -2,6 +2,7 @@ import asyncio
|
||||
import math
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
# import tqdm
|
||||
import aiofiles
|
||||
@@ -121,7 +122,7 @@ class EmbeddingStore:
|
||||
|
||||
self.store = {}
|
||||
|
||||
self.faiss_index = None
|
||||
self.faiss_index: Any = None
|
||||
self.idx2hash = None
|
||||
|
||||
@staticmethod
|
||||
@@ -158,6 +159,8 @@ class EmbeddingStore:
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
assert model_config is not None
|
||||
|
||||
# 限制 chunk_size 和 max_workers 在合理范围内
|
||||
chunk_size = max(MIN_CHUNK_SIZE, min(chunk_size, MAX_CHUNK_SIZE))
|
||||
max_workers = max(MIN_WORKERS, min(max_workers, MAX_WORKERS))
|
||||
@@ -402,6 +405,7 @@ class EmbeddingStore:
|
||||
|
||||
def build_faiss_index(self) -> None:
|
||||
"""重新构建Faiss索引,以余弦相似度为度量"""
|
||||
assert global_config is not None
|
||||
# 获取所有的embedding
|
||||
array = []
|
||||
self.idx2hash = {}
|
||||
|
||||
@@ -344,14 +344,15 @@ class ImageManager:
|
||||
# --- 新的帧选择逻辑:均匀抽取4帧 ---
|
||||
num_frames = len(all_frames)
|
||||
if num_frames <= 4:
|
||||
# 如果总帧数小于等于4,则全部选中
|
||||
# 如果总宽度小于等于4,则全部选中
|
||||
selected_frames = all_frames
|
||||
indices = list(range(num_frames))
|
||||
else:
|
||||
# 使用linspace计算4个均匀分布的索引
|
||||
indices = np.linspace(0, num_frames - 1, 4, dtype=int)
|
||||
selected_frames = [all_frames[i] for i in indices]
|
||||
|
||||
logger.debug(f"GIF Frame Analysis: Total frames={num_frames}, Selected indices={indices if num_frames > 4 else list(range(num_frames))}")
|
||||
logger.debug(f"GIF Frame Analysis: Total frames={num_frames}, Selected indices={indices}")
|
||||
# --- 帧选择逻辑结束 ---
|
||||
|
||||
# 如果选择后连一帧都没有(比如GIF只有一帧且后续处理失败?)或者原始GIF就没帧,也返回None
|
||||
|
||||
@@ -37,7 +37,7 @@ _locks_guard = asyncio.Lock()
|
||||
|
||||
logger = get_logger("utils_video")
|
||||
|
||||
from inkfox import video
|
||||
from inkfox import video # type: ignore
|
||||
|
||||
|
||||
class VideoAnalyzer:
|
||||
@@ -123,7 +123,6 @@ class VideoAnalyzer:
|
||||
# ---- 批量分析 ----
|
||||
async def _analyze_batch(self, frames: list[tuple[str, float]], question: str | None) -> str:
|
||||
from src.llm_models.payload_content.message import MessageBuilder, RoleType
|
||||
from src.llm_models.utils_model import RequestType
|
||||
|
||||
prompt = self.batch_analysis_prompt.format(
|
||||
personality_core=self.personality_core, personality_side=self.personality_side
|
||||
@@ -139,12 +138,7 @@ class VideoAnalyzer:
|
||||
for b64, _ in frames:
|
||||
mb.add_image_content("jpeg", b64)
|
||||
message = mb.build()
|
||||
model_info, api_provider, client = self.video_llm._select_model()
|
||||
resp = await self.video_llm._execute_request(
|
||||
api_provider=api_provider,
|
||||
client=client,
|
||||
request_type=RequestType.RESPONSE,
|
||||
model_info=model_info,
|
||||
resp = await self.video_llm.execute_with_messages(
|
||||
message_list=[message],
|
||||
temperature=None,
|
||||
max_tokens=None,
|
||||
|
||||
@@ -31,9 +31,9 @@ def _extract_frames_worker(
|
||||
max_image_size: int,
|
||||
frame_extraction_mode: str,
|
||||
frame_interval_seconds: float | None,
|
||||
) -> list[Any] | list[tuple[str, str]]:
|
||||
) -> list[tuple[str, float]] | list[tuple[str, str]]:
|
||||
"""线程池中提取视频帧的工作函数"""
|
||||
frames = []
|
||||
frames: list[tuple[str, float]] = []
|
||||
try:
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
@@ -42,7 +42,7 @@ def _extract_frames_worker(
|
||||
|
||||
if frame_extraction_mode == "time_interval":
|
||||
# 新模式:按时间间隔抽帧
|
||||
time_interval = frame_interval_seconds
|
||||
time_interval = frame_interval_seconds or 2.0
|
||||
next_frame_time = 0.0
|
||||
extracted_count = 0 # 初始化提取帧计数器
|
||||
|
||||
@@ -61,7 +61,7 @@ def _extract_frames_worker(
|
||||
# 调整图像大小
|
||||
if max(pil_image.size) > max_image_size:
|
||||
ratio = max_image_size / max(pil_image.size)
|
||||
new_size = tuple(int(dim * ratio) for dim in pil_image.size)
|
||||
new_size = (int(pil_image.size[0] * ratio), int(pil_image.size[1] * ratio))
|
||||
pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS)
|
||||
|
||||
# 转换为base64
|
||||
@@ -240,6 +240,7 @@ class LegacyVideoAnalyzer:
|
||||
estimated_frames = min(self.max_frames, total_frames // frame_interval + 1)
|
||||
else:
|
||||
estimated_frames = self.max_frames
|
||||
frame_interval = 1
|
||||
|
||||
logger.info(f"计算得出帧间隔: {frame_interval} (将提取约{estimated_frames}帧)")
|
||||
|
||||
@@ -276,7 +277,7 @@ class LegacyVideoAnalyzer:
|
||||
return await self._extract_frames_fallback(video_path)
|
||||
|
||||
logger.info(f"✅ 成功提取{len(frames)}帧 (线程池模式)")
|
||||
return frames
|
||||
return frames # type: ignore
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"线程池帧提取失败: {e}")
|
||||
@@ -315,7 +316,7 @@ class LegacyVideoAnalyzer:
|
||||
# 调整图像大小
|
||||
if max(pil_image.size) > self.max_image_size:
|
||||
ratio = self.max_image_size / max(pil_image.size)
|
||||
new_size = tuple(int(dim * ratio) for dim in pil_image.size)
|
||||
new_size = (int(pil_image.size[0] * ratio), int(pil_image.size[1] * ratio))
|
||||
pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS)
|
||||
|
||||
# 转换为base64
|
||||
@@ -463,11 +464,11 @@ class LegacyVideoAnalyzer:
|
||||
# logger.info(f"✅ 多帧消息构建完成,包含{len(frames)}张图片")
|
||||
|
||||
# 获取模型信息和客户端
|
||||
model_info, api_provider, client = self.video_llm._select_model()
|
||||
model_info, api_provider, client = self.video_llm._select_model() # type: ignore
|
||||
# logger.info(f"使用模型: {model_info.name} 进行多帧分析")
|
||||
|
||||
# 直接执行多图片请求
|
||||
api_response = await self.video_llm._execute_request(
|
||||
api_response = await self.video_llm._execute_request( # type: ignore
|
||||
api_provider=api_provider,
|
||||
client=client,
|
||||
request_type=RequestType.RESPONSE,
|
||||
|
||||
Reference in New Issue
Block a user