Refactor replyer retrieval to async and add memory formatter
Changed get_replyer and related calls to async across multiple modules for proper coroutine handling. Added format_memories_bracket_style utility for memory formatting. Improved video analysis caching logic and type annotations. Updated error logging for message processing.
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""纯 inkfox 视频关键帧分析工具
|
||||
|
||||
仅依赖 `inkfox.video` 提供的 Rust 扩展能力:
|
||||
@@ -13,27 +14,25 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import io
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import io
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import List, Tuple, Optional, Dict, Any
|
||||
import hashlib
|
||||
import time
|
||||
|
||||
from PIL import Image
|
||||
from sqlalchemy import exc as sa_exc # type: ignore
|
||||
from sqlalchemy import insert, select, update # type: ignore
|
||||
|
||||
from src.common.database.sqlalchemy_models import Videos, get_db_session # type: ignore
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.database.sqlalchemy_models import Videos, get_db_session # type: ignore
|
||||
|
||||
# 简易并发控制:同一 hash 只处理一次
|
||||
_video_locks: dict[str, asyncio.Lock] = {}
|
||||
_video_locks: Dict[str, asyncio.Lock] = {}
|
||||
_locks_guard = asyncio.Lock()
|
||||
|
||||
logger = get_logger("utils_video")
|
||||
@@ -91,7 +90,7 @@ class VideoAnalyzer:
|
||||
logger.debug(f"获取系统信息失败: {e}")
|
||||
|
||||
# ---- 关键帧提取 ----
|
||||
async def extract_keyframes(self, video_path: str) -> list[tuple[str, float]]:
|
||||
async def extract_keyframes(self, video_path: str) -> List[Tuple[str, float]]:
|
||||
"""提取关键帧并返回 (base64, timestamp_seconds) 列表"""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
result = video.extract_keyframes_from_video( # type: ignore[attr-defined]
|
||||
@@ -106,7 +105,7 @@ class VideoAnalyzer:
|
||||
)
|
||||
files = sorted(Path(tmp).glob("keyframe_*.jpg"))[: self.max_frames]
|
||||
total_ms = getattr(result, "total_time_ms", 0)
|
||||
frames: list[tuple[str, float]] = []
|
||||
frames: List[Tuple[str, float]] = []
|
||||
for i, f in enumerate(files):
|
||||
img = Image.open(f).convert("RGB")
|
||||
if max(img.size) > self.max_image_size:
|
||||
@@ -120,41 +119,38 @@ class VideoAnalyzer:
|
||||
return frames
|
||||
|
||||
# ---- 批量分析 ----
|
||||
async def _analyze_batch(self, frames: list[tuple[str, float]], question: str | None) -> str:
|
||||
from src.llm_models.payload_content.message import MessageBuilder
|
||||
async def _analyze_batch(self, frames: List[Tuple[str, float]], question: Optional[str]) -> str:
|
||||
from src.llm_models.payload_content.message import MessageBuilder, RoleType
|
||||
from src.llm_models.utils_model import RequestType
|
||||
|
||||
prompt = self.batch_analysis_prompt.format(
|
||||
personality_core=self.personality_core, personality_side=self.personality_side
|
||||
)
|
||||
if question:
|
||||
prompt += f"\n用户关注: {question}"
|
||||
|
||||
desc = [
|
||||
(f"第{i+1}帧 (时间: {ts:.2f}s)" if self.enable_frame_timing else f"第{i+1}帧")
|
||||
for i, (_b, ts) in enumerate(frames)
|
||||
]
|
||||
prompt += "\n帧列表: " + ", ".join(desc)
|
||||
|
||||
message_builder = MessageBuilder().add_text_content(prompt)
|
||||
mb = MessageBuilder().set_role(RoleType.User).add_text_content(prompt)
|
||||
for b64, _ in frames:
|
||||
message_builder.add_image_content(image_format="jpeg", image_base64=b64)
|
||||
messages = [message_builder.build()]
|
||||
|
||||
# 使用封装好的高级策略执行请求,而不是直接调用内部方法
|
||||
response, _ = await self.video_llm._strategy.execute_with_failover(
|
||||
RequestType.RESPONSE,
|
||||
raise_when_empty=False, # 即使失败也返回默认值,避免程序崩溃
|
||||
message_list=messages,
|
||||
temperature=self.video_llm.model_for_task.temperature,
|
||||
max_tokens=self.video_llm.model_for_task.max_tokens,
|
||||
mb.add_image_content("jpeg", b64)
|
||||
message = mb.build()
|
||||
model_info, api_provider, client = self.video_llm._select_model()
|
||||
resp = await self.video_llm._execute_request(
|
||||
api_provider=api_provider,
|
||||
client=client,
|
||||
request_type=RequestType.RESPONSE,
|
||||
model_info=model_info,
|
||||
message_list=[message],
|
||||
temperature=None,
|
||||
max_tokens=None,
|
||||
)
|
||||
|
||||
return response.content or "❌ 未获得响应"
|
||||
return resp.content or "❌ 未获得响应"
|
||||
|
||||
# ---- 逐帧分析 ----
|
||||
async def _analyze_sequential(self, frames: list[tuple[str, float]], question: str | None) -> str:
|
||||
results: list[str] = []
|
||||
async def _analyze_sequential(self, frames: List[Tuple[str, float]], question: Optional[str]) -> str:
|
||||
results: List[str] = []
|
||||
for i, (b64, ts) in enumerate(frames):
|
||||
prompt = f"分析第{i+1}帧" + (f" (时间: {ts:.2f}s)" if self.enable_frame_timing else "")
|
||||
if question:
|
||||
@@ -178,7 +174,7 @@ class VideoAnalyzer:
|
||||
return "\n".join(results)
|
||||
|
||||
# ---- 主入口 ----
|
||||
async def analyze_video(self, video_path: str, question: str | None = None) -> tuple[bool, str]:
|
||||
async def analyze_video(self, video_path: str, question: Optional[str] = None) -> Tuple[bool, str]:
|
||||
if not os.path.exists(video_path):
|
||||
return False, "❌ 文件不存在"
|
||||
frames = await self.extract_keyframes(video_path)
|
||||
@@ -193,10 +189,10 @@ class VideoAnalyzer:
|
||||
async def analyze_video_from_bytes(
|
||||
self,
|
||||
video_bytes: bytes,
|
||||
filename: str | None = None,
|
||||
prompt: str | None = None,
|
||||
question: str | None = None,
|
||||
) -> dict[str, str]:
|
||||
filename: Optional[str] = None,
|
||||
prompt: Optional[str] = None,
|
||||
question: Optional[str] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""从内存字节分析视频,兼容旧调用 (prompt / question 二选一) 返回 {"summary": str}."""
|
||||
if not video_bytes:
|
||||
return {"summary": "❌ 空视频数据"}
|
||||
@@ -204,11 +200,17 @@ class VideoAnalyzer:
|
||||
q = prompt if prompt is not None else question
|
||||
video_hash = hashlib.sha256(video_bytes).hexdigest()
|
||||
|
||||
# 查缓存(第一次,未加锁)
|
||||
cached = await self._get_cached(video_hash)
|
||||
if cached:
|
||||
logger.info(f"视频缓存命中(预检查) hash={video_hash[:16]}")
|
||||
return {"summary": cached}
|
||||
# 查缓存
|
||||
try:
|
||||
async with get_db_session() as session: # type: ignore
|
||||
row = await session.execute(
|
||||
Videos.__table__.select().where(Videos.video_hash == video_hash) # type: ignore
|
||||
)
|
||||
existing = row.first()
|
||||
if existing and existing[Videos.description] and existing[Videos.vlm_processed]: # type: ignore
|
||||
return {"summary": existing[Videos.description]} # type: ignore
|
||||
except Exception: # pragma: no cover
|
||||
pass
|
||||
|
||||
# 获取锁避免重复处理
|
||||
async with _locks_guard:
|
||||
@@ -217,11 +219,17 @@ class VideoAnalyzer:
|
||||
lock = asyncio.Lock()
|
||||
_video_locks[video_hash] = lock
|
||||
async with lock:
|
||||
# 双检缓存
|
||||
cached2 = await self._get_cached(video_hash)
|
||||
if cached2:
|
||||
logger.info(f"视频缓存命中(锁后) hash={video_hash[:16]}")
|
||||
return {"summary": cached2}
|
||||
# 双检:进入锁后再查一次,避免重复处理
|
||||
try:
|
||||
async with get_db_session() as session: # type: ignore
|
||||
row = await session.execute(
|
||||
Videos.__table__.select().where(Videos.video_hash == video_hash) # type: ignore
|
||||
)
|
||||
existing = row.first()
|
||||
if existing and existing[Videos.description] and existing[Videos.vlm_processed]: # type: ignore
|
||||
return {"summary": existing[Videos.description]} # type: ignore
|
||||
except Exception: # pragma: no cover
|
||||
pass
|
||||
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(delete=False) as fp:
|
||||
@@ -231,7 +239,26 @@ class VideoAnalyzer:
|
||||
ok, summary = await self.analyze_video(temp_path, q)
|
||||
# 写入缓存(仅成功)
|
||||
if ok:
|
||||
await self._save_cache(video_hash, summary, len(video_bytes))
|
||||
try:
|
||||
async with get_db_session() as session: # type: ignore
|
||||
await session.execute(
|
||||
Videos.__table__.insert().values(
|
||||
video_id="",
|
||||
video_hash=video_hash,
|
||||
description=summary,
|
||||
count=1,
|
||||
timestamp=time.time(),
|
||||
vlm_processed=True,
|
||||
duration=None,
|
||||
frame_count=None,
|
||||
fps=None,
|
||||
resolution=None,
|
||||
file_size=len(video_bytes),
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
except Exception: # pragma: no cover
|
||||
pass
|
||||
return {"summary": summary}
|
||||
finally:
|
||||
if os.path.exists(temp_path):
|
||||
@@ -242,57 +269,9 @@ class VideoAnalyzer:
|
||||
except Exception as e: # pragma: no cover
|
||||
return {"summary": f"❌ 处理失败: {e}"}
|
||||
|
||||
# ---- 缓存辅助 ----
|
||||
async def _get_cached(self, video_hash: str) -> str | None:
|
||||
try:
|
||||
async with get_db_session() as session: # type: ignore
|
||||
result = await session.execute(select(Videos).where(Videos.video_hash == video_hash)) # type: ignore
|
||||
obj: Videos | None = result.scalar_one_or_none() # type: ignore
|
||||
if obj and obj.vlm_processed and obj.description:
|
||||
# 更新使用次数
|
||||
try:
|
||||
await session.execute(
|
||||
update(Videos)
|
||||
.where(Videos.id == obj.id) # type: ignore
|
||||
.values(count=obj.count + 1 if obj.count is not None else 1)
|
||||
)
|
||||
await session.commit()
|
||||
except Exception: # pragma: no cover
|
||||
await session.rollback()
|
||||
return obj.description
|
||||
except Exception: # pragma: no cover
|
||||
pass
|
||||
return None
|
||||
|
||||
async def _save_cache(self, video_hash: str, summary: str, file_size: int) -> None:
|
||||
try:
|
||||
async with get_db_session() as session: # type: ignore
|
||||
stmt = insert(Videos).values( # type: ignore
|
||||
video_id="",
|
||||
video_hash=video_hash,
|
||||
description=summary,
|
||||
count=1,
|
||||
timestamp=time.time(),
|
||||
vlm_processed=True,
|
||||
duration=None,
|
||||
frame_count=None,
|
||||
fps=None,
|
||||
resolution=None,
|
||||
file_size=file_size,
|
||||
)
|
||||
try:
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
logger.debug(f"视频缓存写入 success hash={video_hash}")
|
||||
except sa_exc.IntegrityError: # 可能并发已写入
|
||||
await session.rollback()
|
||||
logger.debug(f"视频缓存已存在 hash={video_hash}")
|
||||
except Exception: # pragma: no cover
|
||||
logger.debug("视频缓存写入失败")
|
||||
|
||||
|
||||
# ---- 外部接口 ----
|
||||
_INSTANCE: VideoAnalyzer | None = None
|
||||
_INSTANCE: Optional[VideoAnalyzer] = None
|
||||
|
||||
|
||||
def get_video_analyzer() -> VideoAnalyzer:
|
||||
@@ -306,7 +285,7 @@ def is_video_analysis_available() -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def get_video_analysis_status() -> dict[str, Any]:
|
||||
def get_video_analysis_status() -> Dict[str, Any]:
|
||||
try:
|
||||
info = video.get_system_info() # type: ignore[attr-defined]
|
||||
except Exception as e: # pragma: no cover
|
||||
@@ -318,4 +297,4 @@ def get_video_analysis_status() -> dict[str, Any]:
|
||||
"modes": ["auto", "batch", "sequential"],
|
||||
"max_frames_default": inst.max_frames,
|
||||
"implementation": "inkfox",
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user