refactor(database): 将同步数据库操作迁移为异步操作
将整个项目的数据库操作从同步模式迁移为异步模式,主要涉及以下修改: - 将 `with get_db_session()` 改为 `async with get_db_session()` - 将同步的 SQLAlchemy 查询方法改为异步执行 - 更新相关的方法签名,添加 async/await 关键字 - 修复由于异步化导致的并发问题和性能问题 这些修改提高了数据库操作的并发性能,避免了阻塞主线程,提升了系统的整体响应能力。涉及修改的模块包括表情包管理、反提示注入统计、用户封禁管理、记忆系统、消息存储等多个核心组件。 BREAKING CHANGE: 所有涉及数据库操作的方法现在都需要使用异步调用,同步调用将不再工作
This commit is contained in:
@@ -829,7 +829,8 @@ async def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
|
||||
description = "[图片内容未知]" # 默认描述
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
image = session.execute(select(Images).where(Images.image_id == pic_id)).scalar_one_or_none()
|
||||
result = session.execute(select(Images).where(Images.image_id == pic_id))
|
||||
image = result.scalar_one_or_none()
|
||||
if image and image.description: # type: ignore
|
||||
description = image.description
|
||||
except Exception:
|
||||
|
||||
@@ -308,7 +308,8 @@ class ImageManager:
|
||||
|
||||
async with get_db_session() as session:
|
||||
# 优先检查Images表中是否已有完整的描述
|
||||
existing_image = (await session.execute(select(Images).where(Images.emoji_hash == image_hash))).scalar()
|
||||
existing_image = result = await session.execute(select(Images).where(Images.emoji_hash == image_hash))
|
||||
result.scalar()
|
||||
if existing_image:
|
||||
# 更新计数
|
||||
if hasattr(existing_image, "count") and existing_image.count is not None:
|
||||
@@ -527,7 +528,8 @@ class ImageManager:
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
async with get_db_session() as session:
|
||||
existing_image = (await session.execute(select(Images).where(Images.emoji_hash == image_hash))).scalar()
|
||||
existing_image = result = await session.execute(select(Images).where(Images.emoji_hash == image_hash))
|
||||
result.scalar()
|
||||
if existing_image:
|
||||
# 检查是否缺少必要字段,如果缺少则创建新记录
|
||||
if (
|
||||
|
||||
@@ -27,19 +27,32 @@ import time
|
||||
from PIL import Image
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.database.sqlalchemy_models import Videos, get_db_session # type: ignore
|
||||
from sqlalchemy import select, update, insert # type: ignore
|
||||
from sqlalchemy import exc as sa_exc # type: ignore
|
||||
|
||||
# 简易并发控制:同一 hash 只处理一次
|
||||
_video_locks: Dict[str, asyncio.Lock] = {}
|
||||
_locks_guard = asyncio.Lock()
|
||||
from src.common.database.sqlalchemy_models import get_db_session, Videos
|
||||
from sqlalchemy import select
|
||||
|
||||
logger = get_logger("utils_video")
|
||||
|
||||
from inkfox import video
|
||||
# Rust模块可用性检测
|
||||
RUST_VIDEO_AVAILABLE = False
|
||||
try:
|
||||
import rust_video # pyright: ignore[reportMissingImports]
|
||||
|
||||
RUST_VIDEO_AVAILABLE = True
|
||||
logger.info("✅ Rust 视频处理模块加载成功")
|
||||
except ImportError as e:
|
||||
logger.warning(f"⚠️ Rust 视频处理模块加载失败: {e}")
|
||||
logger.warning("⚠️ 视频识别功能将自动禁用")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 加载Rust模块时发生错误: {e}")
|
||||
RUST_VIDEO_AVAILABLE = False
|
||||
|
||||
# 全局正在处理的视频哈希集合,用于防止重复处理
|
||||
processing_videos = set()
|
||||
processing_lock = asyncio.Lock()
|
||||
# 为每个视频hash创建独立的锁和事件
|
||||
video_locks = {}
|
||||
video_events = {}
|
||||
video_lock_manager = asyncio.Lock()
|
||||
|
||||
|
||||
class VideoAnalyzer:
|
||||
@@ -192,7 +205,99 @@ class VideoAnalyzer:
|
||||
hash_obj.update(video_data)
|
||||
return hash_obj.hexdigest()
|
||||
|
||||
self._log_system()
|
||||
async def _check_video_exists(self, video_hash: str) -> Optional[Videos]:
|
||||
"""检查视频是否已经分析过"""
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
# 明确刷新会话以确保看到其他事务的最新提交
|
||||
await session.expire_all()
|
||||
stmt = select(Videos).where(Videos.video_hash == video_hash)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
except Exception as e:
|
||||
logger.warning(f"检查视频是否存在时出错: {e}")
|
||||
return None
|
||||
|
||||
async def _store_video_result(
|
||||
self, video_hash: str, description: str, metadata: Optional[Dict] = None
|
||||
) -> Optional[Videos]:
|
||||
"""存储视频分析结果到数据库"""
|
||||
# 检查描述是否为错误信息,如果是则不保存
|
||||
if description.startswith("❌"):
|
||||
logger.warning(f"⚠️ 检测到错误信息,不保存到数据库: {description[:50]}...")
|
||||
return None
|
||||
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
# 只根据video_hash查找
|
||||
stmt = select(Videos).where(Videos.video_hash == video_hash)
|
||||
result = await session.execute(stmt)
|
||||
existing_video = result.scalar_one_or_none()
|
||||
|
||||
if existing_video:
|
||||
# 如果已存在,更新描述和计数
|
||||
existing_video.description = description
|
||||
existing_video.count += 1
|
||||
existing_video.timestamp = time.time()
|
||||
if metadata:
|
||||
existing_video.duration = metadata.get("duration")
|
||||
existing_video.frame_count = metadata.get("frame_count")
|
||||
existing_video.fps = metadata.get("fps")
|
||||
existing_video.resolution = metadata.get("resolution")
|
||||
existing_video.file_size = metadata.get("file_size")
|
||||
await session.commit()
|
||||
await session.refresh(existing_video)
|
||||
logger.info(f"✅ 更新已存在的视频记录,hash: {video_hash[:16]}..., count: {existing_video.count}")
|
||||
return existing_video
|
||||
else:
|
||||
video_record = Videos(
|
||||
video_hash=video_hash, description=description, timestamp=time.time(), count=1
|
||||
)
|
||||
if metadata:
|
||||
video_record.duration = metadata.get("duration")
|
||||
video_record.frame_count = metadata.get("frame_count")
|
||||
video_record.fps = metadata.get("fps")
|
||||
video_record.resolution = metadata.get("resolution")
|
||||
video_record.file_size = metadata.get("file_size")
|
||||
|
||||
session.add(video_record)
|
||||
await session.commit()
|
||||
await session.refresh(video_record)
|
||||
logger.info(f"✅ 新视频分析结果已保存到数据库,hash: {video_hash[:16]}...")
|
||||
return video_record
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 存储视频分析结果时出错: {e}")
|
||||
return None
|
||||
|
||||
def set_analysis_mode(self, mode: str):
|
||||
"""设置分析模式"""
|
||||
if mode in ["batch", "sequential", "auto"]:
|
||||
self.analysis_mode = mode
|
||||
# logger.info(f"分析模式已设置为: {mode}")
|
||||
else:
|
||||
logger.warning(f"无效的分析模式: {mode}")
|
||||
|
||||
async def extract_frames(self, video_path: str) -> List[Tuple[str, float]]:
|
||||
"""提取视频帧 - 智能选择最佳实现"""
|
||||
# 检查是否应该使用Rust实现
|
||||
if RUST_VIDEO_AVAILABLE and self.frame_extraction_mode == "keyframe":
|
||||
# 优先尝试Rust关键帧提取
|
||||
try:
|
||||
return await self._extract_frames_rust_advanced(video_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Rust高级接口失败: {e},尝试基础接口")
|
||||
try:
|
||||
return await self._extract_frames_rust(video_path)
|
||||
except Exception as e2:
|
||||
logger.warning(f"Rust基础接口也失败: {e2},降级到Python实现")
|
||||
return await self._extract_frames_python_fallback(video_path)
|
||||
else:
|
||||
# 使用Python实现(支持time_interval和fixed_number模式)
|
||||
if not RUST_VIDEO_AVAILABLE:
|
||||
logger.info("🔄 Rust模块不可用,使用Python抽帧实现")
|
||||
else:
|
||||
logger.info(f"🔄 抽帧模式为 {self.frame_extraction_mode},使用Python抽帧实现")
|
||||
return await self._extract_frames_python_fallback(video_path)
|
||||
|
||||
# ---- 系统信息 ----
|
||||
def _log_system(self) -> None:
|
||||
@@ -308,31 +413,82 @@ class VideoAnalyzer:
|
||||
prompt: Optional[str] = None,
|
||||
question: Optional[str] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""从内存字节分析视频,兼容旧调用 (prompt / question 二选一) 返回 {"summary": str}."""
|
||||
if not video_bytes:
|
||||
return {"summary": "❌ 空视频数据"}
|
||||
# 兼容参数:prompt 优先,其次 question
|
||||
q = prompt if prompt is not None else question
|
||||
video_hash = hashlib.sha256(video_bytes).hexdigest()
|
||||
"""从字节数据分析视频
|
||||
|
||||
# 查缓存(第一次,未加锁)
|
||||
cached = await self._get_cached(video_hash)
|
||||
if cached:
|
||||
logger.info(f"视频缓存命中(预检查) hash={video_hash[:16]}")
|
||||
return {"summary": cached}
|
||||
Args:
|
||||
video_bytes: 视频字节数据
|
||||
filename: 文件名(可选,仅用于日志)
|
||||
user_question: 用户问题(旧参数名,保持兼容性)
|
||||
prompt: 提示词(新参数名,与系统调用保持一致)
|
||||
|
||||
# 获取锁避免重复处理
|
||||
async with _locks_guard:
|
||||
lock = _video_locks.get(video_hash)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
_video_locks[video_hash] = lock
|
||||
async with lock:
|
||||
# 双检缓存
|
||||
cached2 = await self._get_cached(video_hash)
|
||||
if cached2:
|
||||
logger.info(f"视频缓存命中(锁后) hash={video_hash[:16]}")
|
||||
return {"summary": cached2}
|
||||
Returns:
|
||||
Dict[str, str]: 包含分析结果的字典,格式为 {"summary": "分析结果"}
|
||||
"""
|
||||
if self.disabled:
|
||||
return {"summary": "❌ 视频分析功能已禁用:没有可用的视频处理实现"}
|
||||
|
||||
video_hash = None
|
||||
video_event = None
|
||||
|
||||
try:
|
||||
logger.info("开始从字节数据分析视频")
|
||||
|
||||
# 兼容性处理:如果传入了prompt参数,使用prompt;否则使用user_question
|
||||
question = prompt if prompt is not None else user_question
|
||||
|
||||
# 检查视频数据是否有效
|
||||
if not video_bytes:
|
||||
return {"summary": "❌ 视频数据为空"}
|
||||
|
||||
# 计算视频hash值
|
||||
video_hash = self._calculate_video_hash(video_bytes)
|
||||
logger.info(f"视频hash: {video_hash}")
|
||||
|
||||
# 改进的并发控制:使用每个视频独立的锁和事件
|
||||
async with video_lock_manager:
|
||||
if video_hash not in video_locks:
|
||||
video_locks[video_hash] = asyncio.Lock()
|
||||
video_events[video_hash] = asyncio.Event()
|
||||
|
||||
video_lock = video_locks[video_hash]
|
||||
video_event = video_events[video_hash]
|
||||
|
||||
# 尝试获取该视频的专用锁
|
||||
if video_lock.locked():
|
||||
logger.info(f"⏳ 相同视频正在处理中,等待处理完成... (hash: {video_hash[:16]}...)")
|
||||
try:
|
||||
# 等待处理完成的事件信号,最多等待60秒
|
||||
await asyncio.wait_for(video_event.wait(), timeout=60.0)
|
||||
logger.info("✅ 等待结束,检查是否有处理结果")
|
||||
|
||||
# 检查是否有结果了
|
||||
existing_video = await self._check_video_exists(video_hash)
|
||||
if existing_video:
|
||||
logger.info(f"✅ 找到了处理结果,直接返回 (id: {existing_video.id})")
|
||||
return {"summary": existing_video.description}
|
||||
else:
|
||||
logger.warning("⚠️ 等待完成但未找到结果,可能处理失败")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("⚠️ 等待超时(60秒),放弃等待")
|
||||
|
||||
# 获取锁开始处理
|
||||
async with video_lock:
|
||||
logger.info(f"🔒 获得视频处理锁,开始处理 (hash: {video_hash[:16]}...)")
|
||||
|
||||
# 再次检查数据库(可能在等待期间已经有结果了)
|
||||
existing_video = await self._check_video_exists(video_hash)
|
||||
if existing_video:
|
||||
logger.info(f"✅ 获得锁后发现已有结果,直接返回 (id: {existing_video.id})")
|
||||
video_event.set() # 通知其他等待者
|
||||
return {"summary": existing_video.description}
|
||||
|
||||
# 未找到已存在记录,开始新的分析
|
||||
logger.info("未找到已存在的视频记录,开始新的分析")
|
||||
|
||||
# 创建临时文件进行分析
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_file:
|
||||
temp_file.write(video_bytes)
|
||||
temp_path = temp_file.name
|
||||
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(delete=False) as fp:
|
||||
@@ -351,7 +507,7 @@ class VideoAnalyzer:
|
||||
# 保存分析结果到数据库(仅保存成功的结果)
|
||||
if success and not result.startswith("❌"):
|
||||
metadata = {"filename": filename, "file_size": len(video_bytes), "analysis_timestamp": time.time()}
|
||||
self._store_video_result(video_hash=video_hash, description=result, metadata=metadata)
|
||||
await self._store_video_result(video_hash=video_hash, description=result, metadata=metadata)
|
||||
logger.info("✅ 分析结果已保存到数据库")
|
||||
else:
|
||||
logger.warning("⚠️ 分析失败,不保存到数据库以便后续重试")
|
||||
|
||||
Reference in New Issue
Block a user