feat: 添加视频处理多线程优化和消息切片重组功能
- 新增视频帧提取的线程池支持,提升大视频文件处理性能 - 集成消息切片重组器,支持长消息的自动重组处理 - 优化视频帧提取算法,使用numpy进行数值计算优化 - 重构权限管理插件,修复属性访问和方法签名问题 - 清理未使用的导入和代码,提升代码质量 - 默认启用插件管理功能
This commit is contained in:
@@ -12,10 +12,14 @@ import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import time
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Optional, Dict
|
||||
import io
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
import numpy as np
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
@@ -33,6 +37,68 @@ video_events = {}
|
||||
video_lock_manager = asyncio.Lock()
|
||||
|
||||
|
||||
def _extract_frames_worker(video_path: str, max_frames: int, frame_quality: int, max_image_size: int) -> List[Tuple[str, float]]:
|
||||
"""线程池中提取视频帧的工作函数"""
|
||||
frames = []
|
||||
try:
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
duration = total_frames / fps if fps > 0 else 0
|
||||
|
||||
# 使用numpy优化帧间隔计算
|
||||
if duration > 0:
|
||||
frame_interval = max(1, int(duration / max_frames * fps))
|
||||
else:
|
||||
frame_interval = 30 # 默认间隔
|
||||
|
||||
# 使用numpy计算目标帧位置
|
||||
target_frames = np.arange(0, min(max_frames, total_frames // frame_interval + 1)) * frame_interval
|
||||
target_frames = target_frames[target_frames < total_frames].astype(int)
|
||||
|
||||
for target_frame in target_frames:
|
||||
# 跳转到目标帧
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, target_frame)
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
continue
|
||||
|
||||
# 使用numpy优化图像处理
|
||||
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# 转换为PIL图像并使用numpy进行尺寸计算
|
||||
height, width = frame_rgb.shape[:2]
|
||||
max_dim = max(height, width)
|
||||
|
||||
if max_dim > max_image_size:
|
||||
# 使用numpy计算缩放比例
|
||||
ratio = max_image_size / max_dim
|
||||
new_width = int(width * ratio)
|
||||
new_height = int(height * ratio)
|
||||
|
||||
# 使用opencv进行高效缩放
|
||||
frame_resized = cv2.resize(frame_rgb, (new_width, new_height), interpolation=cv2.INTER_LANCZOS4)
|
||||
pil_image = Image.fromarray(frame_resized)
|
||||
else:
|
||||
pil_image = Image.fromarray(frame_rgb)
|
||||
|
||||
# 转换为base64
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format='JPEG', quality=frame_quality)
|
||||
frame_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||
|
||||
# 计算时间戳
|
||||
timestamp = target_frame / fps if fps > 0 else 0
|
||||
frames.append((frame_base64, timestamp))
|
||||
|
||||
cap.release()
|
||||
return frames
|
||||
|
||||
except Exception as e:
|
||||
# 返回错误信息
|
||||
return [("ERROR", str(e))]
|
||||
|
||||
|
||||
class VideoAnalyzer:
|
||||
"""优化的视频分析器类"""
|
||||
|
||||
@@ -61,6 +127,9 @@ class VideoAnalyzer:
|
||||
self.max_image_size = config.max_image_size
|
||||
self.enable_frame_timing = config.enable_frame_timing
|
||||
self.batch_analysis_prompt = config.batch_analysis_prompt
|
||||
# 新增的线程池配置
|
||||
self.use_multiprocessing = getattr(config, 'use_multiprocessing', True)
|
||||
self.max_workers = getattr(config, 'max_workers', 2)
|
||||
|
||||
# 将配置文件中的模式映射到内部使用的模式名称
|
||||
config_mode = config.analysis_mode
|
||||
@@ -92,6 +161,8 @@ class VideoAnalyzer:
|
||||
self.batch_size = 3 # 批处理时每批处理的帧数
|
||||
self.timeout = 60.0 # 分析超时时间(秒)
|
||||
self.enable_frame_timing = True
|
||||
self.use_multiprocessing = True # 默认启用线程池
|
||||
self.max_workers = 2 # 默认最大2个线程
|
||||
self.batch_analysis_prompt = """请分析这个视频的内容。这些图片是从视频中按时间顺序提取的关键帧。
|
||||
|
||||
请提供详细的分析,包括:
|
||||
@@ -107,7 +178,7 @@ class VideoAnalyzer:
|
||||
# 系统提示词
|
||||
self.system_prompt = "你是一个专业的视频内容分析助手。请仔细观察用户提供的视频关键帧,详细描述视频内容。"
|
||||
|
||||
logger.info(f"✅ 视频分析器初始化完成,分析模式: {self.analysis_mode}")
|
||||
logger.info(f"✅ 视频分析器初始化完成,分析模式: {self.analysis_mode}, 线程池: {self.use_multiprocessing}")
|
||||
|
||||
def _calculate_video_hash(self, video_data: bytes) -> str:
|
||||
"""计算视频文件的hash值"""
|
||||
@@ -182,7 +253,66 @@ class VideoAnalyzer:
|
||||
logger.warning(f"无效的分析模式: {mode}")
|
||||
|
||||
async def extract_frames(self, video_path: str) -> List[Tuple[str, float]]:
|
||||
"""提取视频帧"""
|
||||
"""提取视频帧 - 支持多进程和单线程模式"""
|
||||
# 先获取视频信息
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
||||
duration = total_frames / fps if fps > 0 else 0
|
||||
cap.release()
|
||||
|
||||
logger.info(f"视频信息: {total_frames}帧, {fps:.2f}FPS, {duration:.2f}秒")
|
||||
|
||||
# 估算提取帧数
|
||||
if duration > 0:
|
||||
frame_interval = max(1, int(duration / self.max_frames * fps))
|
||||
estimated_frames = min(self.max_frames, total_frames // frame_interval + 1)
|
||||
else:
|
||||
estimated_frames = self.max_frames
|
||||
|
||||
logger.info(f"计算得出帧间隔: {frame_interval} (将提取约{estimated_frames}帧)")
|
||||
|
||||
# 根据配置选择处理方式
|
||||
if self.use_multiprocessing:
|
||||
return await self._extract_frames_multiprocess(video_path)
|
||||
else:
|
||||
return await self._extract_frames_fallback(video_path)
|
||||
|
||||
async def _extract_frames_multiprocess(self, video_path: str) -> List[Tuple[str, float]]:
|
||||
"""线程池版本的帧提取"""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
try:
|
||||
logger.info("🔄 启动线程池帧提取...")
|
||||
# 使用线程池,避免进程间的导入问题
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
frames = await loop.run_in_executor(
|
||||
executor,
|
||||
_extract_frames_worker,
|
||||
video_path,
|
||||
self.max_frames,
|
||||
self.frame_quality,
|
||||
self.max_image_size
|
||||
)
|
||||
|
||||
# 检查是否有错误
|
||||
if frames and frames[0][0] == "ERROR":
|
||||
logger.error(f"线程池帧提取失败: {frames[0][1]}")
|
||||
# 降级到单线程模式
|
||||
logger.info("🔄 降级到单线程模式...")
|
||||
return await self._extract_frames_fallback(video_path)
|
||||
|
||||
logger.info(f"✅ 成功提取{len(frames)}帧 (线程池模式)")
|
||||
return frames
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"线程池帧提取失败: {e}")
|
||||
# 降级到原始方法
|
||||
logger.info("🔄 降级到单线程模式...")
|
||||
return await self._extract_frames_fallback(video_path)
|
||||
|
||||
async def _extract_frames_fallback(self, video_path: str) -> List[Tuple[str, float]]:
|
||||
"""帧提取的降级方法 - 原始异步版本"""
|
||||
frames = []
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
@@ -191,45 +321,61 @@ class VideoAnalyzer:
|
||||
|
||||
logger.info(f"视频信息: {total_frames}帧, {fps:.2f}FPS, {duration:.2f}秒")
|
||||
|
||||
# 动态计算帧间隔
|
||||
# 使用numpy优化帧间隔计算
|
||||
if duration > 0:
|
||||
frame_interval = max(1, int(duration / self.max_frames * fps))
|
||||
else:
|
||||
frame_interval = 30 # 默认间隔
|
||||
|
||||
logger.info(f"计算得出帧间隔: {frame_interval} (将提取约{min(self.max_frames, total_frames // frame_interval + 1)}帧)")
|
||||
|
||||
# 使用numpy计算目标帧位置
|
||||
target_frames = np.arange(0, min(self.max_frames, total_frames // frame_interval + 1)) * frame_interval
|
||||
target_frames = target_frames[target_frames < total_frames].astype(int)
|
||||
|
||||
frame_count = 0
|
||||
extracted_count = 0
|
||||
|
||||
while cap.isOpened() and extracted_count < self.max_frames:
|
||||
for target_frame in target_frames:
|
||||
# 跳转到目标帧
|
||||
cap.set(cv2.CAP_PROP_POS_FRAMES, target_frame)
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
continue
|
||||
|
||||
if frame_count % frame_interval == 0:
|
||||
# 转换为PIL图像并压缩
|
||||
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
pil_image = Image.fromarray(frame_rgb)
|
||||
|
||||
# 调整图像大小
|
||||
if max(pil_image.size) > self.max_image_size:
|
||||
ratio = self.max_image_size / max(pil_image.size)
|
||||
new_size = tuple(int(dim * ratio) for dim in pil_image.size)
|
||||
pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS)
|
||||
|
||||
# 转换为base64
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format='JPEG', quality=self.frame_quality)
|
||||
frame_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||
|
||||
# 计算时间戳
|
||||
timestamp = frame_count / fps if fps > 0 else 0
|
||||
frames.append((frame_base64, timestamp))
|
||||
extracted_count += 1
|
||||
|
||||
logger.debug(f"提取第{extracted_count}帧 (时间: {timestamp:.2f}s)")
|
||||
# 使用numpy优化图像处理
|
||||
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# 转换为PIL图像并使用numpy进行尺寸计算
|
||||
height, width = frame_rgb.shape[:2]
|
||||
max_dim = max(height, width)
|
||||
|
||||
if max_dim > self.max_image_size:
|
||||
# 使用numpy计算缩放比例
|
||||
ratio = self.max_image_size / max_dim
|
||||
new_width = int(width * ratio)
|
||||
new_height = int(height * ratio)
|
||||
|
||||
# 使用opencv进行高效缩放
|
||||
frame_resized = cv2.resize(frame_rgb, (new_width, new_height), interpolation=cv2.INTER_LANCZOS4)
|
||||
pil_image = Image.fromarray(frame_resized)
|
||||
else:
|
||||
pil_image = Image.fromarray(frame_rgb)
|
||||
|
||||
# 转换为base64
|
||||
buffer = io.BytesIO()
|
||||
pil_image.save(buffer, format='JPEG', quality=self.frame_quality)
|
||||
frame_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||
|
||||
# 计算时间戳
|
||||
timestamp = target_frame / fps if fps > 0 else 0
|
||||
frames.append((frame_base64, timestamp))
|
||||
extracted_count += 1
|
||||
|
||||
logger.debug(f"提取第{extracted_count}帧 (时间: {timestamp:.2f}s, 帧号: {target_frame})")
|
||||
|
||||
# 每提取一帧让步一次
|
||||
await asyncio.sleep(0.001)
|
||||
|
||||
frame_count += 1
|
||||
|
||||
cap.release()
|
||||
logger.info(f"✅ 成功提取{len(frames)}帧")
|
||||
return frames
|
||||
|
||||
Reference in New Issue
Block a user