This commit is contained in:
Windpicker-owo
2025-10-05 18:31:56 +08:00
10 changed files with 284 additions and 544 deletions

View File

@@ -74,7 +74,7 @@ dependencies = [
"websockets>=15.0.1",
"aiomysql>=0.2.0",
"aiosqlite>=0.21.0",
"inkfox>=0.1.0",
"inkfox>=0.1.1",
"rrjieba>=0.1.13",
"mcp>=0.9.0",
"sse-starlette>=2.2.1",

View File

@@ -30,6 +30,7 @@ from .memory_system import MemorySystem, MemorySystemConfig, get_memory_system,
# Vector DB存储系统
from .vector_memory_storage_v2 import VectorMemoryStorage, VectorStorageConfig, get_vector_memory_storage
from .memory_formatter import format_memories_bracket_style
__all__ = [
# 核心数据结构
@@ -62,6 +63,8 @@ __all__ = [
"MemoryActivator",
"memory_activator",
"enhanced_memory_activator", # 兼容性别名
# 格式化工具
"format_memories_bracket_style",
]
# 版本信息

View File

@@ -0,0 +1,118 @@
"""记忆格式化工具
提供统一的记忆块格式化函数,供构建 Prompt 时使用。
当前使用的函数: format_memories_bracket_style
输入: list[dict] 其中每个元素包含:
- display: str 记忆可读内容
- memory_type: str 记忆类型 (personal_fact/opinion/preference/event 等)
- metadata: dict 可选,包括
- confidence: 置信度 (str|float)
- importance: 重要度 (str|float)
- timestamp: 时间戳 (float|str)
- source: 来源 (str)
- relevance_score: 相关度 (float)
返回: 适合直接嵌入提示词的大段文本;若无有效记忆返回空串。
"""
from __future__ import annotations
from typing import Any, Iterable
import time
def _format_timestamp(ts: Any) -> str:
try:
if ts in (None, ""):
return ""
if isinstance(ts, (int, float)) and ts > 0:
return time.strftime("%Y-%m-%d %H:%M", time.localtime(float(ts)))
return str(ts)
except Exception:
return ""
def _coerce_str(v: Any) -> str:
if v is None:
return ""
return str(v)
def format_memories_bracket_style(
memories: Iterable[dict[str, Any]] | None,
query_context: str | None = None,
max_items: int = 15,
) -> str:
"""以方括号 + 标注字段的方式格式化记忆列表。
例子输出:
## 相关记忆回顾
- [类型:personal_fact|重要:高|置信:0.83|相关:0.72] 他喜欢黑咖啡 (来源: chat, 2025-10-05 09:30)
Args:
memories: 记忆字典迭代器
query_context: 当前查询/用户的消息,用于在首行提示(可选)
max_items: 最多输出的记忆条数
Returns:
str: 格式化文本;若无内容返回空串
"""
if not memories:
return ""
lines: list[str] = ["## 相关记忆回顾"]
if query_context:
lines.append(f"(与当前消息相关:{query_context[:60]}{'...' if len(query_context) > 60 else ''}")
lines.append("")
count = 0
for mem in memories:
if count >= max_items:
break
if not isinstance(mem, dict):
continue
display = _coerce_str(mem.get("display", "")).strip()
if not display:
continue
mtype = _coerce_str(mem.get("memory_type", "fact")) or "fact"
meta = mem.get("metadata", {}) if isinstance(mem.get("metadata"), dict) else {}
confidence = _coerce_str(meta.get("confidence", ""))
importance = _coerce_str(meta.get("importance", ""))
source = _coerce_str(meta.get("source", ""))
rel = meta.get("relevance_score")
try:
rel_str = f"{float(rel):.2f}" if rel is not None else ""
except Exception:
rel_str = ""
ts = _format_timestamp(meta.get("timestamp"))
# 构建标签段
tags: list[str] = [f"类型:{mtype}"]
if importance:
tags.append(f"重要:{importance}")
if confidence:
tags.append(f"置信:{confidence}")
if rel_str:
tags.append(f"相关:{rel_str}")
tag_block = "|".join(tags)
suffix_parts = []
if source:
suffix_parts.append(source)
if ts:
suffix_parts.append(ts)
suffix = (" (" + ", ".join(suffix_parts) + ")") if suffix_parts else ""
lines.append(f"- [{tag_block}] {display}{suffix}")
count += 1
if count == 0:
return ""
if count >= max_items:
lines.append(f"\n(已截断,仅显示前 {max_items} 条相关记忆)")
return "\n".join(lines)
__all__ = ["format_memories_bracket_style"]

View File

@@ -1,193 +1,61 @@
"""
记忆元数据索引管理器
使用JSON文件存储记忆元数据支持快速模糊搜索和过滤
记忆元数据索引
"""
import threading
from dataclasses import asdict, dataclass
from datetime import datetime
from pathlib import Path
from dataclasses import dataclass, asdict
from typing import Any
import orjson
from time import time
from src.common.logger import get_logger
logger = get_logger(__name__)
from inkfox.memory import PyMetadataIndex as _RustIndex # type: ignore
@dataclass
class MemoryMetadataIndexEntry:
"""元数据索引条目(轻量级,只用于快速过滤)"""
memory_id: str
user_id: str
# 分类信息
memory_type: str # MemoryType.value
subjects: list[str] # 主语列表
objects: list[str] # 宾语列表
keywords: list[str] # 关键词列表
tags: list[str] # 标签列表
# 数值字段(用于范围过滤)
importance: int # ImportanceLevel.value (1-4)
confidence: int # ConfidenceLevel.value (1-4)
created_at: float # 创建时间戳
access_count: int # 访问次数
# 可选字段
memory_type: str
subjects: list[str]
objects: list[str]
keywords: list[str]
tags: list[str]
importance: int
confidence: int
created_at: float
access_count: int
chat_id: str | None = None
content_preview: str | None = None # 内容预览前100字符
content_preview: str | None = None
class MemoryMetadataIndex:
"""记忆元数据索引管理器"""
"""Rust 加速版本唯一实现。"""
def __init__(self, index_file: str = "data/memory_metadata_index.json"):
self.index_file = Path(index_file)
self.index: dict[str, MemoryMetadataIndexEntry] = {} # memory_id -> entry
self._rust = _RustIndex(index_file)
# 仅为向量层和调试提供最小缓存长度判断、get_entry 返回)
self.index: dict[str, MemoryMetadataIndexEntry] = {}
logger.info("✅ MemoryMetadataIndex (Rust) 初始化完成,仅支持加速实现")
# 倒排索引(用于快速查找)
self.type_index: dict[str, set[str]] = {} # type -> {memory_ids}
self.subject_index: dict[str, set[str]] = {} # subject -> {memory_ids}
self.keyword_index: dict[str, set[str]] = {} # keyword -> {memory_ids}
self.tag_index: dict[str, set[str]] = {} # tag -> {memory_ids}
self.lock = threading.RLock()
# 加载已有索引
self._load_index()
def _load_index(self):
"""从文件加载索引"""
if not self.index_file.exists():
logger.info(f"元数据索引文件不存在,将创建新索引: {self.index_file}")
# 向后代码仍调用的接口batch_add_or_update / add_or_update
def batch_add_or_update(self, entries: list[MemoryMetadataIndexEntry]):
if not entries:
return
try:
with open(self.index_file, "rb") as f:
data = orjson.loads(f.read())
# 重建内存索引
for entry_data in data.get("entries", []):
entry = MemoryMetadataIndexEntry(**entry_data)
self.index[entry.memory_id] = entry
self._update_inverted_indices(entry)
logger.info(f"✅ 加载元数据索引: {len(self.index)} 条记录")
except Exception as e:
logger.error(f"加载元数据索引失败: {e}", exc_info=True)
def _save_index(self):
"""保存索引到文件"""
try:
# 确保目录存在
self.index_file.parent.mkdir(parents=True, exist_ok=True)
# 序列化所有条目
entries = [asdict(entry) for entry in self.index.values()]
data = {
"version": "1.0",
"count": len(entries),
"last_updated": datetime.now().isoformat(),
"entries": entries,
}
# 写入文件(使用临时文件 + 原子重命名)
temp_file = self.index_file.with_suffix(".tmp")
with open(temp_file, "wb") as f:
f.write(orjson.dumps(data, option=orjson.OPT_INDENT_2))
temp_file.replace(self.index_file)
logger.debug(f"元数据索引已保存: {len(entries)} 条记录")
except Exception as e:
logger.error(f"保存元数据索引失败: {e}", exc_info=True)
def _update_inverted_indices(self, entry: MemoryMetadataIndexEntry):
"""更新倒排索引"""
memory_id = entry.memory_id
# 类型索引
self.type_index.setdefault(entry.memory_type, set()).add(memory_id)
# 主语索引
for subject in entry.subjects:
subject_norm = subject.strip().lower()
if subject_norm:
self.subject_index.setdefault(subject_norm, set()).add(memory_id)
# 关键词索引
for keyword in entry.keywords:
keyword_norm = keyword.strip().lower()
if keyword_norm:
self.keyword_index.setdefault(keyword_norm, set()).add(memory_id)
# 标签索引
for tag in entry.tags:
tag_norm = tag.strip().lower()
if tag_norm:
self.tag_index.setdefault(tag_norm, set()).add(memory_id)
payload = []
for e in entries:
if not e.memory_id:
continue
self.index[e.memory_id] = e
payload.append(asdict(e))
if payload:
try:
self._rust.batch_add(payload)
except Exception as ex: # noqa: BLE001
logger.error(f"Rust 元数据批量添加失败: {ex}")
def add_or_update(self, entry: MemoryMetadataIndexEntry):
"""添加或更新索引条目"""
with self.lock:
# 如果已存在,先从倒排索引中移除旧记录
if entry.memory_id in self.index:
self._remove_from_inverted_indices(entry.memory_id)
# 添加新记录
self.index[entry.memory_id] = entry
self._update_inverted_indices(entry)
def _remove_from_inverted_indices(self, memory_id: str):
"""从倒排索引中移除记录"""
if memory_id not in self.index:
return
entry = self.index[memory_id]
# 从类型索引移除
if entry.memory_type in self.type_index:
self.type_index[entry.memory_type].discard(memory_id)
# 从主语索引移除
for subject in entry.subjects:
subject_norm = subject.strip().lower()
if subject_norm in self.subject_index:
self.subject_index[subject_norm].discard(memory_id)
# 从关键词索引移除
for keyword in entry.keywords:
keyword_norm = keyword.strip().lower()
if keyword_norm in self.keyword_index:
self.keyword_index[keyword_norm].discard(memory_id)
# 从标签索引移除
for tag in entry.tags:
tag_norm = tag.strip().lower()
if tag_norm in self.tag_index:
self.tag_index[tag_norm].discard(memory_id)
def remove(self, memory_id: str):
"""移除索引条目"""
with self.lock:
if memory_id in self.index:
self._remove_from_inverted_indices(memory_id)
del self.index[memory_id]
def batch_add_or_update(self, entries: list[MemoryMetadataIndexEntry]):
"""批量添加或更新"""
with self.lock:
for entry in entries:
self.add_or_update(entry)
def save(self):
"""保存索引到磁盘"""
with self.lock:
self._save_index()
self.batch_add_or_update([entry])
def search(
self,
@@ -201,287 +69,54 @@ class MemoryMetadataIndex:
created_before: float | None = None,
user_id: str | None = None,
limit: int | None = None,
flexible_mode: bool = True, # 新增:灵活匹配模式
flexible_mode: bool = True,
) -> list[str]:
"""
搜索符合条件的记忆ID列表支持模糊匹配
Returns:
List[str]: 符合条件的 memory_id 列表
"""
with self.lock:
params: dict[str, Any] = {
"user_id": user_id,
"memory_types": memory_types,
"subjects": subjects,
"keywords": keywords,
"tags": tags,
"importance_min": importance_min,
"importance_max": importance_max,
"created_after": created_after,
"created_before": created_before,
"limit": limit,
}
params = {k: v for k, v in params.items() if v is not None}
try:
if flexible_mode:
return self._search_flexible(
memory_types=memory_types,
subjects=subjects,
keywords=keywords, # 保留用于兼容性
tags=tags, # 保留用于兼容性
created_after=created_after,
created_before=created_before,
user_id=user_id,
limit=limit,
)
else:
return self._search_strict(
memory_types=memory_types,
subjects=subjects,
keywords=keywords,
tags=tags,
importance_min=importance_min,
importance_max=importance_max,
created_after=created_after,
created_before=created_before,
user_id=user_id,
limit=limit,
)
def _search_flexible(
self,
memory_types: list[str] | None = None,
subjects: list[str] | None = None,
created_after: float | None = None,
created_before: float | None = None,
user_id: str | None = None,
limit: int | None = None,
**kwargs, # 接受但不使用的参数
) -> list[str]:
"""
灵活搜索模式2/4项匹配即可支持部分匹配
评分维度:
1. 记忆类型匹配 (0-1分)
2. 主语匹配 (0-1分)
3. 宾语匹配 (0-1分)
4. 时间范围匹配 (0-1分)
总分 >= 2分即视为有效
"""
# 用户过滤(必选)
if user_id:
base_candidates = {mid for mid, entry in self.index.items() if entry.user_id == user_id}
else:
base_candidates = set(self.index.keys())
scored_candidates = []
for memory_id in base_candidates:
entry = self.index[memory_id]
score = 0
match_details = []
# 1. 记忆类型匹配
if memory_types:
type_score = 0
for mtype in memory_types:
if entry.memory_type == mtype:
type_score = 1
break
# 部分匹配:类型名称包含
if mtype.lower() in entry.memory_type.lower() or entry.memory_type.lower() in mtype.lower():
type_score = 0.5
break
score += type_score
if type_score > 0:
match_details.append(f"类型:{entry.memory_type}")
else:
match_details.append("类型:未指定")
# 2. 主语匹配(支持部分匹配)
if subjects:
subject_score = 0
for subject in subjects:
subject_norm = subject.strip().lower()
for entry_subject in entry.subjects:
entry_subject_norm = entry_subject.strip().lower()
# 精确匹配
if subject_norm == entry_subject_norm:
subject_score = 1
break
# 部分匹配:包含关系
if subject_norm in entry_subject_norm or entry_subject_norm in subject_norm:
subject_score = 0.6
break
if subject_score == 1:
break
score += subject_score
if subject_score > 0:
match_details.append("主语:匹配")
else:
match_details.append("主语:未指定")
# 3. 宾语匹配(支持部分匹配)
object_score = 0
if entry.objects:
for entry_object in entry.objects:
entry_object_norm = str(entry_object).strip().lower()
# 检查是否与主语相关(主宾关联)
for subject in subjects or []:
subject_norm = subject.strip().lower()
if subject_norm in entry_object_norm or entry_object_norm in subject_norm:
object_score = 0.8
match_details.append("宾语:主宾关联")
break
if object_score > 0:
break
score += object_score
if object_score > 0:
match_details.append("宾语:匹配")
elif not entry.objects:
match_details.append("宾语:无")
# 4. 时间范围匹配
time_score = 0
if created_after is not None or created_before is not None:
time_match = True
if created_after is not None and entry.created_at < created_after:
time_match = False
if created_before is not None and entry.created_at > created_before:
time_match = False
if time_match:
time_score = 1
match_details.append("时间:匹配")
else:
match_details.append("时间:不匹配")
else:
match_details.append("时间:未指定")
score += time_score
# 只有总分 >= 2 的记忆才会被返回
if score >= 2:
scored_candidates.append((memory_id, score, match_details))
# 按分数和时间排序
scored_candidates.sort(key=lambda x: (x[1], self.index[x[0]].created_at), reverse=True)
if limit:
result_ids = [mid for mid, _, _ in scored_candidates[:limit]]
else:
result_ids = [mid for mid, _, _ in scored_candidates]
logger.debug(
f"[灵活搜索] 过滤条件: types={memory_types}, subjects={subjects}, "
f"time_range=[{created_after}, {created_before}], 返回={len(result_ids)}"
)
# 记录匹配统计
if scored_candidates and len(scored_candidates) > 0:
avg_score = sum(score for _, score, _ in scored_candidates) / len(scored_candidates)
logger.debug(f"[灵活搜索] 平均匹配分数: {avg_score:.2f}, 最高分: {scored_candidates[0][1]:.2f}")
return result_ids
def _search_strict(
self,
memory_types: list[str] | None = None,
subjects: list[str] | None = None,
keywords: list[str] | None = None,
tags: list[str] | None = None,
importance_min: int | None = None,
importance_max: int | None = None,
created_after: float | None = None,
created_before: float | None = None,
user_id: str | None = None,
limit: int | None = None,
) -> list[str]:
"""严格搜索模式(原有逻辑)"""
# 初始候选集(所有记忆)
candidate_ids: set[str] | None = None
# 用户过滤(必选)
if user_id:
candidate_ids = {mid for mid, entry in self.index.items() if entry.user_id == user_id}
else:
candidate_ids = set(self.index.keys())
# 类型过滤OR关系
if memory_types:
type_ids = set()
for mtype in memory_types:
type_ids.update(self.type_index.get(mtype, set()))
candidate_ids &= type_ids
# 主语过滤OR关系支持模糊匹配
if subjects:
subject_ids = set()
for subject in subjects:
subject_norm = subject.strip().lower()
# 精确匹配
if subject_norm in self.subject_index:
subject_ids.update(self.subject_index[subject_norm])
# 模糊匹配(包含)
for indexed_subject, ids in self.subject_index.items():
if subject_norm in indexed_subject or indexed_subject in subject_norm:
subject_ids.update(ids)
candidate_ids &= subject_ids
# 关键词过滤OR关系支持模糊匹配
if keywords:
keyword_ids = set()
for keyword in keywords:
keyword_norm = keyword.strip().lower()
# 精确匹配
if keyword_norm in self.keyword_index:
keyword_ids.update(self.keyword_index[keyword_norm])
# 模糊匹配(包含)
for indexed_keyword, ids in self.keyword_index.items():
if keyword_norm in indexed_keyword or indexed_keyword in keyword_norm:
keyword_ids.update(ids)
candidate_ids &= keyword_ids
# 标签过滤OR关系
if tags:
tag_ids = set()
for tag in tags:
tag_norm = tag.strip().lower()
tag_ids.update(self.tag_index.get(tag_norm, set()))
candidate_ids &= tag_ids
# 重要性过滤
if importance_min is not None or importance_max is not None:
importance_ids = {
mid
for mid in candidate_ids
if (importance_min is None or self.index[mid].importance >= importance_min)
and (importance_max is None or self.index[mid].importance <= importance_max)
}
candidate_ids &= importance_ids
# 时间范围过滤
if created_after is not None or created_before is not None:
time_ids = {
mid
for mid in candidate_ids
if (created_after is None or self.index[mid].created_at >= created_after)
and (created_before is None or self.index[mid].created_at <= created_before)
}
candidate_ids &= time_ids
# 转换为列表并排序(按创建时间倒序)
result_ids = sorted(candidate_ids, key=lambda mid: self.index[mid].created_at, reverse=True)
# 限制数量
if limit:
result_ids = result_ids[:limit]
logger.debug(
f"[严格搜索] types={memory_types}, subjects={subjects}, keywords={keywords}, 返回={len(result_ids)}"
)
return result_ids
return list(self._rust.search_flexible(params))
return list(self._rust.search_strict(params))
except Exception as ex: # noqa: BLE001
logger.error(f"Rust 搜索失败返回空: {ex}")
return []
def get_entry(self, memory_id: str) -> MemoryMetadataIndexEntry | None:
"""获取单个索引条目"""
return self.index.get(memory_id)
def get_stats(self) -> dict[str, Any]:
"""获取索引统计信息"""
with self.lock:
try:
raw = self._rust.stats()
return {
"total_memories": len(self.index),
"types": {mtype: len(ids) for mtype, ids in self.type_index.items()},
"subjects_count": len(self.subject_index),
"keywords_count": len(self.keyword_index),
"tags_count": len(self.tag_index),
"total_memories": raw.get("total", 0),
"types": raw.get("types_dist", {}),
"subjects_count": raw.get("subjects_indexed", 0),
"keywords_count": raw.get("keywords_indexed", 0),
"tags_count": raw.get("tags_indexed", 0),
}
except Exception as ex: # noqa: BLE001
logger.warning(f"读取 Rust stats 失败: {ex}")
return {"total_memories": 0}
def save(self): # 仅调用 rust save
try:
self._rust.save()
except Exception as ex: # noqa: BLE001
logger.warning(f"Rust save 失败: {ex}")
__all__ = [
"MemoryMetadataIndexEntry",
"MemoryMetadataIndex",
]

View File

@@ -263,7 +263,7 @@ class MessageRecv(Message):
logger.warning("视频消息中没有base64数据")
return "[收到视频消息,但数据异常]"
except Exception as e:
logger.error(f"视频处理失败: {e!s}")
logger.error(f"视频处理失败: {str(e)}")
import traceback
logger.error(f"错误详情: {traceback.format_exc()}")
@@ -277,7 +277,7 @@ class MessageRecv(Message):
logger.info("未启用视频识别")
return "[视频]"
except Exception as e:
logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}")
logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}")
return f"[处理失败的{segment.type}消息]"
@@ -427,7 +427,7 @@ class MessageRecvS4U(MessageRecv):
# 使用video analyzer分析视频
video_analyzer = get_video_analyzer()
result = await video_analyzer.analyze_video_from_bytes(
result = await video_analyzer.analyze_video(
video_bytes, filename, prompt=global_config.video_analysis.batch_analysis_prompt
)

View File

@@ -9,6 +9,7 @@ class ReplyerManager:
def __init__(self):
self._repliers: dict[str, DefaultReplyer] = {}
async def get_replyer(
async def get_replyer(
self,
chat_stream: ChatStream | None = None,

View File

@@ -1,4 +1,5 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""纯 inkfox 视频关键帧分析工具
仅依赖 `inkfox.video` 提供的 Rust 扩展能力:
@@ -13,27 +14,25 @@
from __future__ import annotations
import os
import io
import asyncio
import base64
import hashlib
import io
import os
import tempfile
import time
from pathlib import Path
from typing import Any
from typing import List, Tuple, Optional, Dict, Any
import hashlib
import time
from PIL import Image
from sqlalchemy import exc as sa_exc # type: ignore
from sqlalchemy import insert, select, update # type: ignore
from src.common.database.sqlalchemy_models import Videos, get_db_session # type: ignore
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.common.database.sqlalchemy_models import Videos, get_db_session # type: ignore
# 简易并发控制:同一 hash 只处理一次
_video_locks: dict[str, asyncio.Lock] = {}
_video_locks: Dict[str, asyncio.Lock] = {}
_locks_guard = asyncio.Lock()
logger = get_logger("utils_video")
@@ -91,7 +90,7 @@ class VideoAnalyzer:
logger.debug(f"获取系统信息失败: {e}")
# ---- 关键帧提取 ----
async def extract_keyframes(self, video_path: str) -> list[tuple[str, float]]:
async def extract_keyframes(self, video_path: str) -> List[Tuple[str, float]]:
"""提取关键帧并返回 (base64, timestamp_seconds) 列表"""
with tempfile.TemporaryDirectory() as tmp:
result = video.extract_keyframes_from_video( # type: ignore[attr-defined]
@@ -106,7 +105,7 @@ class VideoAnalyzer:
)
files = sorted(Path(tmp).glob("keyframe_*.jpg"))[: self.max_frames]
total_ms = getattr(result, "total_time_ms", 0)
frames: list[tuple[str, float]] = []
frames: List[Tuple[str, float]] = []
for i, f in enumerate(files):
img = Image.open(f).convert("RGB")
if max(img.size) > self.max_image_size:
@@ -120,41 +119,38 @@ class VideoAnalyzer:
return frames
# ---- 批量分析 ----
async def _analyze_batch(self, frames: list[tuple[str, float]], question: str | None) -> str:
from src.llm_models.payload_content.message import MessageBuilder
async def _analyze_batch(self, frames: List[Tuple[str, float]], question: Optional[str]) -> str:
from src.llm_models.payload_content.message import MessageBuilder, RoleType
from src.llm_models.utils_model import RequestType
prompt = self.batch_analysis_prompt.format(
personality_core=self.personality_core, personality_side=self.personality_side
)
if question:
prompt += f"\n用户关注: {question}"
desc = [
(f"{i+1}帧 (时间: {ts:.2f}s)" if self.enable_frame_timing else f"{i+1}")
for i, (_b, ts) in enumerate(frames)
]
prompt += "\n帧列表: " + ", ".join(desc)
message_builder = MessageBuilder().add_text_content(prompt)
mb = MessageBuilder().set_role(RoleType.User).add_text_content(prompt)
for b64, _ in frames:
message_builder.add_image_content(image_format="jpeg", image_base64=b64)
messages = [message_builder.build()]
# 使用封装好的高级策略执行请求,而不是直接调用内部方法
response, _ = await self.video_llm._strategy.execute_with_failover(
RequestType.RESPONSE,
raise_when_empty=False, # 即使失败也返回默认值,避免程序崩溃
message_list=messages,
temperature=self.video_llm.model_for_task.temperature,
max_tokens=self.video_llm.model_for_task.max_tokens,
mb.add_image_content("jpeg", b64)
message = mb.build()
model_info, api_provider, client = self.video_llm._select_model()
resp = await self.video_llm._execute_request(
api_provider=api_provider,
client=client,
request_type=RequestType.RESPONSE,
model_info=model_info,
message_list=[message],
temperature=None,
max_tokens=None,
)
return response.content or "❌ 未获得响应"
return resp.content or "❌ 未获得响应"
# ---- 逐帧分析 ----
async def _analyze_sequential(self, frames: list[tuple[str, float]], question: str | None) -> str:
results: list[str] = []
async def _analyze_sequential(self, frames: List[Tuple[str, float]], question: Optional[str]) -> str:
results: List[str] = []
for i, (b64, ts) in enumerate(frames):
prompt = f"分析第{i+1}" + (f" (时间: {ts:.2f}s)" if self.enable_frame_timing else "")
if question:
@@ -178,7 +174,7 @@ class VideoAnalyzer:
return "\n".join(results)
# ---- 主入口 ----
async def analyze_video(self, video_path: str, question: str | None = None) -> tuple[bool, str]:
async def analyze_video(self, video_path: str, question: Optional[str] = None) -> Tuple[bool, str]:
if not os.path.exists(video_path):
return False, "❌ 文件不存在"
frames = await self.extract_keyframes(video_path)
@@ -193,10 +189,10 @@ class VideoAnalyzer:
async def analyze_video_from_bytes(
self,
video_bytes: bytes,
filename: str | None = None,
prompt: str | None = None,
question: str | None = None,
) -> dict[str, str]:
filename: Optional[str] = None,
prompt: Optional[str] = None,
question: Optional[str] = None,
) -> Dict[str, str]:
"""从内存字节分析视频,兼容旧调用 (prompt / question 二选一) 返回 {"summary": str}."""
if not video_bytes:
return {"summary": "❌ 空视频数据"}
@@ -204,11 +200,17 @@ class VideoAnalyzer:
q = prompt if prompt is not None else question
video_hash = hashlib.sha256(video_bytes).hexdigest()
# 查缓存(第一次,未加锁)
cached = await self._get_cached(video_hash)
if cached:
logger.info(f"视频缓存命中(预检查) hash={video_hash[:16]}")
return {"summary": cached}
# 查缓存
try:
async with get_db_session() as session: # type: ignore
row = await session.execute(
Videos.__table__.select().where(Videos.video_hash == video_hash) # type: ignore
)
existing = row.first()
if existing and existing[Videos.description] and existing[Videos.vlm_processed]: # type: ignore
return {"summary": existing[Videos.description]} # type: ignore
except Exception: # pragma: no cover
pass
# 获取锁避免重复处理
async with _locks_guard:
@@ -217,11 +219,17 @@ class VideoAnalyzer:
lock = asyncio.Lock()
_video_locks[video_hash] = lock
async with lock:
# 双检缓存
cached2 = await self._get_cached(video_hash)
if cached2:
logger.info(f"视频缓存命中(锁后) hash={video_hash[:16]}")
return {"summary": cached2}
# 双检:进入锁后再查一次,避免重复处理
try:
async with get_db_session() as session: # type: ignore
row = await session.execute(
Videos.__table__.select().where(Videos.video_hash == video_hash) # type: ignore
)
existing = row.first()
if existing and existing[Videos.description] and existing[Videos.vlm_processed]: # type: ignore
return {"summary": existing[Videos.description]} # type: ignore
except Exception: # pragma: no cover
pass
try:
with tempfile.NamedTemporaryFile(delete=False) as fp:
@@ -231,7 +239,26 @@ class VideoAnalyzer:
ok, summary = await self.analyze_video(temp_path, q)
# 写入缓存(仅成功)
if ok:
await self._save_cache(video_hash, summary, len(video_bytes))
try:
async with get_db_session() as session: # type: ignore
await session.execute(
Videos.__table__.insert().values(
video_id="",
video_hash=video_hash,
description=summary,
count=1,
timestamp=time.time(),
vlm_processed=True,
duration=None,
frame_count=None,
fps=None,
resolution=None,
file_size=len(video_bytes),
)
)
await session.commit()
except Exception: # pragma: no cover
pass
return {"summary": summary}
finally:
if os.path.exists(temp_path):
@@ -242,57 +269,9 @@ class VideoAnalyzer:
except Exception as e: # pragma: no cover
return {"summary": f"❌ 处理失败: {e}"}
# ---- 缓存辅助 ----
async def _get_cached(self, video_hash: str) -> str | None:
try:
async with get_db_session() as session: # type: ignore
result = await session.execute(select(Videos).where(Videos.video_hash == video_hash)) # type: ignore
obj: Videos | None = result.scalar_one_or_none() # type: ignore
if obj and obj.vlm_processed and obj.description:
# 更新使用次数
try:
await session.execute(
update(Videos)
.where(Videos.id == obj.id) # type: ignore
.values(count=obj.count + 1 if obj.count is not None else 1)
)
await session.commit()
except Exception: # pragma: no cover
await session.rollback()
return obj.description
except Exception: # pragma: no cover
pass
return None
async def _save_cache(self, video_hash: str, summary: str, file_size: int) -> None:
try:
async with get_db_session() as session: # type: ignore
stmt = insert(Videos).values( # type: ignore
video_id="",
video_hash=video_hash,
description=summary,
count=1,
timestamp=time.time(),
vlm_processed=True,
duration=None,
frame_count=None,
fps=None,
resolution=None,
file_size=file_size,
)
try:
await session.execute(stmt)
await session.commit()
logger.debug(f"视频缓存写入 success hash={video_hash}")
except sa_exc.IntegrityError: # 可能并发已写入
await session.rollback()
logger.debug(f"视频缓存已存在 hash={video_hash}")
except Exception: # pragma: no cover
logger.debug("视频缓存写入失败")
# ---- 外部接口 ----
_INSTANCE: VideoAnalyzer | None = None
_INSTANCE: Optional[VideoAnalyzer] = None
def get_video_analyzer() -> VideoAnalyzer:
@@ -306,7 +285,7 @@ def is_video_analysis_available() -> bool:
return True
def get_video_analysis_status() -> dict[str, Any]:
def get_video_analysis_status() -> Dict[str, Any]:
try:
info = video.get_system_info() # type: ignore[attr-defined]
except Exception as e: # pragma: no cover
@@ -318,4 +297,4 @@ def get_video_analysis_status() -> dict[str, Any]:
"modes": ["auto", "batch", "sequential"],
"max_frames_default": inst.max_frames,
"implementation": "inkfox",
}
}

View File

@@ -1,6 +1,8 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
from src.llm_models.payload_content.tool_option import ToolCall
from . import BaseDataModel
if TYPE_CHECKING:

View File

@@ -4,7 +4,7 @@
提供回复器相关功能采用标准Python包设计模式
使用方式:
from src.plugin_system.apis import generator_api
replyer = generator_api.get_replyer(chat_stream)
replyer = await generator_api.get_replyer(chat_stream)
success, reply_set, _ = await generator_api.generate_reply(chat_stream, action_data, reasoning)
"""
@@ -32,6 +32,7 @@ logger = get_logger("generator_api")
# =============================================================================
async def get_replyer(
async def get_replyer(
chat_stream: ChatStream | None = None,
chat_id: str | None = None,

View File

@@ -2,6 +2,7 @@ import asyncio
from typing import Any
from src.common.logger import get_logger
from src.plugin_system.base.base_events_handler import BaseEventHandler
logger = get_logger("base_event")