Merge branch 'master' into migrate-windpicker-changes

This commit is contained in:
Windpicker-owo
2025-08-25 17:47:50 +08:00
9 changed files with 512 additions and 103 deletions

View File

@@ -20,8 +20,6 @@ from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
# 导入反注入系统
from src.chat.antipromptinjector import initialize_anti_injector
# 定义日志配置
# 获取项目根目录假设本文件在src/chat/message_receive/下,根目录为上上上级目录)
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
@@ -242,6 +240,20 @@ class ChatBot:
- 性能计时
"""
try:
# 首先处理可能的切片消息重组
from src.utils.message_chunker import reassembler
# 尝试重组切片消息
reassembled_message = await reassembler.process_chunk(message_data)
if reassembled_message is None:
# 这是一个切片,但还未完整,等待更多切片
logger.debug("等待更多切片,跳过此次处理")
return
elif reassembled_message != message_data:
# 消息已被重组,使用重组后的消息
logger.info("使用重组后的完整消息进行处理")
message_data = reassembled_message
# 确保所有任务已启动
await self._ensure_started()

View File

@@ -12,10 +12,14 @@ import asyncio
import base64
import hashlib
import time
import numpy as np
from PIL import Image
from pathlib import Path
from typing import List, Tuple, Optional, Dict
import io
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import numpy as np
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
@@ -33,6 +37,110 @@ video_events = {}
video_lock_manager = asyncio.Lock()
def _extract_frames_worker(video_path: str,
max_frames: int,
frame_quality: int,
max_image_size: int,
frame_extraction_mode: str,
frame_interval_seconds: Optional[float]) -> List[Tuple[str, float]]:
"""线程池中提取视频帧的工作函数"""
frames = []
try:
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
duration = total_frames / fps if fps > 0 else 0
if frame_extraction_mode == "time_interval":
# 新模式:按时间间隔抽帧
time_interval = frame_interval_seconds
next_frame_time = 0.0
extracted_count = 0 # 初始化提取帧计数器
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
current_time = cap.get(cv2.CAP_PROP_POS_MSEC) / 1000.0
if current_time >= next_frame_time:
# 转换为PIL图像并压缩
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(frame_rgb)
# 调整图像大小
if max(pil_image.size) > max_image_size:
ratio = max_image_size / max(pil_image.size)
new_size = tuple(int(dim * ratio) for dim in pil_image.size)
pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS)
# 转换为base64
buffer = io.BytesIO()
pil_image.save(buffer, format='JPEG', quality=frame_quality)
frame_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
frames.append((frame_base64, current_time))
extracted_count += 1
# 注意这里不能使用logger因为在线程池中
# logger.debug(f"提取第{extracted_count}帧 (时间: {current_time:.2f}s)")
next_frame_time += time_interval
else:
# 使用numpy优化帧间隔计算
if duration > 0:
frame_interval = max(1, int(duration / max_frames * fps))
else:
frame_interval = 30 # 默认间隔
# 使用numpy计算目标帧位置
target_frames = np.arange(0, min(max_frames, total_frames // frame_interval + 1)) * frame_interval
target_frames = target_frames[target_frames < total_frames].astype(int)
for target_frame in target_frames:
# 跳转到目标帧
cap.set(cv2.CAP_PROP_POS_FRAMES, target_frame)
ret, frame = cap.read()
if not ret:
continue
# 使用numpy优化图像处理
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# 转换为PIL图像并使用numpy进行尺寸计算
height, width = frame_rgb.shape[:2]
max_dim = max(height, width)
if max_dim > max_image_size:
# 使用numpy计算缩放比例
ratio = max_image_size / max_dim
new_width = int(width * ratio)
new_height = int(height * ratio)
# 使用opencv进行高效缩放
frame_resized = cv2.resize(frame_rgb, (new_width, new_height), interpolation=cv2.INTER_LANCZOS4)
pil_image = Image.fromarray(frame_resized)
else:
pil_image = Image.fromarray(frame_rgb)
# 转换为base64
buffer = io.BytesIO()
pil_image.save(buffer, format='JPEG', quality=frame_quality)
frame_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
# 计算时间戳
timestamp = target_frame / fps if fps > 0 else 0
frames.append((frame_base64, timestamp))
cap.release()
return frames
except Exception as e:
# 返回错误信息
return [("ERROR", str(e))]
class VideoAnalyzer:
"""优化的视频分析器类"""
@@ -54,49 +162,14 @@ class VideoAnalyzer:
logger.warning(f"video_analysis配置不可用({e})回退使用vlm配置")
# 从配置文件读取参数,如果配置不存在则使用默认值
try:
config = global_config.video_analysis
self.max_frames = config.max_frames
self.frame_quality = config.frame_quality
self.max_image_size = config.max_image_size
self.enable_frame_timing = config.enable_frame_timing
self.batch_analysis_prompt = config.batch_analysis_prompt
self.frame_extraction_mode = config.frame_extraction_mode
self.frame_interval_seconds = config.frame_interval_seconds
# 将配置文件中的模式映射到内部使用的模式名称
config_mode = config.analysis_mode
if config_mode == "batch_frames":
self.analysis_mode = "batch"
elif config_mode == "frame_by_frame":
self.analysis_mode = "sequential"
elif config_mode == "auto":
self.analysis_mode = "auto"
else:
logger.warning(f"无效的分析模式: {config_mode}使用默认的auto模式")
self.analysis_mode = "auto"
self.frame_analysis_delay = 0.3 # API调用间隔
self.frame_interval = 1.0 # 抽帧时间间隔(秒)
self.batch_size = 3 # 批处理时每批处理的帧数
self.timeout = 60.0 # 分析超时时间(秒)
logger.info("✅ 从配置文件读取视频分析参数")
except AttributeError as e:
# 如果配置不存在,使用代码中的默认值
logger.warning(f"配置文件中缺少video_analysis配置({e}),使用默认值")
self.max_frames = 6
self.frame_quality = 85
self.max_image_size = 600
self.analysis_mode = "auto"
self.frame_analysis_delay = 0.3
self.frame_interval = 1.0 # 抽帧时间间隔(秒)
self.batch_size = 3 # 批处理时每批处理的帧数
self.timeout = 60.0 # 分析超时时间(秒)
self.enable_frame_timing = True
self.frame_extraction_mode = "fixed_number"
self.frame_interval_seconds = 2.0
self.batch_analysis_prompt = """请分析这个视频的内容。这些图片是从视频中按时间顺序提取的关键帧。
config = global_config.video_analysis
# 使用 getattr 统一获取配置参数,如果配置不存在则使用默认值
self.max_frames = getattr(config, 'max_frames', 6)
self.frame_quality = getattr(config, 'frame_quality', 85)
self.max_image_size = getattr(config, 'max_image_size', 600)
self.enable_frame_timing = getattr(config, 'enable_frame_timing', True)
self.batch_analysis_prompt = getattr(config, 'batch_analysis_prompt', """请分析这个视频的内容。这些图片是从视频中按时间顺序提取的关键帧。
请提供详细的分析,包括:
1. 视频的整体内容和主题
@@ -106,12 +179,40 @@ class VideoAnalyzer:
5. 整体氛围和情感表达
6. 任何特殊的视觉效果或文字内容
请用中文回答,分析要详细准确。"""
请用中文回答,分析要详细准确。""")
# 新增的线程池配置
self.use_multiprocessing = getattr(config, 'use_multiprocessing', True)
self.max_workers = getattr(config, 'max_workers', 2)
self.frame_extraction_mode = getattr(config, 'frame_extraction_mode', 'fixed_number')
self.frame_interval_seconds = getattr(config, 'frame_interval_seconds', 2.0)
# 将配置文件中的模式映射到内部使用的模式名称
config_mode = getattr(config, 'analysis_mode', 'auto')
if config_mode == "batch_frames":
self.analysis_mode = "batch"
elif config_mode == "frame_by_frame":
self.analysis_mode = "sequential"
elif config_mode == "auto":
self.analysis_mode = "auto"
else:
logger.warning(f"无效的分析模式: {config_mode}使用默认的auto模式")
self.analysis_mode = "auto"
self.frame_analysis_delay = 0.3 # API调用间隔
self.frame_interval = 1.0 # 抽帧时间间隔(秒)
self.batch_size = 3 # 批处理时每批处理的帧数
self.timeout = 60.0 # 分析超时时间(秒)
if config:
logger.info("✅ 从配置文件读取视频分析参数")
else:
logger.warning("配置文件中缺少video_analysis配置使用默认值")
# 系统提示词
self.system_prompt = "你是一个专业的视频内容分析助手。请仔细观察用户提供的视频关键帧,详细描述视频内容。"
logger.info(f"✅ 视频分析器初始化完成,分析模式: {self.analysis_mode}")
logger.info(f"✅ 视频分析器初始化完成,分析模式: {self.analysis_mode}, 线程池: {self.use_multiprocessing}")
def _calculate_video_hash(self, video_data: bytes) -> str:
"""计算视频文件的hash值"""
@@ -186,8 +287,70 @@ class VideoAnalyzer:
logger.warning(f"无效的分析模式: {mode}")
async def extract_frames(self, video_path: str) -> List[Tuple[str, float]]:
"""提取视频帧"""
"""提取视频帧 - 支持多进程和单线程模式"""
# 先获取视频信息
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
duration = total_frames / fps if fps > 0 else 0
cap.release()
logger.info(f"视频信息: {total_frames}帧, {fps:.2f}FPS, {duration:.2f}")
# 估算提取帧数
if duration > 0:
frame_interval = max(1, int(duration / self.max_frames * fps))
estimated_frames = min(self.max_frames, total_frames // frame_interval + 1)
else:
estimated_frames = self.max_frames
logger.info(f"计算得出帧间隔: {frame_interval} (将提取约{estimated_frames}帧)")
# 根据配置选择处理方式
if self.use_multiprocessing:
return await self._extract_frames_multiprocess(video_path)
else:
return await self._extract_frames_fallback(video_path)
async def _extract_frames_multiprocess(self, video_path: str) -> List[Tuple[str, float]]:
"""线程池版本的帧提取"""
loop = asyncio.get_event_loop()
try:
logger.info("🔄 启动线程池帧提取...")
# 使用线程池,避免进程间的导入问题
with ThreadPoolExecutor(max_workers=1) as executor:
frames = await loop.run_in_executor(
executor,
_extract_frames_worker,
video_path,
self.max_frames,
self.frame_quality,
self.max_image_size,
self.frame_extraction_mode,
self.frame_interval_seconds
)
# 检查是否有错误
if frames and frames[0][0] == "ERROR":
logger.error(f"线程池帧提取失败: {frames[0][1]}")
# 降级到单线程模式
logger.info("🔄 降级到单线程模式...")
return await self._extract_frames_fallback(video_path)
logger.info(f"✅ 成功提取{len(frames)}帧 (线程池模式)")
return frames
except Exception as e:
logger.error(f"线程池帧提取失败: {e}")
# 降级到原始方法
logger.info("🔄 降级到单线程模式...")
return await self._extract_frames_fallback(video_path)
async def _extract_frames_fallback(self, video_path: str) -> List[Tuple[str, float]]:
"""帧提取的降级方法 - 原始异步版本"""
frames = []
extracted_count = 0
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
@@ -195,9 +358,7 @@ class VideoAnalyzer:
logger.info(f"视频信息: {total_frames}帧, {fps:.2f}FPS, {duration:.2f}")
frame_count = 0
extracted_count = 0
if self.frame_extraction_mode == "time_interval":
# 新模式:按时间间隔抽帧
time_interval = self.frame_interval_seconds
@@ -233,42 +394,61 @@ class VideoAnalyzer:
next_frame_time += time_interval
else:
# 旧模式:固定总帧数
# 使用numpy优化帧间隔计算
if duration > 0:
frame_interval = max(1, int(total_frames / self.max_frames))
frame_interval = max(1, int(duration / self.max_frames * fps))
else:
frame_interval = 1 # 如果无法获取时长则逐帧提取直到达到max_frames
frame_interval = 30 # 默认间隔
logger.info(f"计算得出帧间隔: {frame_interval} (将提取约{min(self.max_frames, total_frames // frame_interval + 1)}帧)")
while cap.isOpened() and extracted_count < self.max_frames:
# 使用numpy计算目标帧位置
target_frames = np.arange(0, min(self.max_frames, total_frames // frame_interval + 1)) * frame_interval
target_frames = target_frames[target_frames < total_frames].astype(int)
extracted_count = 0
for target_frame in target_frames:
# 跳转到目标帧
cap.set(cv2.CAP_PROP_POS_FRAMES, target_frame)
ret, frame = cap.read()
if not ret:
break
continue
# 使用numpy优化图像处理
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
if frame_count % frame_interval == 0:
# 转换为PIL图像并压缩
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# 转换为PIL图像并使用numpy进行尺寸计算
height, width = frame_rgb.shape[:2]
max_dim = max(height, width)
if max_dim > self.max_image_size:
# 使用numpy计算缩放比例
ratio = self.max_image_size / max_dim
new_width = int(width * ratio)
new_height = int(height * ratio)
# 使用opencv进行高效缩放
frame_resized = cv2.resize(frame_rgb, (new_width, new_height), interpolation=cv2.INTER_LANCZOS4)
pil_image = Image.fromarray(frame_resized)
else:
pil_image = Image.fromarray(frame_rgb)
# 调整图像大小
if max(pil_image.size) > self.max_image_size:
ratio = self.max_image_size / max(pil_image.size)
new_size = tuple(int(dim * ratio) for dim in pil_image.size)
pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS)
# 转换为base64
buffer = io.BytesIO()
pil_image.save(buffer, format='JPEG', quality=self.frame_quality)
frame_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
# 计算时间戳
timestamp = frame_count / fps if fps > 0 else 0
timestamp = target_frame / fps if fps > 0 else 0
frames.append((frame_base64, timestamp))
extracted_count += 1
logger.debug(f"提取第{extracted_count}帧 (时间: {timestamp:.2f}s)")
logger.debug(f"提取第{extracted_count}帧 (时间: {timestamp:.2f}s, 帧号: {target_frame})")
frame_count += 1
# 每提取一帧让步一次
await asyncio.sleep(0.001)
cap.release()
logger.info(f"✅ 成功提取{len(frames)}")
return frames