diff --git a/src/chat/utils/utils_video.py b/src/chat/utils/utils_video.py index 2e886e960..cf66c1754 100644 --- a/src/chat/utils/utils_video.py +++ b/src/chat/utils/utils_video.py @@ -37,7 +37,12 @@ video_events = {} video_lock_manager = asyncio.Lock() -def _extract_frames_worker(video_path: str, max_frames: int, frame_quality: int, max_image_size: int) -> List[Tuple[str, float]]: +def _extract_frames_worker(video_path: str, + max_frames: int, + frame_quality: int, + max_image_size: int, + frame_extraction_mode: str, + frame_interval_seconds: Optional[float]) -> List[Tuple[str, float]]: """线程池中提取视频帧的工作函数""" frames = [] try: @@ -46,50 +51,85 @@ def _extract_frames_worker(video_path: str, max_frames: int, frame_quality: int, total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) duration = total_frames / fps if fps > 0 else 0 - # 使用numpy优化帧间隔计算 - if duration > 0: - frame_interval = max(1, int(duration / max_frames * fps)) + if frame_extraction_mode == "time_interval": + # 新模式:按时间间隔抽帧 + time_interval = frame_interval_seconds + next_frame_time = 0.0 + + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + + current_time = cap.get(cv2.CAP_PROP_POS_MSEC) / 1000.0 + + if current_time >= next_frame_time: + # 转换为PIL图像并压缩 + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + pil_image = Image.fromarray(frame_rgb) + + # 调整图像大小 + if max(pil_image.size) > self.max_image_size: + ratio = self.max_image_size / max(pil_image.size) + new_size = tuple(int(dim * ratio) for dim in pil_image.size) + pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS) + + # 转换为base64 + buffer = io.BytesIO() + pil_image.save(buffer, format='JPEG', quality=self.frame_quality) + frame_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8') + + frames.append((frame_base64, current_time)) + extracted_count += 1 + + logger.debug(f"提取第{extracted_count}帧 (时间: {current_time:.2f}s)") + + next_frame_time += time_interval else: - frame_interval = 30 # 默认间隔 - - # 使用numpy计算目标帧位置 - target_frames = np.arange(0, min(max_frames, total_frames // frame_interval + 1)) * frame_interval - target_frames = target_frames[target_frames < total_frames].astype(int) - - for target_frame in target_frames: - # 跳转到目标帧 - cap.set(cv2.CAP_PROP_POS_FRAMES, target_frame) - ret, frame = cap.read() - if not ret: - continue - - # 使用numpy优化图像处理 - frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - - # 转换为PIL图像并使用numpy进行尺寸计算 - height, width = frame_rgb.shape[:2] - max_dim = max(height, width) - - if max_dim > max_image_size: - # 使用numpy计算缩放比例 - ratio = max_image_size / max_dim - new_width = int(width * ratio) - new_height = int(height * ratio) - - # 使用opencv进行高效缩放 - frame_resized = cv2.resize(frame_rgb, (new_width, new_height), interpolation=cv2.INTER_LANCZOS4) - pil_image = Image.fromarray(frame_resized) + # 使用numpy优化帧间隔计算 + if duration > 0: + frame_interval = max(1, int(duration / max_frames * fps)) else: - pil_image = Image.fromarray(frame_rgb) + frame_interval = 30 # 默认间隔 + + # 使用numpy计算目标帧位置 + target_frames = np.arange(0, min(max_frames, total_frames // frame_interval + 1)) * frame_interval + target_frames = target_frames[target_frames < total_frames].astype(int) - # 转换为base64 - buffer = io.BytesIO() - pil_image.save(buffer, format='JPEG', quality=frame_quality) - frame_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8') - - # 计算时间戳 - timestamp = target_frame / fps if fps > 0 else 0 - frames.append((frame_base64, timestamp)) + for target_frame in target_frames: + # 跳转到目标帧 + cap.set(cv2.CAP_PROP_POS_FRAMES, target_frame) + ret, frame = cap.read() + if not ret: + continue + + # 使用numpy优化图像处理 + frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + + # 转换为PIL图像并使用numpy进行尺寸计算 + height, width = frame_rgb.shape[:2] + max_dim = max(height, width) + + if max_dim > max_image_size: + # 使用numpy计算缩放比例 + ratio = max_image_size / max_dim + new_width = int(width * ratio) + new_height = int(height * ratio) + + # 使用opencv进行高效缩放 + frame_resized = cv2.resize(frame_rgb, (new_width, new_height), interpolation=cv2.INTER_LANCZOS4) + pil_image = Image.fromarray(frame_resized) + else: + pil_image = Image.fromarray(frame_rgb) + + # 转换为base64 + buffer = io.BytesIO() + pil_image.save(buffer, format='JPEG', quality=frame_quality) + frame_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8') + + # 计算时间戳 + timestamp = target_frame / fps if fps > 0 else 0 + frames.append((frame_base64, timestamp)) cap.release() return frames @@ -120,50 +160,14 @@ class VideoAnalyzer: logger.warning(f"video_analysis配置不可用({e}),回退使用vlm配置") # 从配置文件读取参数,如果配置不存在则使用默认值 - try: - config = global_config.video_analysis - self.max_frames = config.max_frames - self.frame_quality = config.frame_quality - self.max_image_size = config.max_image_size - self.enable_frame_timing = config.enable_frame_timing - self.batch_analysis_prompt = config.batch_analysis_prompt - # 新增的线程池配置 - self.use_multiprocessing = getattr(config, 'use_multiprocessing', True) - self.max_workers = getattr(config, 'max_workers', 2) - - # 将配置文件中的模式映射到内部使用的模式名称 - config_mode = config.analysis_mode - if config_mode == "batch_frames": - self.analysis_mode = "batch" - elif config_mode == "frame_by_frame": - self.analysis_mode = "sequential" - elif config_mode == "auto": - self.analysis_mode = "auto" - else: - logger.warning(f"无效的分析模式: {config_mode},使用默认的auto模式") - self.analysis_mode = "auto" - - self.frame_analysis_delay = 0.3 # API调用间隔(秒) - self.frame_interval = 1.0 # 抽帧时间间隔(秒) - self.batch_size = 3 # 批处理时每批处理的帧数 - self.timeout = 60.0 # 分析超时时间(秒) - logger.info("✅ 从配置文件读取视频分析参数") - - except AttributeError as e: - # 如果配置不存在,使用代码中的默认值 - logger.warning(f"配置文件中缺少video_analysis配置({e}),使用默认值") - self.max_frames = 6 - self.frame_quality = 85 - self.max_image_size = 600 - self.analysis_mode = "auto" - self.frame_analysis_delay = 0.3 - self.frame_interval = 1.0 # 抽帧时间间隔(秒) - self.batch_size = 3 # 批处理时每批处理的帧数 - self.timeout = 60.0 # 分析超时时间(秒) - self.enable_frame_timing = True - self.use_multiprocessing = True # 默认启用线程池 - self.max_workers = 2 # 默认最大2个线程 - self.batch_analysis_prompt = """请分析这个视频的内容。这些图片是从视频中按时间顺序提取的关键帧。 + config = global_config.video_analysis + + # 使用 getattr 统一获取配置参数,如果配置不存在则使用默认值 + self.max_frames = getattr(config, 'max_frames', 6) + self.frame_quality = getattr(config, 'frame_quality', 85) + self.max_image_size = getattr(config, 'max_image_size', 600) + self.enable_frame_timing = getattr(config, 'enable_frame_timing', True) + self.batch_analysis_prompt = getattr(config, 'batch_analysis_prompt', """请分析这个视频的内容。这些图片是从视频中按时间顺序提取的关键帧。 请提供详细的分析,包括: 1. 视频的整体内容和主题 @@ -173,7 +177,35 @@ class VideoAnalyzer: 5. 整体氛围和情感表达 6. 任何特殊的视觉效果或文字内容 -请用中文回答,分析要详细准确。""" +请用中文回答,分析要详细准确。""") + + # 新增的线程池配置 + self.use_multiprocessing = getattr(config, 'use_multiprocessing', True) + self.max_workers = getattr(config, 'max_workers', 2) + self.frame_extraction_mode = getattr(config, 'frame_extraction_mode', 'fixed_number') + self.frame_interval_seconds = getattr(config, 'frame_interval_seconds', 2.0) + + # 将配置文件中的模式映射到内部使用的模式名称 + config_mode = getattr(config, 'analysis_mode', 'auto') + if config_mode == "batch_frames": + self.analysis_mode = "batch" + elif config_mode == "frame_by_frame": + self.analysis_mode = "sequential" + elif config_mode == "auto": + self.analysis_mode = "auto" + else: + logger.warning(f"无效的分析模式: {config_mode},使用默认的auto模式") + self.analysis_mode = "auto" + + self.frame_analysis_delay = 0.3 # API调用间隔(秒) + self.frame_interval = 1.0 # 抽帧时间间隔(秒) + self.batch_size = 3 # 批处理时每批处理的帧数 + self.timeout = 60.0 # 分析超时时间(秒) + + if config: + logger.info("✅ 从配置文件读取视频分析参数") + else: + logger.warning("配置文件中缺少video_analysis配置,使用默认值") # 系统提示词 self.system_prompt = "你是一个专业的视频内容分析助手。请仔细观察用户提供的视频关键帧,详细描述视频内容。" @@ -292,7 +324,9 @@ class VideoAnalyzer: video_path, self.max_frames, self.frame_quality, - self.max_image_size + self.max_image_size, + self.frame_extraction_mode, + self.frame_interval_seconds ) # 检查是否有错误 @@ -314,6 +348,7 @@ class VideoAnalyzer: async def _extract_frames_fallback(self, video_path: str) -> List[Tuple[str, float]]: """帧提取的降级方法 - 原始异步版本""" frames = [] + extracted_count = 0 cap = cv2.VideoCapture(video_path) fps = cap.get(cv2.CAP_PROP_FPS) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))