Refactor replyer retrieval to async and add memory formatter

Changed get_replyer and related calls to async across multiple modules for proper coroutine handling. Added format_memories_bracket_style utility for memory formatting. Improved video analysis caching logic and type annotations. Updated error logging for message processing.
This commit is contained in:
雅诺狐
2025-10-05 17:48:28 +08:00
parent f4404e09ef
commit 86377d983f
8 changed files with 215 additions and 112 deletions

View File

@@ -1,4 +1,5 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""纯 inkfox 视频关键帧分析工具
仅依赖 `inkfox.video` 提供的 Rust 扩展能力:
@@ -13,27 +14,25 @@
from __future__ import annotations
import os
import io
import asyncio
import base64
import hashlib
import io
import os
import tempfile
import time
from pathlib import Path
from typing import Any
from typing import List, Tuple, Optional, Dict, Any
import hashlib
import time
from PIL import Image
from sqlalchemy import exc as sa_exc # type: ignore
from sqlalchemy import insert, select, update # type: ignore
from src.common.database.sqlalchemy_models import Videos, get_db_session # type: ignore
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.common.database.sqlalchemy_models import Videos, get_db_session # type: ignore
# 简易并发控制:同一 hash 只处理一次
_video_locks: dict[str, asyncio.Lock] = {}
_video_locks: Dict[str, asyncio.Lock] = {}
_locks_guard = asyncio.Lock()
logger = get_logger("utils_video")
@@ -91,7 +90,7 @@ class VideoAnalyzer:
logger.debug(f"获取系统信息失败: {e}")
# ---- 关键帧提取 ----
async def extract_keyframes(self, video_path: str) -> list[tuple[str, float]]:
async def extract_keyframes(self, video_path: str) -> List[Tuple[str, float]]:
"""提取关键帧并返回 (base64, timestamp_seconds) 列表"""
with tempfile.TemporaryDirectory() as tmp:
result = video.extract_keyframes_from_video( # type: ignore[attr-defined]
@@ -106,7 +105,7 @@ class VideoAnalyzer:
)
files = sorted(Path(tmp).glob("keyframe_*.jpg"))[: self.max_frames]
total_ms = getattr(result, "total_time_ms", 0)
frames: list[tuple[str, float]] = []
frames: List[Tuple[str, float]] = []
for i, f in enumerate(files):
img = Image.open(f).convert("RGB")
if max(img.size) > self.max_image_size:
@@ -120,41 +119,38 @@ class VideoAnalyzer:
return frames
# ---- 批量分析 ----
async def _analyze_batch(self, frames: list[tuple[str, float]], question: str | None) -> str:
from src.llm_models.payload_content.message import MessageBuilder
async def _analyze_batch(self, frames: List[Tuple[str, float]], question: Optional[str]) -> str:
from src.llm_models.payload_content.message import MessageBuilder, RoleType
from src.llm_models.utils_model import RequestType
prompt = self.batch_analysis_prompt.format(
personality_core=self.personality_core, personality_side=self.personality_side
)
if question:
prompt += f"\n用户关注: {question}"
desc = [
(f"{i+1}帧 (时间: {ts:.2f}s)" if self.enable_frame_timing else f"{i+1}")
for i, (_b, ts) in enumerate(frames)
]
prompt += "\n帧列表: " + ", ".join(desc)
message_builder = MessageBuilder().add_text_content(prompt)
mb = MessageBuilder().set_role(RoleType.User).add_text_content(prompt)
for b64, _ in frames:
message_builder.add_image_content(image_format="jpeg", image_base64=b64)
messages = [message_builder.build()]
# 使用封装好的高级策略执行请求,而不是直接调用内部方法
response, _ = await self.video_llm._strategy.execute_with_failover(
RequestType.RESPONSE,
raise_when_empty=False, # 即使失败也返回默认值,避免程序崩溃
message_list=messages,
temperature=self.video_llm.model_for_task.temperature,
max_tokens=self.video_llm.model_for_task.max_tokens,
mb.add_image_content("jpeg", b64)
message = mb.build()
model_info, api_provider, client = self.video_llm._select_model()
resp = await self.video_llm._execute_request(
api_provider=api_provider,
client=client,
request_type=RequestType.RESPONSE,
model_info=model_info,
message_list=[message],
temperature=None,
max_tokens=None,
)
return response.content or "❌ 未获得响应"
return resp.content or "❌ 未获得响应"
# ---- 逐帧分析 ----
async def _analyze_sequential(self, frames: list[tuple[str, float]], question: str | None) -> str:
results: list[str] = []
async def _analyze_sequential(self, frames: List[Tuple[str, float]], question: Optional[str]) -> str:
results: List[str] = []
for i, (b64, ts) in enumerate(frames):
prompt = f"分析第{i+1}" + (f" (时间: {ts:.2f}s)" if self.enable_frame_timing else "")
if question:
@@ -178,7 +174,7 @@ class VideoAnalyzer:
return "\n".join(results)
# ---- 主入口 ----
async def analyze_video(self, video_path: str, question: str | None = None) -> tuple[bool, str]:
async def analyze_video(self, video_path: str, question: Optional[str] = None) -> Tuple[bool, str]:
if not os.path.exists(video_path):
return False, "❌ 文件不存在"
frames = await self.extract_keyframes(video_path)
@@ -193,10 +189,10 @@ class VideoAnalyzer:
async def analyze_video_from_bytes(
self,
video_bytes: bytes,
filename: str | None = None,
prompt: str | None = None,
question: str | None = None,
) -> dict[str, str]:
filename: Optional[str] = None,
prompt: Optional[str] = None,
question: Optional[str] = None,
) -> Dict[str, str]:
"""从内存字节分析视频,兼容旧调用 (prompt / question 二选一) 返回 {"summary": str}."""
if not video_bytes:
return {"summary": "❌ 空视频数据"}
@@ -204,11 +200,17 @@ class VideoAnalyzer:
q = prompt if prompt is not None else question
video_hash = hashlib.sha256(video_bytes).hexdigest()
# 查缓存(第一次,未加锁)
cached = await self._get_cached(video_hash)
if cached:
logger.info(f"视频缓存命中(预检查) hash={video_hash[:16]}")
return {"summary": cached}
# 查缓存
try:
async with get_db_session() as session: # type: ignore
row = await session.execute(
Videos.__table__.select().where(Videos.video_hash == video_hash) # type: ignore
)
existing = row.first()
if existing and existing[Videos.description] and existing[Videos.vlm_processed]: # type: ignore
return {"summary": existing[Videos.description]} # type: ignore
except Exception: # pragma: no cover
pass
# 获取锁避免重复处理
async with _locks_guard:
@@ -217,11 +219,17 @@ class VideoAnalyzer:
lock = asyncio.Lock()
_video_locks[video_hash] = lock
async with lock:
# 双检缓存
cached2 = await self._get_cached(video_hash)
if cached2:
logger.info(f"视频缓存命中(锁后) hash={video_hash[:16]}")
return {"summary": cached2}
# 双检:进入锁后再查一次,避免重复处理
try:
async with get_db_session() as session: # type: ignore
row = await session.execute(
Videos.__table__.select().where(Videos.video_hash == video_hash) # type: ignore
)
existing = row.first()
if existing and existing[Videos.description] and existing[Videos.vlm_processed]: # type: ignore
return {"summary": existing[Videos.description]} # type: ignore
except Exception: # pragma: no cover
pass
try:
with tempfile.NamedTemporaryFile(delete=False) as fp:
@@ -231,7 +239,26 @@ class VideoAnalyzer:
ok, summary = await self.analyze_video(temp_path, q)
# 写入缓存(仅成功)
if ok:
await self._save_cache(video_hash, summary, len(video_bytes))
try:
async with get_db_session() as session: # type: ignore
await session.execute(
Videos.__table__.insert().values(
video_id="",
video_hash=video_hash,
description=summary,
count=1,
timestamp=time.time(),
vlm_processed=True,
duration=None,
frame_count=None,
fps=None,
resolution=None,
file_size=len(video_bytes),
)
)
await session.commit()
except Exception: # pragma: no cover
pass
return {"summary": summary}
finally:
if os.path.exists(temp_path):
@@ -242,57 +269,9 @@ class VideoAnalyzer:
except Exception as e: # pragma: no cover
return {"summary": f"❌ 处理失败: {e}"}
# ---- 缓存辅助 ----
async def _get_cached(self, video_hash: str) -> str | None:
try:
async with get_db_session() as session: # type: ignore
result = await session.execute(select(Videos).where(Videos.video_hash == video_hash)) # type: ignore
obj: Videos | None = result.scalar_one_or_none() # type: ignore
if obj and obj.vlm_processed and obj.description:
# 更新使用次数
try:
await session.execute(
update(Videos)
.where(Videos.id == obj.id) # type: ignore
.values(count=obj.count + 1 if obj.count is not None else 1)
)
await session.commit()
except Exception: # pragma: no cover
await session.rollback()
return obj.description
except Exception: # pragma: no cover
pass
return None
async def _save_cache(self, video_hash: str, summary: str, file_size: int) -> None:
try:
async with get_db_session() as session: # type: ignore
stmt = insert(Videos).values( # type: ignore
video_id="",
video_hash=video_hash,
description=summary,
count=1,
timestamp=time.time(),
vlm_processed=True,
duration=None,
frame_count=None,
fps=None,
resolution=None,
file_size=file_size,
)
try:
await session.execute(stmt)
await session.commit()
logger.debug(f"视频缓存写入 success hash={video_hash}")
except sa_exc.IntegrityError: # 可能并发已写入
await session.rollback()
logger.debug(f"视频缓存已存在 hash={video_hash}")
except Exception: # pragma: no cover
logger.debug("视频缓存写入失败")
# ---- 外部接口 ----
_INSTANCE: VideoAnalyzer | None = None
_INSTANCE: Optional[VideoAnalyzer] = None
def get_video_analyzer() -> VideoAnalyzer:
@@ -306,7 +285,7 @@ def is_video_analysis_available() -> bool:
return True
def get_video_analysis_status() -> dict[str, Any]:
def get_video_analysis_status() -> Dict[str, Any]:
try:
info = video.get_system_info() # type: ignore[attr-defined]
except Exception as e: # pragma: no cover
@@ -318,4 +297,4 @@ def get_video_analysis_status() -> dict[str, Any]:
"modes": ["auto", "batch", "sequential"],
"max_frames_default": inst.max_frames,
"implementation": "inkfox",
}
}