From 86377d983f7c980bc74883277a9fa2a8ace60d94 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=85=E8=AF=BA=E7=8B=90?= <212194964+foxcyber907@users.noreply.github.com> Date: Sun, 5 Oct 2025 17:48:28 +0800 Subject: [PATCH] Refactor replyer retrieval to async and add memory formatter Changed get_replyer and related calls to async across multiple modules for proper coroutine handling. Added format_memories_bracket_style utility for memory formatting. Improved video analysis caching logic and type annotations. Updated error logging for message processing. --- src/chat/memory_system/__init__.py | 3 + src/chat/memory_system/memory_formatter.py | 118 ++++++++++++++ src/chat/memory_system/memory_system.py | 6 +- src/chat/message_receive/message.py | 6 +- src/chat/replyer/replyer_manager.py | 5 +- src/chat/utils/prompt.py | 2 +- src/chat/utils/utils_video.py | 175 +++++++++------------ src/plugin_system/apis/generator_api.py | 12 +- 8 files changed, 215 insertions(+), 112 deletions(-) create mode 100644 src/chat/memory_system/memory_formatter.py diff --git a/src/chat/memory_system/__init__.py b/src/chat/memory_system/__init__.py index d3c5feea4..962389b15 100644 --- a/src/chat/memory_system/__init__.py +++ b/src/chat/memory_system/__init__.py @@ -30,6 +30,7 @@ from .memory_system import MemorySystem, MemorySystemConfig, get_memory_system, # Vector DB存储系统 from .vector_memory_storage_v2 import VectorMemoryStorage, VectorStorageConfig, get_vector_memory_storage +from .memory_formatter import format_memories_bracket_style __all__ = [ # 核心数据结构 @@ -62,6 +63,8 @@ __all__ = [ "MemoryActivator", "memory_activator", "enhanced_memory_activator", # 兼容性别名 + # 格式化工具 + "format_memories_bracket_style", ] # 版本信息 diff --git a/src/chat/memory_system/memory_formatter.py b/src/chat/memory_system/memory_formatter.py new file mode 100644 index 000000000..5e5f100f7 --- /dev/null +++ b/src/chat/memory_system/memory_formatter.py @@ -0,0 +1,118 @@ +"""记忆格式化工具 + +提供统一的记忆块格式化函数,供构建 Prompt 时使用。 + +当前使用的函数: format_memories_bracket_style +输入: list[dict] 其中每个元素包含: + - display: str 记忆可读内容 + - memory_type: str 记忆类型 (personal_fact/opinion/preference/event 等) + - metadata: dict 可选,包括 + - confidence: 置信度 (str|float) + - importance: 重要度 (str|float) + - timestamp: 时间戳 (float|str) + - source: 来源 (str) + - relevance_score: 相关度 (float) + +返回: 适合直接嵌入提示词的大段文本;若无有效记忆返回空串。 +""" +from __future__ import annotations + +from typing import Any, Iterable +import time + + +def _format_timestamp(ts: Any) -> str: + try: + if ts in (None, ""): + return "" + if isinstance(ts, (int, float)) and ts > 0: + return time.strftime("%Y-%m-%d %H:%M", time.localtime(float(ts))) + return str(ts) + except Exception: + return "" + + +def _coerce_str(v: Any) -> str: + if v is None: + return "" + return str(v) + + +def format_memories_bracket_style( + memories: Iterable[dict[str, Any]] | None, + query_context: str | None = None, + max_items: int = 15, +) -> str: + """以方括号 + 标注字段的方式格式化记忆列表。 + + 例子输出: + ## 相关记忆回顾 + - [类型:personal_fact|重要:高|置信:0.83|相关:0.72] 他喜欢黑咖啡 (来源: chat, 2025-10-05 09:30) + + Args: + memories: 记忆字典迭代器 + query_context: 当前查询/用户的消息,用于在首行提示(可选) + max_items: 最多输出的记忆条数 + Returns: + str: 格式化文本;若无内容返回空串 + """ + if not memories: + return "" + + lines: list[str] = ["## 相关记忆回顾"] + if query_context: + lines.append(f"(与当前消息相关:{query_context[:60]}{'...' if len(query_context) > 60 else ''})") + lines.append("") + + count = 0 + for mem in memories: + if count >= max_items: + break + if not isinstance(mem, dict): + continue + display = _coerce_str(mem.get("display", "")).strip() + if not display: + continue + + mtype = _coerce_str(mem.get("memory_type", "fact")) or "fact" + meta = mem.get("metadata", {}) if isinstance(mem.get("metadata"), dict) else {} + confidence = _coerce_str(meta.get("confidence", "")) + importance = _coerce_str(meta.get("importance", "")) + source = _coerce_str(meta.get("source", "")) + rel = meta.get("relevance_score") + try: + rel_str = f"{float(rel):.2f}" if rel is not None else "" + except Exception: + rel_str = "" + ts = _format_timestamp(meta.get("timestamp")) + + # 构建标签段 + tags: list[str] = [f"类型:{mtype}"] + if importance: + tags.append(f"重要:{importance}") + if confidence: + tags.append(f"置信:{confidence}") + if rel_str: + tags.append(f"相关:{rel_str}") + + tag_block = "|".join(tags) + suffix_parts = [] + if source: + suffix_parts.append(source) + if ts: + suffix_parts.append(ts) + suffix = (" (" + ", ".join(suffix_parts) + ")") if suffix_parts else "" + + lines.append(f"- [{tag_block}] {display}{suffix}") + count += 1 + + if count == 0: + return "" + + if count >= max_items: + lines.append(f"\n(已截断,仅显示前 {max_items} 条相关记忆)") + + return "\n".join(lines) + + +__all__ = ["format_memories_bracket_style"] diff --git a/src/chat/memory_system/memory_system.py b/src/chat/memory_system/memory_system.py index a2c0a0e83..1032e93f5 100644 --- a/src/chat/memory_system/memory_system.py +++ b/src/chat/memory_system/memory_system.py @@ -997,7 +997,8 @@ class MemorySystem: from src.chat.message_receive.chat_stream import get_chat_manager chat_manager = get_chat_manager() - chat_stream = chat_manager.get_stream(stream_id) + # get_stream 为异步方法,需要 await + chat_stream = await chat_manager.get_stream(stream_id) if not chat_stream or not hasattr(chat_stream, "context_manager"): logger.debug(f"未找到stream_id={stream_id}的聊天流或上下文管理器") @@ -1111,7 +1112,8 @@ class MemorySystem: from src.chat.message_receive.chat_stream import get_chat_manager chat_manager = get_chat_manager() - chat_stream = chat_manager.get_stream(stream_id) + # ChatManager.get_stream 是异步方法,需要 await,否则会产生 "coroutine was never awaited" 警告 + chat_stream = await chat_manager.get_stream(stream_id) if chat_stream and hasattr(chat_stream, "context_manager"): history_limit = self._determine_history_limit(context) messages = chat_stream.context_manager.get_messages(limit=history_limit, include_unread=True) diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index 7953ff862..86c32ea94 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -263,7 +263,7 @@ class MessageRecv(Message): logger.warning("视频消息中没有base64数据") return "[收到视频消息,但数据异常]" except Exception as e: - logger.error(f"视频处理失败: {e!s}") + logger.error(f"视频处理失败: {str(e)}") import traceback logger.error(f"错误详情: {traceback.format_exc()}") @@ -277,7 +277,7 @@ class MessageRecv(Message): logger.info("未启用视频识别") return "[视频]" except Exception as e: - logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}") + logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}") return f"[处理失败的{segment.type}消息]" @@ -427,7 +427,7 @@ class MessageRecvS4U(MessageRecv): # 使用video analyzer分析视频 video_analyzer = get_video_analyzer() - result = await video_analyzer.analyze_video_from_bytes( + result = await video_analyzer.analyze_video( video_bytes, filename, prompt=global_config.video_analysis.batch_analysis_prompt ) diff --git a/src/chat/replyer/replyer_manager.py b/src/chat/replyer/replyer_manager.py index 55a422c1b..3e11c8a2f 100644 --- a/src/chat/replyer/replyer_manager.py +++ b/src/chat/replyer/replyer_manager.py @@ -9,7 +9,7 @@ class ReplyerManager: def __init__(self): self._repliers: dict[str, DefaultReplyer] = {} - def get_replyer( + async def get_replyer( self, chat_stream: ChatStream | None = None, chat_id: str | None = None, @@ -37,7 +37,8 @@ class ReplyerManager: target_stream = chat_stream if not target_stream: if chat_manager := get_chat_manager(): - target_stream = chat_manager.get_stream(stream_id) + # get_stream 为异步,需要等待 + target_stream = await chat_manager.get_stream(stream_id) if not target_stream: logger.warning(f"[ReplyerManager] 未找到 stream_id='{stream_id}' 的聊天流,无法创建回复器。") diff --git a/src/chat/utils/prompt.py b/src/chat/utils/prompt.py index 53f11f500..72bb926b8 100644 --- a/src/chat/utils/prompt.py +++ b/src/chat/utils/prompt.py @@ -492,7 +492,7 @@ class Prompt: from src.plugin_system.apis.generator_api import get_replyer # 创建临时生成器实例来使用其方法 - temp_generator = get_replyer(None, chat_id, request_type="prompt_building") + temp_generator = await get_replyer(None, chat_id, request_type="prompt_building") return await temp_generator.build_s4u_chat_history_prompts( message_list_before_now, target_user_id, sender, chat_id ) diff --git a/src/chat/utils/utils_video.py b/src/chat/utils/utils_video.py index 78ea3a11c..fe14e54c5 100644 --- a/src/chat/utils/utils_video.py +++ b/src/chat/utils/utils_video.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# -*- coding: utf-8 -*- """纯 inkfox 视频关键帧分析工具 仅依赖 `inkfox.video` 提供的 Rust 扩展能力: @@ -13,27 +14,25 @@ from __future__ import annotations +import os +import io import asyncio import base64 -import hashlib -import io -import os import tempfile -import time from pathlib import Path -from typing import Any +from typing import List, Tuple, Optional, Dict, Any +import hashlib +import time from PIL import Image -from sqlalchemy import exc as sa_exc # type: ignore -from sqlalchemy import insert, select, update # type: ignore -from src.common.database.sqlalchemy_models import Videos, get_db_session # type: ignore from src.common.logger import get_logger from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest +from src.common.database.sqlalchemy_models import Videos, get_db_session # type: ignore # 简易并发控制:同一 hash 只处理一次 -_video_locks: dict[str, asyncio.Lock] = {} +_video_locks: Dict[str, asyncio.Lock] = {} _locks_guard = asyncio.Lock() logger = get_logger("utils_video") @@ -91,7 +90,7 @@ class VideoAnalyzer: logger.debug(f"获取系统信息失败: {e}") # ---- 关键帧提取 ---- - async def extract_keyframes(self, video_path: str) -> list[tuple[str, float]]: + async def extract_keyframes(self, video_path: str) -> List[Tuple[str, float]]: """提取关键帧并返回 (base64, timestamp_seconds) 列表""" with tempfile.TemporaryDirectory() as tmp: result = video.extract_keyframes_from_video( # type: ignore[attr-defined] @@ -106,7 +105,7 @@ class VideoAnalyzer: ) files = sorted(Path(tmp).glob("keyframe_*.jpg"))[: self.max_frames] total_ms = getattr(result, "total_time_ms", 0) - frames: list[tuple[str, float]] = [] + frames: List[Tuple[str, float]] = [] for i, f in enumerate(files): img = Image.open(f).convert("RGB") if max(img.size) > self.max_image_size: @@ -120,41 +119,38 @@ class VideoAnalyzer: return frames # ---- 批量分析 ---- - async def _analyze_batch(self, frames: list[tuple[str, float]], question: str | None) -> str: - from src.llm_models.payload_content.message import MessageBuilder + async def _analyze_batch(self, frames: List[Tuple[str, float]], question: Optional[str]) -> str: + from src.llm_models.payload_content.message import MessageBuilder, RoleType from src.llm_models.utils_model import RequestType - prompt = self.batch_analysis_prompt.format( personality_core=self.personality_core, personality_side=self.personality_side ) if question: prompt += f"\n用户关注: {question}" - desc = [ (f"第{i+1}帧 (时间: {ts:.2f}s)" if self.enable_frame_timing else f"第{i+1}帧") for i, (_b, ts) in enumerate(frames) ] prompt += "\n帧列表: " + ", ".join(desc) - - message_builder = MessageBuilder().add_text_content(prompt) + mb = MessageBuilder().set_role(RoleType.User).add_text_content(prompt) for b64, _ in frames: - message_builder.add_image_content(image_format="jpeg", image_base64=b64) - messages = [message_builder.build()] - - # 使用封装好的高级策略执行请求,而不是直接调用内部方法 - response, _ = await self.video_llm._strategy.execute_with_failover( - RequestType.RESPONSE, - raise_when_empty=False, # 即使失败也返回默认值,避免程序崩溃 - message_list=messages, - temperature=self.video_llm.model_for_task.temperature, - max_tokens=self.video_llm.model_for_task.max_tokens, + mb.add_image_content("jpeg", b64) + message = mb.build() + model_info, api_provider, client = self.video_llm._select_model() + resp = await self.video_llm._execute_request( + api_provider=api_provider, + client=client, + request_type=RequestType.RESPONSE, + model_info=model_info, + message_list=[message], + temperature=None, + max_tokens=None, ) - - return response.content or "❌ 未获得响应" + return resp.content or "❌ 未获得响应" # ---- 逐帧分析 ---- - async def _analyze_sequential(self, frames: list[tuple[str, float]], question: str | None) -> str: - results: list[str] = [] + async def _analyze_sequential(self, frames: List[Tuple[str, float]], question: Optional[str]) -> str: + results: List[str] = [] for i, (b64, ts) in enumerate(frames): prompt = f"分析第{i+1}帧" + (f" (时间: {ts:.2f}s)" if self.enable_frame_timing else "") if question: @@ -178,7 +174,7 @@ class VideoAnalyzer: return "\n".join(results) # ---- 主入口 ---- - async def analyze_video(self, video_path: str, question: str | None = None) -> tuple[bool, str]: + async def analyze_video(self, video_path: str, question: Optional[str] = None) -> Tuple[bool, str]: if not os.path.exists(video_path): return False, "❌ 文件不存在" frames = await self.extract_keyframes(video_path) @@ -193,10 +189,10 @@ class VideoAnalyzer: async def analyze_video_from_bytes( self, video_bytes: bytes, - filename: str | None = None, - prompt: str | None = None, - question: str | None = None, - ) -> dict[str, str]: + filename: Optional[str] = None, + prompt: Optional[str] = None, + question: Optional[str] = None, + ) -> Dict[str, str]: """从内存字节分析视频,兼容旧调用 (prompt / question 二选一) 返回 {"summary": str}.""" if not video_bytes: return {"summary": "❌ 空视频数据"} @@ -204,11 +200,17 @@ class VideoAnalyzer: q = prompt if prompt is not None else question video_hash = hashlib.sha256(video_bytes).hexdigest() - # 查缓存(第一次,未加锁) - cached = await self._get_cached(video_hash) - if cached: - logger.info(f"视频缓存命中(预检查) hash={video_hash[:16]}") - return {"summary": cached} + # 查缓存 + try: + async with get_db_session() as session: # type: ignore + row = await session.execute( + Videos.__table__.select().where(Videos.video_hash == video_hash) # type: ignore + ) + existing = row.first() + if existing and existing[Videos.description] and existing[Videos.vlm_processed]: # type: ignore + return {"summary": existing[Videos.description]} # type: ignore + except Exception: # pragma: no cover + pass # 获取锁避免重复处理 async with _locks_guard: @@ -217,11 +219,17 @@ class VideoAnalyzer: lock = asyncio.Lock() _video_locks[video_hash] = lock async with lock: - # 双检缓存 - cached2 = await self._get_cached(video_hash) - if cached2: - logger.info(f"视频缓存命中(锁后) hash={video_hash[:16]}") - return {"summary": cached2} + # 双检:进入锁后再查一次,避免重复处理 + try: + async with get_db_session() as session: # type: ignore + row = await session.execute( + Videos.__table__.select().where(Videos.video_hash == video_hash) # type: ignore + ) + existing = row.first() + if existing and existing[Videos.description] and existing[Videos.vlm_processed]: # type: ignore + return {"summary": existing[Videos.description]} # type: ignore + except Exception: # pragma: no cover + pass try: with tempfile.NamedTemporaryFile(delete=False) as fp: @@ -231,7 +239,26 @@ class VideoAnalyzer: ok, summary = await self.analyze_video(temp_path, q) # 写入缓存(仅成功) if ok: - await self._save_cache(video_hash, summary, len(video_bytes)) + try: + async with get_db_session() as session: # type: ignore + await session.execute( + Videos.__table__.insert().values( + video_id="", + video_hash=video_hash, + description=summary, + count=1, + timestamp=time.time(), + vlm_processed=True, + duration=None, + frame_count=None, + fps=None, + resolution=None, + file_size=len(video_bytes), + ) + ) + await session.commit() + except Exception: # pragma: no cover + pass return {"summary": summary} finally: if os.path.exists(temp_path): @@ -242,57 +269,9 @@ class VideoAnalyzer: except Exception as e: # pragma: no cover return {"summary": f"❌ 处理失败: {e}"} - # ---- 缓存辅助 ---- - async def _get_cached(self, video_hash: str) -> str | None: - try: - async with get_db_session() as session: # type: ignore - result = await session.execute(select(Videos).where(Videos.video_hash == video_hash)) # type: ignore - obj: Videos | None = result.scalar_one_or_none() # type: ignore - if obj and obj.vlm_processed and obj.description: - # 更新使用次数 - try: - await session.execute( - update(Videos) - .where(Videos.id == obj.id) # type: ignore - .values(count=obj.count + 1 if obj.count is not None else 1) - ) - await session.commit() - except Exception: # pragma: no cover - await session.rollback() - return obj.description - except Exception: # pragma: no cover - pass - return None - - async def _save_cache(self, video_hash: str, summary: str, file_size: int) -> None: - try: - async with get_db_session() as session: # type: ignore - stmt = insert(Videos).values( # type: ignore - video_id="", - video_hash=video_hash, - description=summary, - count=1, - timestamp=time.time(), - vlm_processed=True, - duration=None, - frame_count=None, - fps=None, - resolution=None, - file_size=file_size, - ) - try: - await session.execute(stmt) - await session.commit() - logger.debug(f"视频缓存写入 success hash={video_hash}") - except sa_exc.IntegrityError: # 可能并发已写入 - await session.rollback() - logger.debug(f"视频缓存已存在 hash={video_hash}") - except Exception: # pragma: no cover - logger.debug("视频缓存写入失败") - # ---- 外部接口 ---- -_INSTANCE: VideoAnalyzer | None = None +_INSTANCE: Optional[VideoAnalyzer] = None def get_video_analyzer() -> VideoAnalyzer: @@ -306,7 +285,7 @@ def is_video_analysis_available() -> bool: return True -def get_video_analysis_status() -> dict[str, Any]: +def get_video_analysis_status() -> Dict[str, Any]: try: info = video.get_system_info() # type: ignore[attr-defined] except Exception as e: # pragma: no cover @@ -318,4 +297,4 @@ def get_video_analysis_status() -> dict[str, Any]: "modes": ["auto", "batch", "sequential"], "max_frames_default": inst.max_frames, "implementation": "inkfox", - } + } \ No newline at end of file diff --git a/src/plugin_system/apis/generator_api.py b/src/plugin_system/apis/generator_api.py index 21bc6fdde..c3eefe6ca 100644 --- a/src/plugin_system/apis/generator_api.py +++ b/src/plugin_system/apis/generator_api.py @@ -4,7 +4,7 @@ 提供回复器相关功能,采用标准Python包设计模式 使用方式: from src.plugin_system.apis import generator_api - replyer = generator_api.get_replyer(chat_stream) + replyer = await generator_api.get_replyer(chat_stream) success, reply_set, _ = await generator_api.generate_reply(chat_stream, action_data, reasoning) """ @@ -31,7 +31,7 @@ logger = get_logger("generator_api") # ============================================================================= -def get_replyer( +async def get_replyer( chat_stream: ChatStream | None = None, chat_id: str | None = None, request_type: str = "replyer", @@ -56,7 +56,7 @@ def get_replyer( raise ValueError("chat_stream 和 chat_id 不可均为空") try: logger.debug(f"[GeneratorAPI] 正在获取回复器,chat_id: {chat_id}, chat_stream: {'有' if chat_stream else '无'}") - return replyer_manager.get_replyer( + return await replyer_manager.get_replyer( chat_stream=chat_stream, chat_id=chat_id, request_type=request_type, @@ -110,7 +110,7 @@ async def generate_reply( """ try: # 获取回复器 - replyer = get_replyer(chat_stream, chat_id, request_type=request_type) + replyer = await get_replyer(chat_stream, chat_id, request_type=request_type) if not replyer: logger.error("[GeneratorAPI] 无法获取回复器") return False, [], None @@ -199,7 +199,7 @@ async def rewrite_reply( """ try: # 获取回复器 - replyer = get_replyer(chat_stream, chat_id, request_type=request_type) + replyer = await get_replyer(chat_stream, chat_id, request_type=request_type) if not replyer: logger.error("[GeneratorAPI] 无法获取回复器") return False, [], None @@ -285,7 +285,7 @@ async def generate_response_custom( Returns: Optional[str]: 生成的回复内容 """ - replyer = get_replyer(chat_stream, chat_id, request_type=request_type) + replyer = await get_replyer(chat_stream, chat_id, request_type=request_type) if not replyer: logger.error("[GeneratorAPI] 无法获取回复器") return None