Refactor replyer retrieval to async and add memory formatter

Changed get_replyer and related calls to async across multiple modules for proper coroutine handling. Added format_memories_bracket_style utility for memory formatting. Improved video analysis caching logic and type annotations. Updated error logging for message processing.
This commit is contained in:
雅诺狐
2025-10-05 17:48:28 +08:00
parent f4404e09ef
commit 86377d983f
8 changed files with 215 additions and 112 deletions

View File

@@ -30,6 +30,7 @@ from .memory_system import MemorySystem, MemorySystemConfig, get_memory_system,
# Vector DB存储系统
from .vector_memory_storage_v2 import VectorMemoryStorage, VectorStorageConfig, get_vector_memory_storage
from .memory_formatter import format_memories_bracket_style
__all__ = [
# 核心数据结构
@@ -62,6 +63,8 @@ __all__ = [
"MemoryActivator",
"memory_activator",
"enhanced_memory_activator", # 兼容性别名
# 格式化工具
"format_memories_bracket_style",
]
# 版本信息

View File

@@ -0,0 +1,118 @@
"""记忆格式化工具
提供统一的记忆块格式化函数,供构建 Prompt 时使用。
当前使用的函数: format_memories_bracket_style
输入: list[dict] 其中每个元素包含:
- display: str 记忆可读内容
- memory_type: str 记忆类型 (personal_fact/opinion/preference/event 等)
- metadata: dict 可选,包括
- confidence: 置信度 (str|float)
- importance: 重要度 (str|float)
- timestamp: 时间戳 (float|str)
- source: 来源 (str)
- relevance_score: 相关度 (float)
返回: 适合直接嵌入提示词的大段文本;若无有效记忆返回空串。
"""
from __future__ import annotations
from typing import Any, Iterable
import time
def _format_timestamp(ts: Any) -> str:
try:
if ts in (None, ""):
return ""
if isinstance(ts, (int, float)) and ts > 0:
return time.strftime("%Y-%m-%d %H:%M", time.localtime(float(ts)))
return str(ts)
except Exception:
return ""
def _coerce_str(v: Any) -> str:
if v is None:
return ""
return str(v)
def format_memories_bracket_style(
memories: Iterable[dict[str, Any]] | None,
query_context: str | None = None,
max_items: int = 15,
) -> str:
"""以方括号 + 标注字段的方式格式化记忆列表。
例子输出:
## 相关记忆回顾
- [类型:personal_fact|重要:高|置信:0.83|相关:0.72] 他喜欢黑咖啡 (来源: chat, 2025-10-05 09:30)
Args:
memories: 记忆字典迭代器
query_context: 当前查询/用户的消息,用于在首行提示(可选)
max_items: 最多输出的记忆条数
Returns:
str: 格式化文本;若无内容返回空串
"""
if not memories:
return ""
lines: list[str] = ["## 相关记忆回顾"]
if query_context:
lines.append(f"(与当前消息相关:{query_context[:60]}{'...' if len(query_context) > 60 else ''}")
lines.append("")
count = 0
for mem in memories:
if count >= max_items:
break
if not isinstance(mem, dict):
continue
display = _coerce_str(mem.get("display", "")).strip()
if not display:
continue
mtype = _coerce_str(mem.get("memory_type", "fact")) or "fact"
meta = mem.get("metadata", {}) if isinstance(mem.get("metadata"), dict) else {}
confidence = _coerce_str(meta.get("confidence", ""))
importance = _coerce_str(meta.get("importance", ""))
source = _coerce_str(meta.get("source", ""))
rel = meta.get("relevance_score")
try:
rel_str = f"{float(rel):.2f}" if rel is not None else ""
except Exception:
rel_str = ""
ts = _format_timestamp(meta.get("timestamp"))
# 构建标签段
tags: list[str] = [f"类型:{mtype}"]
if importance:
tags.append(f"重要:{importance}")
if confidence:
tags.append(f"置信:{confidence}")
if rel_str:
tags.append(f"相关:{rel_str}")
tag_block = "|".join(tags)
suffix_parts = []
if source:
suffix_parts.append(source)
if ts:
suffix_parts.append(ts)
suffix = (" (" + ", ".join(suffix_parts) + ")") if suffix_parts else ""
lines.append(f"- [{tag_block}] {display}{suffix}")
count += 1
if count == 0:
return ""
if count >= max_items:
lines.append(f"\n(已截断,仅显示前 {max_items} 条相关记忆)")
return "\n".join(lines)
__all__ = ["format_memories_bracket_style"]

View File

@@ -997,7 +997,8 @@ class MemorySystem:
from src.chat.message_receive.chat_stream import get_chat_manager
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
# get_stream 为异步方法,需要 await
chat_stream = await chat_manager.get_stream(stream_id)
if not chat_stream or not hasattr(chat_stream, "context_manager"):
logger.debug(f"未找到stream_id={stream_id}的聊天流或上下文管理器")
@@ -1111,7 +1112,8 @@ class MemorySystem:
from src.chat.message_receive.chat_stream import get_chat_manager
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
# ChatManager.get_stream 是异步方法,需要 await否则会产生 "coroutine was never awaited" 警告
chat_stream = await chat_manager.get_stream(stream_id)
if chat_stream and hasattr(chat_stream, "context_manager"):
history_limit = self._determine_history_limit(context)
messages = chat_stream.context_manager.get_messages(limit=history_limit, include_unread=True)

View File

@@ -263,7 +263,7 @@ class MessageRecv(Message):
logger.warning("视频消息中没有base64数据")
return "[收到视频消息,但数据异常]"
except Exception as e:
logger.error(f"视频处理失败: {e!s}")
logger.error(f"视频处理失败: {str(e)}")
import traceback
logger.error(f"错误详情: {traceback.format_exc()}")
@@ -277,7 +277,7 @@ class MessageRecv(Message):
logger.info("未启用视频识别")
return "[视频]"
except Exception as e:
logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}")
logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}")
return f"[处理失败的{segment.type}消息]"
@@ -427,7 +427,7 @@ class MessageRecvS4U(MessageRecv):
# 使用video analyzer分析视频
video_analyzer = get_video_analyzer()
result = await video_analyzer.analyze_video_from_bytes(
result = await video_analyzer.analyze_video(
video_bytes, filename, prompt=global_config.video_analysis.batch_analysis_prompt
)

View File

@@ -9,7 +9,7 @@ class ReplyerManager:
def __init__(self):
self._repliers: dict[str, DefaultReplyer] = {}
def get_replyer(
async def get_replyer(
self,
chat_stream: ChatStream | None = None,
chat_id: str | None = None,
@@ -37,7 +37,8 @@ class ReplyerManager:
target_stream = chat_stream
if not target_stream:
if chat_manager := get_chat_manager():
target_stream = chat_manager.get_stream(stream_id)
# get_stream 为异步,需要等待
target_stream = await chat_manager.get_stream(stream_id)
if not target_stream:
logger.warning(f"[ReplyerManager] 未找到 stream_id='{stream_id}' 的聊天流,无法创建回复器。")

View File

@@ -492,7 +492,7 @@ class Prompt:
from src.plugin_system.apis.generator_api import get_replyer
# 创建临时生成器实例来使用其方法
temp_generator = get_replyer(None, chat_id, request_type="prompt_building")
temp_generator = await get_replyer(None, chat_id, request_type="prompt_building")
return await temp_generator.build_s4u_chat_history_prompts(
message_list_before_now, target_user_id, sender, chat_id
)

View File

@@ -1,4 +1,5 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""纯 inkfox 视频关键帧分析工具
仅依赖 `inkfox.video` 提供的 Rust 扩展能力:
@@ -13,27 +14,25 @@
from __future__ import annotations
import os
import io
import asyncio
import base64
import hashlib
import io
import os
import tempfile
import time
from pathlib import Path
from typing import Any
from typing import List, Tuple, Optional, Dict, Any
import hashlib
import time
from PIL import Image
from sqlalchemy import exc as sa_exc # type: ignore
from sqlalchemy import insert, select, update # type: ignore
from src.common.database.sqlalchemy_models import Videos, get_db_session # type: ignore
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.common.database.sqlalchemy_models import Videos, get_db_session # type: ignore
# 简易并发控制:同一 hash 只处理一次
_video_locks: dict[str, asyncio.Lock] = {}
_video_locks: Dict[str, asyncio.Lock] = {}
_locks_guard = asyncio.Lock()
logger = get_logger("utils_video")
@@ -91,7 +90,7 @@ class VideoAnalyzer:
logger.debug(f"获取系统信息失败: {e}")
# ---- 关键帧提取 ----
async def extract_keyframes(self, video_path: str) -> list[tuple[str, float]]:
async def extract_keyframes(self, video_path: str) -> List[Tuple[str, float]]:
"""提取关键帧并返回 (base64, timestamp_seconds) 列表"""
with tempfile.TemporaryDirectory() as tmp:
result = video.extract_keyframes_from_video( # type: ignore[attr-defined]
@@ -106,7 +105,7 @@ class VideoAnalyzer:
)
files = sorted(Path(tmp).glob("keyframe_*.jpg"))[: self.max_frames]
total_ms = getattr(result, "total_time_ms", 0)
frames: list[tuple[str, float]] = []
frames: List[Tuple[str, float]] = []
for i, f in enumerate(files):
img = Image.open(f).convert("RGB")
if max(img.size) > self.max_image_size:
@@ -120,41 +119,38 @@ class VideoAnalyzer:
return frames
# ---- 批量分析 ----
async def _analyze_batch(self, frames: list[tuple[str, float]], question: str | None) -> str:
from src.llm_models.payload_content.message import MessageBuilder
async def _analyze_batch(self, frames: List[Tuple[str, float]], question: Optional[str]) -> str:
from src.llm_models.payload_content.message import MessageBuilder, RoleType
from src.llm_models.utils_model import RequestType
prompt = self.batch_analysis_prompt.format(
personality_core=self.personality_core, personality_side=self.personality_side
)
if question:
prompt += f"\n用户关注: {question}"
desc = [
(f"{i+1}帧 (时间: {ts:.2f}s)" if self.enable_frame_timing else f"{i+1}")
for i, (_b, ts) in enumerate(frames)
]
prompt += "\n帧列表: " + ", ".join(desc)
message_builder = MessageBuilder().add_text_content(prompt)
mb = MessageBuilder().set_role(RoleType.User).add_text_content(prompt)
for b64, _ in frames:
message_builder.add_image_content(image_format="jpeg", image_base64=b64)
messages = [message_builder.build()]
# 使用封装好的高级策略执行请求,而不是直接调用内部方法
response, _ = await self.video_llm._strategy.execute_with_failover(
RequestType.RESPONSE,
raise_when_empty=False, # 即使失败也返回默认值,避免程序崩溃
message_list=messages,
temperature=self.video_llm.model_for_task.temperature,
max_tokens=self.video_llm.model_for_task.max_tokens,
mb.add_image_content("jpeg", b64)
message = mb.build()
model_info, api_provider, client = self.video_llm._select_model()
resp = await self.video_llm._execute_request(
api_provider=api_provider,
client=client,
request_type=RequestType.RESPONSE,
model_info=model_info,
message_list=[message],
temperature=None,
max_tokens=None,
)
return response.content or "❌ 未获得响应"
return resp.content or "❌ 未获得响应"
# ---- 逐帧分析 ----
async def _analyze_sequential(self, frames: list[tuple[str, float]], question: str | None) -> str:
results: list[str] = []
async def _analyze_sequential(self, frames: List[Tuple[str, float]], question: Optional[str]) -> str:
results: List[str] = []
for i, (b64, ts) in enumerate(frames):
prompt = f"分析第{i+1}" + (f" (时间: {ts:.2f}s)" if self.enable_frame_timing else "")
if question:
@@ -178,7 +174,7 @@ class VideoAnalyzer:
return "\n".join(results)
# ---- 主入口 ----
async def analyze_video(self, video_path: str, question: str | None = None) -> tuple[bool, str]:
async def analyze_video(self, video_path: str, question: Optional[str] = None) -> Tuple[bool, str]:
if not os.path.exists(video_path):
return False, "❌ 文件不存在"
frames = await self.extract_keyframes(video_path)
@@ -193,10 +189,10 @@ class VideoAnalyzer:
async def analyze_video_from_bytes(
self,
video_bytes: bytes,
filename: str | None = None,
prompt: str | None = None,
question: str | None = None,
) -> dict[str, str]:
filename: Optional[str] = None,
prompt: Optional[str] = None,
question: Optional[str] = None,
) -> Dict[str, str]:
"""从内存字节分析视频,兼容旧调用 (prompt / question 二选一) 返回 {"summary": str}."""
if not video_bytes:
return {"summary": "❌ 空视频数据"}
@@ -204,11 +200,17 @@ class VideoAnalyzer:
q = prompt if prompt is not None else question
video_hash = hashlib.sha256(video_bytes).hexdigest()
# 查缓存(第一次,未加锁)
cached = await self._get_cached(video_hash)
if cached:
logger.info(f"视频缓存命中(预检查) hash={video_hash[:16]}")
return {"summary": cached}
# 查缓存
try:
async with get_db_session() as session: # type: ignore
row = await session.execute(
Videos.__table__.select().where(Videos.video_hash == video_hash) # type: ignore
)
existing = row.first()
if existing and existing[Videos.description] and existing[Videos.vlm_processed]: # type: ignore
return {"summary": existing[Videos.description]} # type: ignore
except Exception: # pragma: no cover
pass
# 获取锁避免重复处理
async with _locks_guard:
@@ -217,11 +219,17 @@ class VideoAnalyzer:
lock = asyncio.Lock()
_video_locks[video_hash] = lock
async with lock:
# 双检缓存
cached2 = await self._get_cached(video_hash)
if cached2:
logger.info(f"视频缓存命中(锁后) hash={video_hash[:16]}")
return {"summary": cached2}
# 双检:进入锁后再查一次,避免重复处理
try:
async with get_db_session() as session: # type: ignore
row = await session.execute(
Videos.__table__.select().where(Videos.video_hash == video_hash) # type: ignore
)
existing = row.first()
if existing and existing[Videos.description] and existing[Videos.vlm_processed]: # type: ignore
return {"summary": existing[Videos.description]} # type: ignore
except Exception: # pragma: no cover
pass
try:
with tempfile.NamedTemporaryFile(delete=False) as fp:
@@ -231,7 +239,26 @@ class VideoAnalyzer:
ok, summary = await self.analyze_video(temp_path, q)
# 写入缓存(仅成功)
if ok:
await self._save_cache(video_hash, summary, len(video_bytes))
try:
async with get_db_session() as session: # type: ignore
await session.execute(
Videos.__table__.insert().values(
video_id="",
video_hash=video_hash,
description=summary,
count=1,
timestamp=time.time(),
vlm_processed=True,
duration=None,
frame_count=None,
fps=None,
resolution=None,
file_size=len(video_bytes),
)
)
await session.commit()
except Exception: # pragma: no cover
pass
return {"summary": summary}
finally:
if os.path.exists(temp_path):
@@ -242,57 +269,9 @@ class VideoAnalyzer:
except Exception as e: # pragma: no cover
return {"summary": f"❌ 处理失败: {e}"}
# ---- 缓存辅助 ----
async def _get_cached(self, video_hash: str) -> str | None:
try:
async with get_db_session() as session: # type: ignore
result = await session.execute(select(Videos).where(Videos.video_hash == video_hash)) # type: ignore
obj: Videos | None = result.scalar_one_or_none() # type: ignore
if obj and obj.vlm_processed and obj.description:
# 更新使用次数
try:
await session.execute(
update(Videos)
.where(Videos.id == obj.id) # type: ignore
.values(count=obj.count + 1 if obj.count is not None else 1)
)
await session.commit()
except Exception: # pragma: no cover
await session.rollback()
return obj.description
except Exception: # pragma: no cover
pass
return None
async def _save_cache(self, video_hash: str, summary: str, file_size: int) -> None:
try:
async with get_db_session() as session: # type: ignore
stmt = insert(Videos).values( # type: ignore
video_id="",
video_hash=video_hash,
description=summary,
count=1,
timestamp=time.time(),
vlm_processed=True,
duration=None,
frame_count=None,
fps=None,
resolution=None,
file_size=file_size,
)
try:
await session.execute(stmt)
await session.commit()
logger.debug(f"视频缓存写入 success hash={video_hash}")
except sa_exc.IntegrityError: # 可能并发已写入
await session.rollback()
logger.debug(f"视频缓存已存在 hash={video_hash}")
except Exception: # pragma: no cover
logger.debug("视频缓存写入失败")
# ---- 外部接口 ----
_INSTANCE: VideoAnalyzer | None = None
_INSTANCE: Optional[VideoAnalyzer] = None
def get_video_analyzer() -> VideoAnalyzer:
@@ -306,7 +285,7 @@ def is_video_analysis_available() -> bool:
return True
def get_video_analysis_status() -> dict[str, Any]:
def get_video_analysis_status() -> Dict[str, Any]:
try:
info = video.get_system_info() # type: ignore[attr-defined]
except Exception as e: # pragma: no cover
@@ -318,4 +297,4 @@ def get_video_analysis_status() -> dict[str, Any]:
"modes": ["auto", "batch", "sequential"],
"max_frames_default": inst.max_frames,
"implementation": "inkfox",
}
}

View File

@@ -4,7 +4,7 @@
提供回复器相关功能采用标准Python包设计模式
使用方式:
from src.plugin_system.apis import generator_api
replyer = generator_api.get_replyer(chat_stream)
replyer = await generator_api.get_replyer(chat_stream)
success, reply_set, _ = await generator_api.generate_reply(chat_stream, action_data, reasoning)
"""
@@ -31,7 +31,7 @@ logger = get_logger("generator_api")
# =============================================================================
def get_replyer(
async def get_replyer(
chat_stream: ChatStream | None = None,
chat_id: str | None = None,
request_type: str = "replyer",
@@ -56,7 +56,7 @@ def get_replyer(
raise ValueError("chat_stream 和 chat_id 不可均为空")
try:
logger.debug(f"[GeneratorAPI] 正在获取回复器chat_id: {chat_id}, chat_stream: {'' if chat_stream else ''}")
return replyer_manager.get_replyer(
return await replyer_manager.get_replyer(
chat_stream=chat_stream,
chat_id=chat_id,
request_type=request_type,
@@ -110,7 +110,7 @@ async def generate_reply(
"""
try:
# 获取回复器
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
replyer = await get_replyer(chat_stream, chat_id, request_type=request_type)
if not replyer:
logger.error("[GeneratorAPI] 无法获取回复器")
return False, [], None
@@ -199,7 +199,7 @@ async def rewrite_reply(
"""
try:
# 获取回复器
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
replyer = await get_replyer(chat_stream, chat_id, request_type=request_type)
if not replyer:
logger.error("[GeneratorAPI] 无法获取回复器")
return False, [], None
@@ -285,7 +285,7 @@ async def generate_response_custom(
Returns:
Optional[str]: 生成的回复内容
"""
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
replyer = await get_replyer(chat_stream, chat_id, request_type=request_type)
if not replyer:
logger.error("[GeneratorAPI] 无法获取回复器")
return None