diff --git a/pyproject.toml b/pyproject.toml index 2ad3c5433..f5de2e25f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,7 +74,7 @@ dependencies = [ "websockets>=15.0.1", "aiomysql>=0.2.0", "aiosqlite>=0.21.0", - "inkfox>=0.1.0", + "inkfox>=0.1.1", "rrjieba>=0.1.13", "mcp>=0.9.0", "sse-starlette>=2.2.1", diff --git a/src/chat/memory_system/__init__.py b/src/chat/memory_system/__init__.py index d3c5feea4..962389b15 100644 --- a/src/chat/memory_system/__init__.py +++ b/src/chat/memory_system/__init__.py @@ -30,6 +30,7 @@ from .memory_system import MemorySystem, MemorySystemConfig, get_memory_system, # Vector DB存储系统 from .vector_memory_storage_v2 import VectorMemoryStorage, VectorStorageConfig, get_vector_memory_storage +from .memory_formatter import format_memories_bracket_style __all__ = [ # 核心数据结构 @@ -62,6 +63,8 @@ __all__ = [ "MemoryActivator", "memory_activator", "enhanced_memory_activator", # 兼容性别名 + # 格式化工具 + "format_memories_bracket_style", ] # 版本信息 diff --git a/src/chat/memory_system/memory_formatter.py b/src/chat/memory_system/memory_formatter.py new file mode 100644 index 000000000..5e5f100f7 --- /dev/null +++ b/src/chat/memory_system/memory_formatter.py @@ -0,0 +1,118 @@ +"""记忆格式化工具 + +提供统一的记忆块格式化函数,供构建 Prompt 时使用。 + +当前使用的函数: format_memories_bracket_style +输入: list[dict] 其中每个元素包含: + - display: str 记忆可读内容 + - memory_type: str 记忆类型 (personal_fact/opinion/preference/event 等) + - metadata: dict 可选,包括 + - confidence: 置信度 (str|float) + - importance: 重要度 (str|float) + - timestamp: 时间戳 (float|str) + - source: 来源 (str) + - relevance_score: 相关度 (float) + +返回: 适合直接嵌入提示词的大段文本;若无有效记忆返回空串。 +""" +from __future__ import annotations + +from typing import Any, Iterable +import time + + +def _format_timestamp(ts: Any) -> str: + try: + if ts in (None, ""): + return "" + if isinstance(ts, (int, float)) and ts > 0: + return time.strftime("%Y-%m-%d %H:%M", time.localtime(float(ts))) + return str(ts) + except Exception: + return "" + + +def _coerce_str(v: Any) -> str: + if v is None: + return "" + return str(v) + + +def format_memories_bracket_style( + memories: Iterable[dict[str, Any]] | None, + query_context: str | None = None, + max_items: int = 15, +) -> str: + """以方括号 + 标注字段的方式格式化记忆列表。 + + 例子输出: + ## 相关记忆回顾 + - [类型:personal_fact|重要:高|置信:0.83|相关:0.72] 他喜欢黑咖啡 (来源: chat, 2025-10-05 09:30) + + Args: + memories: 记忆字典迭代器 + query_context: 当前查询/用户的消息,用于在首行提示(可选) + max_items: 最多输出的记忆条数 + Returns: + str: 格式化文本;若无内容返回空串 + """ + if not memories: + return "" + + lines: list[str] = ["## 相关记忆回顾"] + if query_context: + lines.append(f"(与当前消息相关:{query_context[:60]}{'...' if len(query_context) > 60 else ''})") + lines.append("") + + count = 0 + for mem in memories: + if count >= max_items: + break + if not isinstance(mem, dict): + continue + display = _coerce_str(mem.get("display", "")).strip() + if not display: + continue + + mtype = _coerce_str(mem.get("memory_type", "fact")) or "fact" + meta = mem.get("metadata", {}) if isinstance(mem.get("metadata"), dict) else {} + confidence = _coerce_str(meta.get("confidence", "")) + importance = _coerce_str(meta.get("importance", "")) + source = _coerce_str(meta.get("source", "")) + rel = meta.get("relevance_score") + try: + rel_str = f"{float(rel):.2f}" if rel is not None else "" + except Exception: + rel_str = "" + ts = _format_timestamp(meta.get("timestamp")) + + # 构建标签段 + tags: list[str] = [f"类型:{mtype}"] + if importance: + tags.append(f"重要:{importance}") + if confidence: + tags.append(f"置信:{confidence}") + if rel_str: + tags.append(f"相关:{rel_str}") + + tag_block = "|".join(tags) + suffix_parts = [] + if source: + suffix_parts.append(source) + if ts: + suffix_parts.append(ts) + suffix = (" (" + ", ".join(suffix_parts) + ")") if suffix_parts else "" + + lines.append(f"- [{tag_block}] {display}{suffix}") + count += 1 + + if count == 0: + return "" + + if count >= max_items: + lines.append(f"\n(已截断,仅显示前 {max_items} 条相关记忆)") + + return "\n".join(lines) + + +__all__ = ["format_memories_bracket_style"] diff --git a/src/chat/memory_system/memory_metadata_index.py b/src/chat/memory_system/memory_metadata_index.py index 4b405aad6..eff666b2c 100644 --- a/src/chat/memory_system/memory_metadata_index.py +++ b/src/chat/memory_system/memory_metadata_index.py @@ -1,193 +1,61 @@ """ -记忆元数据索引管理器 -使用JSON文件存储记忆元数据,支持快速模糊搜索和过滤 +记忆元数据索引。 """ -import threading -from dataclasses import asdict, dataclass -from datetime import datetime -from pathlib import Path +from dataclasses import dataclass, asdict from typing import Any - -import orjson +from time import time from src.common.logger import get_logger logger = get_logger(__name__) +from inkfox.memory import PyMetadataIndex as _RustIndex # type: ignore @dataclass class MemoryMetadataIndexEntry: - """元数据索引条目(轻量级,只用于快速过滤)""" - memory_id: str user_id: str - - # 分类信息 - memory_type: str # MemoryType.value - subjects: list[str] # 主语列表 - objects: list[str] # 宾语列表 - keywords: list[str] # 关键词列表 - tags: list[str] # 标签列表 - - # 数值字段(用于范围过滤) - importance: int # ImportanceLevel.value (1-4) - confidence: int # ConfidenceLevel.value (1-4) - created_at: float # 创建时间戳 - access_count: int # 访问次数 - - # 可选字段 + memory_type: str + subjects: list[str] + objects: list[str] + keywords: list[str] + tags: list[str] + importance: int + confidence: int + created_at: float + access_count: int chat_id: str | None = None - content_preview: str | None = None # 内容预览(前100字符) + content_preview: str | None = None class MemoryMetadataIndex: - """记忆元数据索引管理器""" + """Rust 加速版本唯一实现。""" def __init__(self, index_file: str = "data/memory_metadata_index.json"): - self.index_file = Path(index_file) - self.index: dict[str, MemoryMetadataIndexEntry] = {} # memory_id -> entry + self._rust = _RustIndex(index_file) + # 仅为向量层和调试提供最小缓存(长度判断、get_entry 返回) + self.index: dict[str, MemoryMetadataIndexEntry] = {} + logger.info("✅ MemoryMetadataIndex (Rust) 初始化完成,仅支持加速实现") - # 倒排索引(用于快速查找) - self.type_index: dict[str, set[str]] = {} # type -> {memory_ids} - self.subject_index: dict[str, set[str]] = {} # subject -> {memory_ids} - self.keyword_index: dict[str, set[str]] = {} # keyword -> {memory_ids} - self.tag_index: dict[str, set[str]] = {} # tag -> {memory_ids} - - self.lock = threading.RLock() - - # 加载已有索引 - self._load_index() - - def _load_index(self): - """从文件加载索引""" - if not self.index_file.exists(): - logger.info(f"元数据索引文件不存在,将创建新索引: {self.index_file}") + # 向后代码仍调用的接口:batch_add_or_update / add_or_update + def batch_add_or_update(self, entries: list[MemoryMetadataIndexEntry]): + if not entries: return - - try: - with open(self.index_file, "rb") as f: - data = orjson.loads(f.read()) - - # 重建内存索引 - for entry_data in data.get("entries", []): - entry = MemoryMetadataIndexEntry(**entry_data) - self.index[entry.memory_id] = entry - self._update_inverted_indices(entry) - - logger.info(f"✅ 加载元数据索引: {len(self.index)} 条记录") - - except Exception as e: - logger.error(f"加载元数据索引失败: {e}", exc_info=True) - - def _save_index(self): - """保存索引到文件""" - try: - # 确保目录存在 - self.index_file.parent.mkdir(parents=True, exist_ok=True) - - # 序列化所有条目 - entries = [asdict(entry) for entry in self.index.values()] - data = { - "version": "1.0", - "count": len(entries), - "last_updated": datetime.now().isoformat(), - "entries": entries, - } - - # 写入文件(使用临时文件 + 原子重命名) - temp_file = self.index_file.with_suffix(".tmp") - with open(temp_file, "wb") as f: - f.write(orjson.dumps(data, option=orjson.OPT_INDENT_2)) - - temp_file.replace(self.index_file) - logger.debug(f"元数据索引已保存: {len(entries)} 条记录") - - except Exception as e: - logger.error(f"保存元数据索引失败: {e}", exc_info=True) - - def _update_inverted_indices(self, entry: MemoryMetadataIndexEntry): - """更新倒排索引""" - memory_id = entry.memory_id - - # 类型索引 - self.type_index.setdefault(entry.memory_type, set()).add(memory_id) - - # 主语索引 - for subject in entry.subjects: - subject_norm = subject.strip().lower() - if subject_norm: - self.subject_index.setdefault(subject_norm, set()).add(memory_id) - - # 关键词索引 - for keyword in entry.keywords: - keyword_norm = keyword.strip().lower() - if keyword_norm: - self.keyword_index.setdefault(keyword_norm, set()).add(memory_id) - - # 标签索引 - for tag in entry.tags: - tag_norm = tag.strip().lower() - if tag_norm: - self.tag_index.setdefault(tag_norm, set()).add(memory_id) + payload = [] + for e in entries: + if not e.memory_id: + continue + self.index[e.memory_id] = e + payload.append(asdict(e)) + if payload: + try: + self._rust.batch_add(payload) + except Exception as ex: # noqa: BLE001 + logger.error(f"Rust 元数据批量添加失败: {ex}") def add_or_update(self, entry: MemoryMetadataIndexEntry): - """添加或更新索引条目""" - with self.lock: - # 如果已存在,先从倒排索引中移除旧记录 - if entry.memory_id in self.index: - self._remove_from_inverted_indices(entry.memory_id) - - # 添加新记录 - self.index[entry.memory_id] = entry - self._update_inverted_indices(entry) - - def _remove_from_inverted_indices(self, memory_id: str): - """从倒排索引中移除记录""" - if memory_id not in self.index: - return - - entry = self.index[memory_id] - - # 从类型索引移除 - if entry.memory_type in self.type_index: - self.type_index[entry.memory_type].discard(memory_id) - - # 从主语索引移除 - for subject in entry.subjects: - subject_norm = subject.strip().lower() - if subject_norm in self.subject_index: - self.subject_index[subject_norm].discard(memory_id) - - # 从关键词索引移除 - for keyword in entry.keywords: - keyword_norm = keyword.strip().lower() - if keyword_norm in self.keyword_index: - self.keyword_index[keyword_norm].discard(memory_id) - - # 从标签索引移除 - for tag in entry.tags: - tag_norm = tag.strip().lower() - if tag_norm in self.tag_index: - self.tag_index[tag_norm].discard(memory_id) - - def remove(self, memory_id: str): - """移除索引条目""" - with self.lock: - if memory_id in self.index: - self._remove_from_inverted_indices(memory_id) - del self.index[memory_id] - - def batch_add_or_update(self, entries: list[MemoryMetadataIndexEntry]): - """批量添加或更新""" - with self.lock: - for entry in entries: - self.add_or_update(entry) - - def save(self): - """保存索引到磁盘""" - with self.lock: - self._save_index() + self.batch_add_or_update([entry]) def search( self, @@ -201,287 +69,54 @@ class MemoryMetadataIndex: created_before: float | None = None, user_id: str | None = None, limit: int | None = None, - flexible_mode: bool = True, # 新增:灵活匹配模式 + flexible_mode: bool = True, ) -> list[str]: - """ - 搜索符合条件的记忆ID列表(支持模糊匹配) - - Returns: - List[str]: 符合条件的 memory_id 列表 - """ - with self.lock: + params: dict[str, Any] = { + "user_id": user_id, + "memory_types": memory_types, + "subjects": subjects, + "keywords": keywords, + "tags": tags, + "importance_min": importance_min, + "importance_max": importance_max, + "created_after": created_after, + "created_before": created_before, + "limit": limit, + } + params = {k: v for k, v in params.items() if v is not None} + try: if flexible_mode: - return self._search_flexible( - memory_types=memory_types, - subjects=subjects, - keywords=keywords, # 保留用于兼容性 - tags=tags, # 保留用于兼容性 - created_after=created_after, - created_before=created_before, - user_id=user_id, - limit=limit, - ) - else: - return self._search_strict( - memory_types=memory_types, - subjects=subjects, - keywords=keywords, - tags=tags, - importance_min=importance_min, - importance_max=importance_max, - created_after=created_after, - created_before=created_before, - user_id=user_id, - limit=limit, - ) - - def _search_flexible( - self, - memory_types: list[str] | None = None, - subjects: list[str] | None = None, - created_after: float | None = None, - created_before: float | None = None, - user_id: str | None = None, - limit: int | None = None, - **kwargs, # 接受但不使用的参数 - ) -> list[str]: - """ - 灵活搜索模式:2/4项匹配即可,支持部分匹配 - - 评分维度: - 1. 记忆类型匹配 (0-1分) - 2. 主语匹配 (0-1分) - 3. 宾语匹配 (0-1分) - 4. 时间范围匹配 (0-1分) - - 总分 >= 2分即视为有效 - """ - # 用户过滤(必选) - if user_id: - base_candidates = {mid for mid, entry in self.index.items() if entry.user_id == user_id} - else: - base_candidates = set(self.index.keys()) - - scored_candidates = [] - - for memory_id in base_candidates: - entry = self.index[memory_id] - score = 0 - match_details = [] - - # 1. 记忆类型匹配 - if memory_types: - type_score = 0 - for mtype in memory_types: - if entry.memory_type == mtype: - type_score = 1 - break - # 部分匹配:类型名称包含 - if mtype.lower() in entry.memory_type.lower() or entry.memory_type.lower() in mtype.lower(): - type_score = 0.5 - break - score += type_score - if type_score > 0: - match_details.append(f"类型:{entry.memory_type}") - else: - match_details.append("类型:未指定") - - # 2. 主语匹配(支持部分匹配) - if subjects: - subject_score = 0 - for subject in subjects: - subject_norm = subject.strip().lower() - for entry_subject in entry.subjects: - entry_subject_norm = entry_subject.strip().lower() - # 精确匹配 - if subject_norm == entry_subject_norm: - subject_score = 1 - break - # 部分匹配:包含关系 - if subject_norm in entry_subject_norm or entry_subject_norm in subject_norm: - subject_score = 0.6 - break - if subject_score == 1: - break - score += subject_score - if subject_score > 0: - match_details.append("主语:匹配") - else: - match_details.append("主语:未指定") - - # 3. 宾语匹配(支持部分匹配) - object_score = 0 - if entry.objects: - for entry_object in entry.objects: - entry_object_norm = str(entry_object).strip().lower() - # 检查是否与主语相关(主宾关联) - for subject in subjects or []: - subject_norm = subject.strip().lower() - if subject_norm in entry_object_norm or entry_object_norm in subject_norm: - object_score = 0.8 - match_details.append("宾语:主宾关联") - break - if object_score > 0: - break - - score += object_score - if object_score > 0: - match_details.append("宾语:匹配") - elif not entry.objects: - match_details.append("宾语:无") - - # 4. 时间范围匹配 - time_score = 0 - if created_after is not None or created_before is not None: - time_match = True - if created_after is not None and entry.created_at < created_after: - time_match = False - if created_before is not None and entry.created_at > created_before: - time_match = False - if time_match: - time_score = 1 - match_details.append("时间:匹配") - else: - match_details.append("时间:不匹配") - else: - match_details.append("时间:未指定") - - score += time_score - - # 只有总分 >= 2 的记忆才会被返回 - if score >= 2: - scored_candidates.append((memory_id, score, match_details)) - - # 按分数和时间排序 - scored_candidates.sort(key=lambda x: (x[1], self.index[x[0]].created_at), reverse=True) - - if limit: - result_ids = [mid for mid, _, _ in scored_candidates[:limit]] - else: - result_ids = [mid for mid, _, _ in scored_candidates] - - logger.debug( - f"[灵活搜索] 过滤条件: types={memory_types}, subjects={subjects}, " - f"time_range=[{created_after}, {created_before}], 返回={len(result_ids)}条" - ) - - # 记录匹配统计 - if scored_candidates and len(scored_candidates) > 0: - avg_score = sum(score for _, score, _ in scored_candidates) / len(scored_candidates) - logger.debug(f"[灵活搜索] 平均匹配分数: {avg_score:.2f}, 最高分: {scored_candidates[0][1]:.2f}") - - return result_ids - - def _search_strict( - self, - memory_types: list[str] | None = None, - subjects: list[str] | None = None, - keywords: list[str] | None = None, - tags: list[str] | None = None, - importance_min: int | None = None, - importance_max: int | None = None, - created_after: float | None = None, - created_before: float | None = None, - user_id: str | None = None, - limit: int | None = None, - ) -> list[str]: - """严格搜索模式(原有逻辑)""" - # 初始候选集(所有记忆) - candidate_ids: set[str] | None = None - - # 用户过滤(必选) - if user_id: - candidate_ids = {mid for mid, entry in self.index.items() if entry.user_id == user_id} - else: - candidate_ids = set(self.index.keys()) - - # 类型过滤(OR关系) - if memory_types: - type_ids = set() - for mtype in memory_types: - type_ids.update(self.type_index.get(mtype, set())) - candidate_ids &= type_ids - - # 主语过滤(OR关系,支持模糊匹配) - if subjects: - subject_ids = set() - for subject in subjects: - subject_norm = subject.strip().lower() - # 精确匹配 - if subject_norm in self.subject_index: - subject_ids.update(self.subject_index[subject_norm]) - # 模糊匹配(包含) - for indexed_subject, ids in self.subject_index.items(): - if subject_norm in indexed_subject or indexed_subject in subject_norm: - subject_ids.update(ids) - candidate_ids &= subject_ids - - # 关键词过滤(OR关系,支持模糊匹配) - if keywords: - keyword_ids = set() - for keyword in keywords: - keyword_norm = keyword.strip().lower() - # 精确匹配 - if keyword_norm in self.keyword_index: - keyword_ids.update(self.keyword_index[keyword_norm]) - # 模糊匹配(包含) - for indexed_keyword, ids in self.keyword_index.items(): - if keyword_norm in indexed_keyword or indexed_keyword in keyword_norm: - keyword_ids.update(ids) - candidate_ids &= keyword_ids - - # 标签过滤(OR关系) - if tags: - tag_ids = set() - for tag in tags: - tag_norm = tag.strip().lower() - tag_ids.update(self.tag_index.get(tag_norm, set())) - candidate_ids &= tag_ids - - # 重要性过滤 - if importance_min is not None or importance_max is not None: - importance_ids = { - mid - for mid in candidate_ids - if (importance_min is None or self.index[mid].importance >= importance_min) - and (importance_max is None or self.index[mid].importance <= importance_max) - } - candidate_ids &= importance_ids - - # 时间范围过滤 - if created_after is not None or created_before is not None: - time_ids = { - mid - for mid in candidate_ids - if (created_after is None or self.index[mid].created_at >= created_after) - and (created_before is None or self.index[mid].created_at <= created_before) - } - candidate_ids &= time_ids - - # 转换为列表并排序(按创建时间倒序) - result_ids = sorted(candidate_ids, key=lambda mid: self.index[mid].created_at, reverse=True) - - # 限制数量 - if limit: - result_ids = result_ids[:limit] - - logger.debug( - f"[严格搜索] types={memory_types}, subjects={subjects}, keywords={keywords}, 返回={len(result_ids)}条" - ) - - return result_ids + return list(self._rust.search_flexible(params)) + return list(self._rust.search_strict(params)) + except Exception as ex: # noqa: BLE001 + logger.error(f"Rust 搜索失败返回空: {ex}") + return [] def get_entry(self, memory_id: str) -> MemoryMetadataIndexEntry | None: - """获取单个索引条目""" return self.index.get(memory_id) def get_stats(self) -> dict[str, Any]: - """获取索引统计信息""" - with self.lock: + try: + raw = self._rust.stats() return { - "total_memories": len(self.index), - "types": {mtype: len(ids) for mtype, ids in self.type_index.items()}, - "subjects_count": len(self.subject_index), - "keywords_count": len(self.keyword_index), - "tags_count": len(self.tag_index), + "total_memories": raw.get("total", 0), + "types": raw.get("types_dist", {}), + "subjects_count": raw.get("subjects_indexed", 0), + "keywords_count": raw.get("keywords_indexed", 0), + "tags_count": raw.get("tags_indexed", 0), } + except Exception as ex: # noqa: BLE001 + logger.warning(f"读取 Rust stats 失败: {ex}") + return {"total_memories": 0} + + def save(self): # 仅调用 rust save + try: + self._rust.save() + except Exception as ex: # noqa: BLE001 + logger.warning(f"Rust save 失败: {ex}") + + +__all__ = [ + "MemoryMetadataIndexEntry", + "MemoryMetadataIndex", +] diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index 7953ff862..86c32ea94 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -263,7 +263,7 @@ class MessageRecv(Message): logger.warning("视频消息中没有base64数据") return "[收到视频消息,但数据异常]" except Exception as e: - logger.error(f"视频处理失败: {e!s}") + logger.error(f"视频处理失败: {str(e)}") import traceback logger.error(f"错误详情: {traceback.format_exc()}") @@ -277,7 +277,7 @@ class MessageRecv(Message): logger.info("未启用视频识别") return "[视频]" except Exception as e: - logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}") + logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}") return f"[处理失败的{segment.type}消息]" @@ -427,7 +427,7 @@ class MessageRecvS4U(MessageRecv): # 使用video analyzer分析视频 video_analyzer = get_video_analyzer() - result = await video_analyzer.analyze_video_from_bytes( + result = await video_analyzer.analyze_video( video_bytes, filename, prompt=global_config.video_analysis.batch_analysis_prompt ) diff --git a/src/chat/replyer/replyer_manager.py b/src/chat/replyer/replyer_manager.py index 4f3f4f428..e1aef6a4f 100644 --- a/src/chat/replyer/replyer_manager.py +++ b/src/chat/replyer/replyer_manager.py @@ -9,6 +9,7 @@ class ReplyerManager: def __init__(self): self._repliers: dict[str, DefaultReplyer] = {} + async def get_replyer( async def get_replyer( self, chat_stream: ChatStream | None = None, diff --git a/src/chat/utils/utils_video.py b/src/chat/utils/utils_video.py index 78ea3a11c..fe14e54c5 100644 --- a/src/chat/utils/utils_video.py +++ b/src/chat/utils/utils_video.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +# -*- coding: utf-8 -*- """纯 inkfox 视频关键帧分析工具 仅依赖 `inkfox.video` 提供的 Rust 扩展能力: @@ -13,27 +14,25 @@ from __future__ import annotations +import os +import io import asyncio import base64 -import hashlib -import io -import os import tempfile -import time from pathlib import Path -from typing import Any +from typing import List, Tuple, Optional, Dict, Any +import hashlib +import time from PIL import Image -from sqlalchemy import exc as sa_exc # type: ignore -from sqlalchemy import insert, select, update # type: ignore -from src.common.database.sqlalchemy_models import Videos, get_db_session # type: ignore from src.common.logger import get_logger from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest +from src.common.database.sqlalchemy_models import Videos, get_db_session # type: ignore # 简易并发控制:同一 hash 只处理一次 -_video_locks: dict[str, asyncio.Lock] = {} +_video_locks: Dict[str, asyncio.Lock] = {} _locks_guard = asyncio.Lock() logger = get_logger("utils_video") @@ -91,7 +90,7 @@ class VideoAnalyzer: logger.debug(f"获取系统信息失败: {e}") # ---- 关键帧提取 ---- - async def extract_keyframes(self, video_path: str) -> list[tuple[str, float]]: + async def extract_keyframes(self, video_path: str) -> List[Tuple[str, float]]: """提取关键帧并返回 (base64, timestamp_seconds) 列表""" with tempfile.TemporaryDirectory() as tmp: result = video.extract_keyframes_from_video( # type: ignore[attr-defined] @@ -106,7 +105,7 @@ class VideoAnalyzer: ) files = sorted(Path(tmp).glob("keyframe_*.jpg"))[: self.max_frames] total_ms = getattr(result, "total_time_ms", 0) - frames: list[tuple[str, float]] = [] + frames: List[Tuple[str, float]] = [] for i, f in enumerate(files): img = Image.open(f).convert("RGB") if max(img.size) > self.max_image_size: @@ -120,41 +119,38 @@ class VideoAnalyzer: return frames # ---- 批量分析 ---- - async def _analyze_batch(self, frames: list[tuple[str, float]], question: str | None) -> str: - from src.llm_models.payload_content.message import MessageBuilder + async def _analyze_batch(self, frames: List[Tuple[str, float]], question: Optional[str]) -> str: + from src.llm_models.payload_content.message import MessageBuilder, RoleType from src.llm_models.utils_model import RequestType - prompt = self.batch_analysis_prompt.format( personality_core=self.personality_core, personality_side=self.personality_side ) if question: prompt += f"\n用户关注: {question}" - desc = [ (f"第{i+1}帧 (时间: {ts:.2f}s)" if self.enable_frame_timing else f"第{i+1}帧") for i, (_b, ts) in enumerate(frames) ] prompt += "\n帧列表: " + ", ".join(desc) - - message_builder = MessageBuilder().add_text_content(prompt) + mb = MessageBuilder().set_role(RoleType.User).add_text_content(prompt) for b64, _ in frames: - message_builder.add_image_content(image_format="jpeg", image_base64=b64) - messages = [message_builder.build()] - - # 使用封装好的高级策略执行请求,而不是直接调用内部方法 - response, _ = await self.video_llm._strategy.execute_with_failover( - RequestType.RESPONSE, - raise_when_empty=False, # 即使失败也返回默认值,避免程序崩溃 - message_list=messages, - temperature=self.video_llm.model_for_task.temperature, - max_tokens=self.video_llm.model_for_task.max_tokens, + mb.add_image_content("jpeg", b64) + message = mb.build() + model_info, api_provider, client = self.video_llm._select_model() + resp = await self.video_llm._execute_request( + api_provider=api_provider, + client=client, + request_type=RequestType.RESPONSE, + model_info=model_info, + message_list=[message], + temperature=None, + max_tokens=None, ) - - return response.content or "❌ 未获得响应" + return resp.content or "❌ 未获得响应" # ---- 逐帧分析 ---- - async def _analyze_sequential(self, frames: list[tuple[str, float]], question: str | None) -> str: - results: list[str] = [] + async def _analyze_sequential(self, frames: List[Tuple[str, float]], question: Optional[str]) -> str: + results: List[str] = [] for i, (b64, ts) in enumerate(frames): prompt = f"分析第{i+1}帧" + (f" (时间: {ts:.2f}s)" if self.enable_frame_timing else "") if question: @@ -178,7 +174,7 @@ class VideoAnalyzer: return "\n".join(results) # ---- 主入口 ---- - async def analyze_video(self, video_path: str, question: str | None = None) -> tuple[bool, str]: + async def analyze_video(self, video_path: str, question: Optional[str] = None) -> Tuple[bool, str]: if not os.path.exists(video_path): return False, "❌ 文件不存在" frames = await self.extract_keyframes(video_path) @@ -193,10 +189,10 @@ class VideoAnalyzer: async def analyze_video_from_bytes( self, video_bytes: bytes, - filename: str | None = None, - prompt: str | None = None, - question: str | None = None, - ) -> dict[str, str]: + filename: Optional[str] = None, + prompt: Optional[str] = None, + question: Optional[str] = None, + ) -> Dict[str, str]: """从内存字节分析视频,兼容旧调用 (prompt / question 二选一) 返回 {"summary": str}.""" if not video_bytes: return {"summary": "❌ 空视频数据"} @@ -204,11 +200,17 @@ class VideoAnalyzer: q = prompt if prompt is not None else question video_hash = hashlib.sha256(video_bytes).hexdigest() - # 查缓存(第一次,未加锁) - cached = await self._get_cached(video_hash) - if cached: - logger.info(f"视频缓存命中(预检查) hash={video_hash[:16]}") - return {"summary": cached} + # 查缓存 + try: + async with get_db_session() as session: # type: ignore + row = await session.execute( + Videos.__table__.select().where(Videos.video_hash == video_hash) # type: ignore + ) + existing = row.first() + if existing and existing[Videos.description] and existing[Videos.vlm_processed]: # type: ignore + return {"summary": existing[Videos.description]} # type: ignore + except Exception: # pragma: no cover + pass # 获取锁避免重复处理 async with _locks_guard: @@ -217,11 +219,17 @@ class VideoAnalyzer: lock = asyncio.Lock() _video_locks[video_hash] = lock async with lock: - # 双检缓存 - cached2 = await self._get_cached(video_hash) - if cached2: - logger.info(f"视频缓存命中(锁后) hash={video_hash[:16]}") - return {"summary": cached2} + # 双检:进入锁后再查一次,避免重复处理 + try: + async with get_db_session() as session: # type: ignore + row = await session.execute( + Videos.__table__.select().where(Videos.video_hash == video_hash) # type: ignore + ) + existing = row.first() + if existing and existing[Videos.description] and existing[Videos.vlm_processed]: # type: ignore + return {"summary": existing[Videos.description]} # type: ignore + except Exception: # pragma: no cover + pass try: with tempfile.NamedTemporaryFile(delete=False) as fp: @@ -231,7 +239,26 @@ class VideoAnalyzer: ok, summary = await self.analyze_video(temp_path, q) # 写入缓存(仅成功) if ok: - await self._save_cache(video_hash, summary, len(video_bytes)) + try: + async with get_db_session() as session: # type: ignore + await session.execute( + Videos.__table__.insert().values( + video_id="", + video_hash=video_hash, + description=summary, + count=1, + timestamp=time.time(), + vlm_processed=True, + duration=None, + frame_count=None, + fps=None, + resolution=None, + file_size=len(video_bytes), + ) + ) + await session.commit() + except Exception: # pragma: no cover + pass return {"summary": summary} finally: if os.path.exists(temp_path): @@ -242,57 +269,9 @@ class VideoAnalyzer: except Exception as e: # pragma: no cover return {"summary": f"❌ 处理失败: {e}"} - # ---- 缓存辅助 ---- - async def _get_cached(self, video_hash: str) -> str | None: - try: - async with get_db_session() as session: # type: ignore - result = await session.execute(select(Videos).where(Videos.video_hash == video_hash)) # type: ignore - obj: Videos | None = result.scalar_one_or_none() # type: ignore - if obj and obj.vlm_processed and obj.description: - # 更新使用次数 - try: - await session.execute( - update(Videos) - .where(Videos.id == obj.id) # type: ignore - .values(count=obj.count + 1 if obj.count is not None else 1) - ) - await session.commit() - except Exception: # pragma: no cover - await session.rollback() - return obj.description - except Exception: # pragma: no cover - pass - return None - - async def _save_cache(self, video_hash: str, summary: str, file_size: int) -> None: - try: - async with get_db_session() as session: # type: ignore - stmt = insert(Videos).values( # type: ignore - video_id="", - video_hash=video_hash, - description=summary, - count=1, - timestamp=time.time(), - vlm_processed=True, - duration=None, - frame_count=None, - fps=None, - resolution=None, - file_size=file_size, - ) - try: - await session.execute(stmt) - await session.commit() - logger.debug(f"视频缓存写入 success hash={video_hash}") - except sa_exc.IntegrityError: # 可能并发已写入 - await session.rollback() - logger.debug(f"视频缓存已存在 hash={video_hash}") - except Exception: # pragma: no cover - logger.debug("视频缓存写入失败") - # ---- 外部接口 ---- -_INSTANCE: VideoAnalyzer | None = None +_INSTANCE: Optional[VideoAnalyzer] = None def get_video_analyzer() -> VideoAnalyzer: @@ -306,7 +285,7 @@ def is_video_analysis_available() -> bool: return True -def get_video_analysis_status() -> dict[str, Any]: +def get_video_analysis_status() -> Dict[str, Any]: try: info = video.get_system_info() # type: ignore[attr-defined] except Exception as e: # pragma: no cover @@ -318,4 +297,4 @@ def get_video_analysis_status() -> dict[str, Any]: "modes": ["auto", "batch", "sequential"], "max_frames_default": inst.max_frames, "implementation": "inkfox", - } + } \ No newline at end of file diff --git a/src/common/data_models/llm_data_model.py b/src/common/data_models/llm_data_model.py index 147c2b22b..95fd41520 100644 --- a/src/common/data_models/llm_data_model.py +++ b/src/common/data_models/llm_data_model.py @@ -1,6 +1,8 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any +from src.llm_models.payload_content.tool_option import ToolCall + from . import BaseDataModel if TYPE_CHECKING: diff --git a/src/plugin_system/apis/generator_api.py b/src/plugin_system/apis/generator_api.py index 05f388c2a..458ab1572 100644 --- a/src/plugin_system/apis/generator_api.py +++ b/src/plugin_system/apis/generator_api.py @@ -4,7 +4,7 @@ 提供回复器相关功能,采用标准Python包设计模式 使用方式: from src.plugin_system.apis import generator_api - replyer = generator_api.get_replyer(chat_stream) + replyer = await generator_api.get_replyer(chat_stream) success, reply_set, _ = await generator_api.generate_reply(chat_stream, action_data, reasoning) """ @@ -32,6 +32,7 @@ logger = get_logger("generator_api") # ============================================================================= +async def get_replyer( async def get_replyer( chat_stream: ChatStream | None = None, chat_id: str | None = None, diff --git a/src/plugin_system/base/base_event.py b/src/plugin_system/base/base_event.py index f8c45e54d..582731aa0 100644 --- a/src/plugin_system/base/base_event.py +++ b/src/plugin_system/base/base_event.py @@ -2,6 +2,7 @@ import asyncio from typing import Any from src.common.logger import get_logger +from src.plugin_system.base.base_events_handler import BaseEventHandler logger = get_logger("base_event")