Merge branch 'dev' of https://github.com/MoFox-Studio/MoFox_Bot into dev
This commit is contained in:
@@ -74,7 +74,7 @@ dependencies = [
|
|||||||
"websockets>=15.0.1",
|
"websockets>=15.0.1",
|
||||||
"aiomysql>=0.2.0",
|
"aiomysql>=0.2.0",
|
||||||
"aiosqlite>=0.21.0",
|
"aiosqlite>=0.21.0",
|
||||||
"inkfox>=0.1.0",
|
"inkfox>=0.1.1",
|
||||||
"rrjieba>=0.1.13",
|
"rrjieba>=0.1.13",
|
||||||
"mcp>=0.9.0",
|
"mcp>=0.9.0",
|
||||||
"sse-starlette>=2.2.1",
|
"sse-starlette>=2.2.1",
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ from .memory_system import MemorySystem, MemorySystemConfig, get_memory_system,
|
|||||||
|
|
||||||
# Vector DB存储系统
|
# Vector DB存储系统
|
||||||
from .vector_memory_storage_v2 import VectorMemoryStorage, VectorStorageConfig, get_vector_memory_storage
|
from .vector_memory_storage_v2 import VectorMemoryStorage, VectorStorageConfig, get_vector_memory_storage
|
||||||
|
from .memory_formatter import format_memories_bracket_style
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# 核心数据结构
|
# 核心数据结构
|
||||||
@@ -62,6 +63,8 @@ __all__ = [
|
|||||||
"MemoryActivator",
|
"MemoryActivator",
|
||||||
"memory_activator",
|
"memory_activator",
|
||||||
"enhanced_memory_activator", # 兼容性别名
|
"enhanced_memory_activator", # 兼容性别名
|
||||||
|
# 格式化工具
|
||||||
|
"format_memories_bracket_style",
|
||||||
]
|
]
|
||||||
|
|
||||||
# 版本信息
|
# 版本信息
|
||||||
|
|||||||
118
src/chat/memory_system/memory_formatter.py
Normal file
118
src/chat/memory_system/memory_formatter.py
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
"""记忆格式化工具
|
||||||
|
|
||||||
|
提供统一的记忆块格式化函数,供构建 Prompt 时使用。
|
||||||
|
|
||||||
|
当前使用的函数: format_memories_bracket_style
|
||||||
|
输入: list[dict] 其中每个元素包含:
|
||||||
|
- display: str 记忆可读内容
|
||||||
|
- memory_type: str 记忆类型 (personal_fact/opinion/preference/event 等)
|
||||||
|
- metadata: dict 可选,包括
|
||||||
|
- confidence: 置信度 (str|float)
|
||||||
|
- importance: 重要度 (str|float)
|
||||||
|
- timestamp: 时间戳 (float|str)
|
||||||
|
- source: 来源 (str)
|
||||||
|
- relevance_score: 相关度 (float)
|
||||||
|
|
||||||
|
返回: 适合直接嵌入提示词的大段文本;若无有效记忆返回空串。
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Iterable
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
def _format_timestamp(ts: Any) -> str:
|
||||||
|
try:
|
||||||
|
if ts in (None, ""):
|
||||||
|
return ""
|
||||||
|
if isinstance(ts, (int, float)) and ts > 0:
|
||||||
|
return time.strftime("%Y-%m-%d %H:%M", time.localtime(float(ts)))
|
||||||
|
return str(ts)
|
||||||
|
except Exception:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_str(v: Any) -> str:
|
||||||
|
if v is None:
|
||||||
|
return ""
|
||||||
|
return str(v)
|
||||||
|
|
||||||
|
|
||||||
|
def format_memories_bracket_style(
|
||||||
|
memories: Iterable[dict[str, Any]] | None,
|
||||||
|
query_context: str | None = None,
|
||||||
|
max_items: int = 15,
|
||||||
|
) -> str:
|
||||||
|
"""以方括号 + 标注字段的方式格式化记忆列表。
|
||||||
|
|
||||||
|
例子输出:
|
||||||
|
## 相关记忆回顾
|
||||||
|
- [类型:personal_fact|重要:高|置信:0.83|相关:0.72] 他喜欢黑咖啡 (来源: chat, 2025-10-05 09:30)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
memories: 记忆字典迭代器
|
||||||
|
query_context: 当前查询/用户的消息,用于在首行提示(可选)
|
||||||
|
max_items: 最多输出的记忆条数
|
||||||
|
Returns:
|
||||||
|
str: 格式化文本;若无内容返回空串
|
||||||
|
"""
|
||||||
|
if not memories:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
lines: list[str] = ["## 相关记忆回顾"]
|
||||||
|
if query_context:
|
||||||
|
lines.append(f"(与当前消息相关:{query_context[:60]}{'...' if len(query_context) > 60 else ''})")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
for mem in memories:
|
||||||
|
if count >= max_items:
|
||||||
|
break
|
||||||
|
if not isinstance(mem, dict):
|
||||||
|
continue
|
||||||
|
display = _coerce_str(mem.get("display", "")).strip()
|
||||||
|
if not display:
|
||||||
|
continue
|
||||||
|
|
||||||
|
mtype = _coerce_str(mem.get("memory_type", "fact")) or "fact"
|
||||||
|
meta = mem.get("metadata", {}) if isinstance(mem.get("metadata"), dict) else {}
|
||||||
|
confidence = _coerce_str(meta.get("confidence", ""))
|
||||||
|
importance = _coerce_str(meta.get("importance", ""))
|
||||||
|
source = _coerce_str(meta.get("source", ""))
|
||||||
|
rel = meta.get("relevance_score")
|
||||||
|
try:
|
||||||
|
rel_str = f"{float(rel):.2f}" if rel is not None else ""
|
||||||
|
except Exception:
|
||||||
|
rel_str = ""
|
||||||
|
ts = _format_timestamp(meta.get("timestamp"))
|
||||||
|
|
||||||
|
# 构建标签段
|
||||||
|
tags: list[str] = [f"类型:{mtype}"]
|
||||||
|
if importance:
|
||||||
|
tags.append(f"重要:{importance}")
|
||||||
|
if confidence:
|
||||||
|
tags.append(f"置信:{confidence}")
|
||||||
|
if rel_str:
|
||||||
|
tags.append(f"相关:{rel_str}")
|
||||||
|
|
||||||
|
tag_block = "|".join(tags)
|
||||||
|
suffix_parts = []
|
||||||
|
if source:
|
||||||
|
suffix_parts.append(source)
|
||||||
|
if ts:
|
||||||
|
suffix_parts.append(ts)
|
||||||
|
suffix = (" (" + ", ".join(suffix_parts) + ")") if suffix_parts else ""
|
||||||
|
|
||||||
|
lines.append(f"- [{tag_block}] {display}{suffix}")
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
if count == 0:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
if count >= max_items:
|
||||||
|
lines.append(f"\n(已截断,仅显示前 {max_items} 条相关记忆)")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["format_memories_bracket_style"]
|
||||||
@@ -1,193 +1,61 @@
|
|||||||
"""
|
"""
|
||||||
记忆元数据索引管理器
|
记忆元数据索引。
|
||||||
使用JSON文件存储记忆元数据,支持快速模糊搜索和过滤
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import threading
|
from dataclasses import dataclass, asdict
|
||||||
from dataclasses import asdict, dataclass
|
|
||||||
from datetime import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from time import time
|
||||||
import orjson
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
from inkfox.memory import PyMetadataIndex as _RustIndex # type: ignore
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MemoryMetadataIndexEntry:
|
class MemoryMetadataIndexEntry:
|
||||||
"""元数据索引条目(轻量级,只用于快速过滤)"""
|
|
||||||
|
|
||||||
memory_id: str
|
memory_id: str
|
||||||
user_id: str
|
user_id: str
|
||||||
|
memory_type: str
|
||||||
# 分类信息
|
subjects: list[str]
|
||||||
memory_type: str # MemoryType.value
|
objects: list[str]
|
||||||
subjects: list[str] # 主语列表
|
keywords: list[str]
|
||||||
objects: list[str] # 宾语列表
|
tags: list[str]
|
||||||
keywords: list[str] # 关键词列表
|
importance: int
|
||||||
tags: list[str] # 标签列表
|
confidence: int
|
||||||
|
created_at: float
|
||||||
# 数值字段(用于范围过滤)
|
access_count: int
|
||||||
importance: int # ImportanceLevel.value (1-4)
|
|
||||||
confidence: int # ConfidenceLevel.value (1-4)
|
|
||||||
created_at: float # 创建时间戳
|
|
||||||
access_count: int # 访问次数
|
|
||||||
|
|
||||||
# 可选字段
|
|
||||||
chat_id: str | None = None
|
chat_id: str | None = None
|
||||||
content_preview: str | None = None # 内容预览(前100字符)
|
content_preview: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class MemoryMetadataIndex:
|
class MemoryMetadataIndex:
|
||||||
"""记忆元数据索引管理器"""
|
"""Rust 加速版本唯一实现。"""
|
||||||
|
|
||||||
def __init__(self, index_file: str = "data/memory_metadata_index.json"):
|
def __init__(self, index_file: str = "data/memory_metadata_index.json"):
|
||||||
self.index_file = Path(index_file)
|
self._rust = _RustIndex(index_file)
|
||||||
self.index: dict[str, MemoryMetadataIndexEntry] = {} # memory_id -> entry
|
# 仅为向量层和调试提供最小缓存(长度判断、get_entry 返回)
|
||||||
|
self.index: dict[str, MemoryMetadataIndexEntry] = {}
|
||||||
|
logger.info("✅ MemoryMetadataIndex (Rust) 初始化完成,仅支持加速实现")
|
||||||
|
|
||||||
# 倒排索引(用于快速查找)
|
# 向后代码仍调用的接口:batch_add_or_update / add_or_update
|
||||||
self.type_index: dict[str, set[str]] = {} # type -> {memory_ids}
|
def batch_add_or_update(self, entries: list[MemoryMetadataIndexEntry]):
|
||||||
self.subject_index: dict[str, set[str]] = {} # subject -> {memory_ids}
|
if not entries:
|
||||||
self.keyword_index: dict[str, set[str]] = {} # keyword -> {memory_ids}
|
|
||||||
self.tag_index: dict[str, set[str]] = {} # tag -> {memory_ids}
|
|
||||||
|
|
||||||
self.lock = threading.RLock()
|
|
||||||
|
|
||||||
# 加载已有索引
|
|
||||||
self._load_index()
|
|
||||||
|
|
||||||
def _load_index(self):
|
|
||||||
"""从文件加载索引"""
|
|
||||||
if not self.index_file.exists():
|
|
||||||
logger.info(f"元数据索引文件不存在,将创建新索引: {self.index_file}")
|
|
||||||
return
|
return
|
||||||
|
payload = []
|
||||||
try:
|
for e in entries:
|
||||||
with open(self.index_file, "rb") as f:
|
if not e.memory_id:
|
||||||
data = orjson.loads(f.read())
|
continue
|
||||||
|
self.index[e.memory_id] = e
|
||||||
# 重建内存索引
|
payload.append(asdict(e))
|
||||||
for entry_data in data.get("entries", []):
|
if payload:
|
||||||
entry = MemoryMetadataIndexEntry(**entry_data)
|
try:
|
||||||
self.index[entry.memory_id] = entry
|
self._rust.batch_add(payload)
|
||||||
self._update_inverted_indices(entry)
|
except Exception as ex: # noqa: BLE001
|
||||||
|
logger.error(f"Rust 元数据批量添加失败: {ex}")
|
||||||
logger.info(f"✅ 加载元数据索引: {len(self.index)} 条记录")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"加载元数据索引失败: {e}", exc_info=True)
|
|
||||||
|
|
||||||
def _save_index(self):
|
|
||||||
"""保存索引到文件"""
|
|
||||||
try:
|
|
||||||
# 确保目录存在
|
|
||||||
self.index_file.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# 序列化所有条目
|
|
||||||
entries = [asdict(entry) for entry in self.index.values()]
|
|
||||||
data = {
|
|
||||||
"version": "1.0",
|
|
||||||
"count": len(entries),
|
|
||||||
"last_updated": datetime.now().isoformat(),
|
|
||||||
"entries": entries,
|
|
||||||
}
|
|
||||||
|
|
||||||
# 写入文件(使用临时文件 + 原子重命名)
|
|
||||||
temp_file = self.index_file.with_suffix(".tmp")
|
|
||||||
with open(temp_file, "wb") as f:
|
|
||||||
f.write(orjson.dumps(data, option=orjson.OPT_INDENT_2))
|
|
||||||
|
|
||||||
temp_file.replace(self.index_file)
|
|
||||||
logger.debug(f"元数据索引已保存: {len(entries)} 条记录")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"保存元数据索引失败: {e}", exc_info=True)
|
|
||||||
|
|
||||||
def _update_inverted_indices(self, entry: MemoryMetadataIndexEntry):
|
|
||||||
"""更新倒排索引"""
|
|
||||||
memory_id = entry.memory_id
|
|
||||||
|
|
||||||
# 类型索引
|
|
||||||
self.type_index.setdefault(entry.memory_type, set()).add(memory_id)
|
|
||||||
|
|
||||||
# 主语索引
|
|
||||||
for subject in entry.subjects:
|
|
||||||
subject_norm = subject.strip().lower()
|
|
||||||
if subject_norm:
|
|
||||||
self.subject_index.setdefault(subject_norm, set()).add(memory_id)
|
|
||||||
|
|
||||||
# 关键词索引
|
|
||||||
for keyword in entry.keywords:
|
|
||||||
keyword_norm = keyword.strip().lower()
|
|
||||||
if keyword_norm:
|
|
||||||
self.keyword_index.setdefault(keyword_norm, set()).add(memory_id)
|
|
||||||
|
|
||||||
# 标签索引
|
|
||||||
for tag in entry.tags:
|
|
||||||
tag_norm = tag.strip().lower()
|
|
||||||
if tag_norm:
|
|
||||||
self.tag_index.setdefault(tag_norm, set()).add(memory_id)
|
|
||||||
|
|
||||||
def add_or_update(self, entry: MemoryMetadataIndexEntry):
|
def add_or_update(self, entry: MemoryMetadataIndexEntry):
|
||||||
"""添加或更新索引条目"""
|
self.batch_add_or_update([entry])
|
||||||
with self.lock:
|
|
||||||
# 如果已存在,先从倒排索引中移除旧记录
|
|
||||||
if entry.memory_id in self.index:
|
|
||||||
self._remove_from_inverted_indices(entry.memory_id)
|
|
||||||
|
|
||||||
# 添加新记录
|
|
||||||
self.index[entry.memory_id] = entry
|
|
||||||
self._update_inverted_indices(entry)
|
|
||||||
|
|
||||||
def _remove_from_inverted_indices(self, memory_id: str):
|
|
||||||
"""从倒排索引中移除记录"""
|
|
||||||
if memory_id not in self.index:
|
|
||||||
return
|
|
||||||
|
|
||||||
entry = self.index[memory_id]
|
|
||||||
|
|
||||||
# 从类型索引移除
|
|
||||||
if entry.memory_type in self.type_index:
|
|
||||||
self.type_index[entry.memory_type].discard(memory_id)
|
|
||||||
|
|
||||||
# 从主语索引移除
|
|
||||||
for subject in entry.subjects:
|
|
||||||
subject_norm = subject.strip().lower()
|
|
||||||
if subject_norm in self.subject_index:
|
|
||||||
self.subject_index[subject_norm].discard(memory_id)
|
|
||||||
|
|
||||||
# 从关键词索引移除
|
|
||||||
for keyword in entry.keywords:
|
|
||||||
keyword_norm = keyword.strip().lower()
|
|
||||||
if keyword_norm in self.keyword_index:
|
|
||||||
self.keyword_index[keyword_norm].discard(memory_id)
|
|
||||||
|
|
||||||
# 从标签索引移除
|
|
||||||
for tag in entry.tags:
|
|
||||||
tag_norm = tag.strip().lower()
|
|
||||||
if tag_norm in self.tag_index:
|
|
||||||
self.tag_index[tag_norm].discard(memory_id)
|
|
||||||
|
|
||||||
def remove(self, memory_id: str):
|
|
||||||
"""移除索引条目"""
|
|
||||||
with self.lock:
|
|
||||||
if memory_id in self.index:
|
|
||||||
self._remove_from_inverted_indices(memory_id)
|
|
||||||
del self.index[memory_id]
|
|
||||||
|
|
||||||
def batch_add_or_update(self, entries: list[MemoryMetadataIndexEntry]):
|
|
||||||
"""批量添加或更新"""
|
|
||||||
with self.lock:
|
|
||||||
for entry in entries:
|
|
||||||
self.add_or_update(entry)
|
|
||||||
|
|
||||||
def save(self):
|
|
||||||
"""保存索引到磁盘"""
|
|
||||||
with self.lock:
|
|
||||||
self._save_index()
|
|
||||||
|
|
||||||
def search(
|
def search(
|
||||||
self,
|
self,
|
||||||
@@ -201,287 +69,54 @@ class MemoryMetadataIndex:
|
|||||||
created_before: float | None = None,
|
created_before: float | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
limit: int | None = None,
|
limit: int | None = None,
|
||||||
flexible_mode: bool = True, # 新增:灵活匹配模式
|
flexible_mode: bool = True,
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
"""
|
params: dict[str, Any] = {
|
||||||
搜索符合条件的记忆ID列表(支持模糊匹配)
|
"user_id": user_id,
|
||||||
|
"memory_types": memory_types,
|
||||||
Returns:
|
"subjects": subjects,
|
||||||
List[str]: 符合条件的 memory_id 列表
|
"keywords": keywords,
|
||||||
"""
|
"tags": tags,
|
||||||
with self.lock:
|
"importance_min": importance_min,
|
||||||
|
"importance_max": importance_max,
|
||||||
|
"created_after": created_after,
|
||||||
|
"created_before": created_before,
|
||||||
|
"limit": limit,
|
||||||
|
}
|
||||||
|
params = {k: v for k, v in params.items() if v is not None}
|
||||||
|
try:
|
||||||
if flexible_mode:
|
if flexible_mode:
|
||||||
return self._search_flexible(
|
return list(self._rust.search_flexible(params))
|
||||||
memory_types=memory_types,
|
return list(self._rust.search_strict(params))
|
||||||
subjects=subjects,
|
except Exception as ex: # noqa: BLE001
|
||||||
keywords=keywords, # 保留用于兼容性
|
logger.error(f"Rust 搜索失败返回空: {ex}")
|
||||||
tags=tags, # 保留用于兼容性
|
return []
|
||||||
created_after=created_after,
|
|
||||||
created_before=created_before,
|
|
||||||
user_id=user_id,
|
|
||||||
limit=limit,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return self._search_strict(
|
|
||||||
memory_types=memory_types,
|
|
||||||
subjects=subjects,
|
|
||||||
keywords=keywords,
|
|
||||||
tags=tags,
|
|
||||||
importance_min=importance_min,
|
|
||||||
importance_max=importance_max,
|
|
||||||
created_after=created_after,
|
|
||||||
created_before=created_before,
|
|
||||||
user_id=user_id,
|
|
||||||
limit=limit,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _search_flexible(
|
|
||||||
self,
|
|
||||||
memory_types: list[str] | None = None,
|
|
||||||
subjects: list[str] | None = None,
|
|
||||||
created_after: float | None = None,
|
|
||||||
created_before: float | None = None,
|
|
||||||
user_id: str | None = None,
|
|
||||||
limit: int | None = None,
|
|
||||||
**kwargs, # 接受但不使用的参数
|
|
||||||
) -> list[str]:
|
|
||||||
"""
|
|
||||||
灵活搜索模式:2/4项匹配即可,支持部分匹配
|
|
||||||
|
|
||||||
评分维度:
|
|
||||||
1. 记忆类型匹配 (0-1分)
|
|
||||||
2. 主语匹配 (0-1分)
|
|
||||||
3. 宾语匹配 (0-1分)
|
|
||||||
4. 时间范围匹配 (0-1分)
|
|
||||||
|
|
||||||
总分 >= 2分即视为有效
|
|
||||||
"""
|
|
||||||
# 用户过滤(必选)
|
|
||||||
if user_id:
|
|
||||||
base_candidates = {mid for mid, entry in self.index.items() if entry.user_id == user_id}
|
|
||||||
else:
|
|
||||||
base_candidates = set(self.index.keys())
|
|
||||||
|
|
||||||
scored_candidates = []
|
|
||||||
|
|
||||||
for memory_id in base_candidates:
|
|
||||||
entry = self.index[memory_id]
|
|
||||||
score = 0
|
|
||||||
match_details = []
|
|
||||||
|
|
||||||
# 1. 记忆类型匹配
|
|
||||||
if memory_types:
|
|
||||||
type_score = 0
|
|
||||||
for mtype in memory_types:
|
|
||||||
if entry.memory_type == mtype:
|
|
||||||
type_score = 1
|
|
||||||
break
|
|
||||||
# 部分匹配:类型名称包含
|
|
||||||
if mtype.lower() in entry.memory_type.lower() or entry.memory_type.lower() in mtype.lower():
|
|
||||||
type_score = 0.5
|
|
||||||
break
|
|
||||||
score += type_score
|
|
||||||
if type_score > 0:
|
|
||||||
match_details.append(f"类型:{entry.memory_type}")
|
|
||||||
else:
|
|
||||||
match_details.append("类型:未指定")
|
|
||||||
|
|
||||||
# 2. 主语匹配(支持部分匹配)
|
|
||||||
if subjects:
|
|
||||||
subject_score = 0
|
|
||||||
for subject in subjects:
|
|
||||||
subject_norm = subject.strip().lower()
|
|
||||||
for entry_subject in entry.subjects:
|
|
||||||
entry_subject_norm = entry_subject.strip().lower()
|
|
||||||
# 精确匹配
|
|
||||||
if subject_norm == entry_subject_norm:
|
|
||||||
subject_score = 1
|
|
||||||
break
|
|
||||||
# 部分匹配:包含关系
|
|
||||||
if subject_norm in entry_subject_norm or entry_subject_norm in subject_norm:
|
|
||||||
subject_score = 0.6
|
|
||||||
break
|
|
||||||
if subject_score == 1:
|
|
||||||
break
|
|
||||||
score += subject_score
|
|
||||||
if subject_score > 0:
|
|
||||||
match_details.append("主语:匹配")
|
|
||||||
else:
|
|
||||||
match_details.append("主语:未指定")
|
|
||||||
|
|
||||||
# 3. 宾语匹配(支持部分匹配)
|
|
||||||
object_score = 0
|
|
||||||
if entry.objects:
|
|
||||||
for entry_object in entry.objects:
|
|
||||||
entry_object_norm = str(entry_object).strip().lower()
|
|
||||||
# 检查是否与主语相关(主宾关联)
|
|
||||||
for subject in subjects or []:
|
|
||||||
subject_norm = subject.strip().lower()
|
|
||||||
if subject_norm in entry_object_norm or entry_object_norm in subject_norm:
|
|
||||||
object_score = 0.8
|
|
||||||
match_details.append("宾语:主宾关联")
|
|
||||||
break
|
|
||||||
if object_score > 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
score += object_score
|
|
||||||
if object_score > 0:
|
|
||||||
match_details.append("宾语:匹配")
|
|
||||||
elif not entry.objects:
|
|
||||||
match_details.append("宾语:无")
|
|
||||||
|
|
||||||
# 4. 时间范围匹配
|
|
||||||
time_score = 0
|
|
||||||
if created_after is not None or created_before is not None:
|
|
||||||
time_match = True
|
|
||||||
if created_after is not None and entry.created_at < created_after:
|
|
||||||
time_match = False
|
|
||||||
if created_before is not None and entry.created_at > created_before:
|
|
||||||
time_match = False
|
|
||||||
if time_match:
|
|
||||||
time_score = 1
|
|
||||||
match_details.append("时间:匹配")
|
|
||||||
else:
|
|
||||||
match_details.append("时间:不匹配")
|
|
||||||
else:
|
|
||||||
match_details.append("时间:未指定")
|
|
||||||
|
|
||||||
score += time_score
|
|
||||||
|
|
||||||
# 只有总分 >= 2 的记忆才会被返回
|
|
||||||
if score >= 2:
|
|
||||||
scored_candidates.append((memory_id, score, match_details))
|
|
||||||
|
|
||||||
# 按分数和时间排序
|
|
||||||
scored_candidates.sort(key=lambda x: (x[1], self.index[x[0]].created_at), reverse=True)
|
|
||||||
|
|
||||||
if limit:
|
|
||||||
result_ids = [mid for mid, _, _ in scored_candidates[:limit]]
|
|
||||||
else:
|
|
||||||
result_ids = [mid for mid, _, _ in scored_candidates]
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"[灵活搜索] 过滤条件: types={memory_types}, subjects={subjects}, "
|
|
||||||
f"time_range=[{created_after}, {created_before}], 返回={len(result_ids)}条"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 记录匹配统计
|
|
||||||
if scored_candidates and len(scored_candidates) > 0:
|
|
||||||
avg_score = sum(score for _, score, _ in scored_candidates) / len(scored_candidates)
|
|
||||||
logger.debug(f"[灵活搜索] 平均匹配分数: {avg_score:.2f}, 最高分: {scored_candidates[0][1]:.2f}")
|
|
||||||
|
|
||||||
return result_ids
|
|
||||||
|
|
||||||
def _search_strict(
|
|
||||||
self,
|
|
||||||
memory_types: list[str] | None = None,
|
|
||||||
subjects: list[str] | None = None,
|
|
||||||
keywords: list[str] | None = None,
|
|
||||||
tags: list[str] | None = None,
|
|
||||||
importance_min: int | None = None,
|
|
||||||
importance_max: int | None = None,
|
|
||||||
created_after: float | None = None,
|
|
||||||
created_before: float | None = None,
|
|
||||||
user_id: str | None = None,
|
|
||||||
limit: int | None = None,
|
|
||||||
) -> list[str]:
|
|
||||||
"""严格搜索模式(原有逻辑)"""
|
|
||||||
# 初始候选集(所有记忆)
|
|
||||||
candidate_ids: set[str] | None = None
|
|
||||||
|
|
||||||
# 用户过滤(必选)
|
|
||||||
if user_id:
|
|
||||||
candidate_ids = {mid for mid, entry in self.index.items() if entry.user_id == user_id}
|
|
||||||
else:
|
|
||||||
candidate_ids = set(self.index.keys())
|
|
||||||
|
|
||||||
# 类型过滤(OR关系)
|
|
||||||
if memory_types:
|
|
||||||
type_ids = set()
|
|
||||||
for mtype in memory_types:
|
|
||||||
type_ids.update(self.type_index.get(mtype, set()))
|
|
||||||
candidate_ids &= type_ids
|
|
||||||
|
|
||||||
# 主语过滤(OR关系,支持模糊匹配)
|
|
||||||
if subjects:
|
|
||||||
subject_ids = set()
|
|
||||||
for subject in subjects:
|
|
||||||
subject_norm = subject.strip().lower()
|
|
||||||
# 精确匹配
|
|
||||||
if subject_norm in self.subject_index:
|
|
||||||
subject_ids.update(self.subject_index[subject_norm])
|
|
||||||
# 模糊匹配(包含)
|
|
||||||
for indexed_subject, ids in self.subject_index.items():
|
|
||||||
if subject_norm in indexed_subject or indexed_subject in subject_norm:
|
|
||||||
subject_ids.update(ids)
|
|
||||||
candidate_ids &= subject_ids
|
|
||||||
|
|
||||||
# 关键词过滤(OR关系,支持模糊匹配)
|
|
||||||
if keywords:
|
|
||||||
keyword_ids = set()
|
|
||||||
for keyword in keywords:
|
|
||||||
keyword_norm = keyword.strip().lower()
|
|
||||||
# 精确匹配
|
|
||||||
if keyword_norm in self.keyword_index:
|
|
||||||
keyword_ids.update(self.keyword_index[keyword_norm])
|
|
||||||
# 模糊匹配(包含)
|
|
||||||
for indexed_keyword, ids in self.keyword_index.items():
|
|
||||||
if keyword_norm in indexed_keyword or indexed_keyword in keyword_norm:
|
|
||||||
keyword_ids.update(ids)
|
|
||||||
candidate_ids &= keyword_ids
|
|
||||||
|
|
||||||
# 标签过滤(OR关系)
|
|
||||||
if tags:
|
|
||||||
tag_ids = set()
|
|
||||||
for tag in tags:
|
|
||||||
tag_norm = tag.strip().lower()
|
|
||||||
tag_ids.update(self.tag_index.get(tag_norm, set()))
|
|
||||||
candidate_ids &= tag_ids
|
|
||||||
|
|
||||||
# 重要性过滤
|
|
||||||
if importance_min is not None or importance_max is not None:
|
|
||||||
importance_ids = {
|
|
||||||
mid
|
|
||||||
for mid in candidate_ids
|
|
||||||
if (importance_min is None or self.index[mid].importance >= importance_min)
|
|
||||||
and (importance_max is None or self.index[mid].importance <= importance_max)
|
|
||||||
}
|
|
||||||
candidate_ids &= importance_ids
|
|
||||||
|
|
||||||
# 时间范围过滤
|
|
||||||
if created_after is not None or created_before is not None:
|
|
||||||
time_ids = {
|
|
||||||
mid
|
|
||||||
for mid in candidate_ids
|
|
||||||
if (created_after is None or self.index[mid].created_at >= created_after)
|
|
||||||
and (created_before is None or self.index[mid].created_at <= created_before)
|
|
||||||
}
|
|
||||||
candidate_ids &= time_ids
|
|
||||||
|
|
||||||
# 转换为列表并排序(按创建时间倒序)
|
|
||||||
result_ids = sorted(candidate_ids, key=lambda mid: self.index[mid].created_at, reverse=True)
|
|
||||||
|
|
||||||
# 限制数量
|
|
||||||
if limit:
|
|
||||||
result_ids = result_ids[:limit]
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"[严格搜索] types={memory_types}, subjects={subjects}, keywords={keywords}, 返回={len(result_ids)}条"
|
|
||||||
)
|
|
||||||
|
|
||||||
return result_ids
|
|
||||||
|
|
||||||
def get_entry(self, memory_id: str) -> MemoryMetadataIndexEntry | None:
|
def get_entry(self, memory_id: str) -> MemoryMetadataIndexEntry | None:
|
||||||
"""获取单个索引条目"""
|
|
||||||
return self.index.get(memory_id)
|
return self.index.get(memory_id)
|
||||||
|
|
||||||
def get_stats(self) -> dict[str, Any]:
|
def get_stats(self) -> dict[str, Any]:
|
||||||
"""获取索引统计信息"""
|
try:
|
||||||
with self.lock:
|
raw = self._rust.stats()
|
||||||
return {
|
return {
|
||||||
"total_memories": len(self.index),
|
"total_memories": raw.get("total", 0),
|
||||||
"types": {mtype: len(ids) for mtype, ids in self.type_index.items()},
|
"types": raw.get("types_dist", {}),
|
||||||
"subjects_count": len(self.subject_index),
|
"subjects_count": raw.get("subjects_indexed", 0),
|
||||||
"keywords_count": len(self.keyword_index),
|
"keywords_count": raw.get("keywords_indexed", 0),
|
||||||
"tags_count": len(self.tag_index),
|
"tags_count": raw.get("tags_indexed", 0),
|
||||||
}
|
}
|
||||||
|
except Exception as ex: # noqa: BLE001
|
||||||
|
logger.warning(f"读取 Rust stats 失败: {ex}")
|
||||||
|
return {"total_memories": 0}
|
||||||
|
|
||||||
|
def save(self): # 仅调用 rust save
|
||||||
|
try:
|
||||||
|
self._rust.save()
|
||||||
|
except Exception as ex: # noqa: BLE001
|
||||||
|
logger.warning(f"Rust save 失败: {ex}")
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"MemoryMetadataIndexEntry",
|
||||||
|
"MemoryMetadataIndex",
|
||||||
|
]
|
||||||
|
|||||||
@@ -263,7 +263,7 @@ class MessageRecv(Message):
|
|||||||
logger.warning("视频消息中没有base64数据")
|
logger.warning("视频消息中没有base64数据")
|
||||||
return "[收到视频消息,但数据异常]"
|
return "[收到视频消息,但数据异常]"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"视频处理失败: {e!s}")
|
logger.error(f"视频处理失败: {str(e)}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
logger.error(f"错误详情: {traceback.format_exc()}")
|
logger.error(f"错误详情: {traceback.format_exc()}")
|
||||||
@@ -277,7 +277,7 @@ class MessageRecv(Message):
|
|||||||
logger.info("未启用视频识别")
|
logger.info("未启用视频识别")
|
||||||
return "[视频]"
|
return "[视频]"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}")
|
logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}")
|
||||||
return f"[处理失败的{segment.type}消息]"
|
return f"[处理失败的{segment.type}消息]"
|
||||||
|
|
||||||
|
|
||||||
@@ -427,7 +427,7 @@ class MessageRecvS4U(MessageRecv):
|
|||||||
|
|
||||||
# 使用video analyzer分析视频
|
# 使用video analyzer分析视频
|
||||||
video_analyzer = get_video_analyzer()
|
video_analyzer = get_video_analyzer()
|
||||||
result = await video_analyzer.analyze_video_from_bytes(
|
result = await video_analyzer.analyze_video(
|
||||||
video_bytes, filename, prompt=global_config.video_analysis.batch_analysis_prompt
|
video_bytes, filename, prompt=global_config.video_analysis.batch_analysis_prompt
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ class ReplyerManager:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._repliers: dict[str, DefaultReplyer] = {}
|
self._repliers: dict[str, DefaultReplyer] = {}
|
||||||
|
|
||||||
|
async def get_replyer(
|
||||||
async def get_replyer(
|
async def get_replyer(
|
||||||
self,
|
self,
|
||||||
chat_stream: ChatStream | None = None,
|
chat_stream: ChatStream | None = None,
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
"""纯 inkfox 视频关键帧分析工具
|
"""纯 inkfox 视频关键帧分析工具
|
||||||
|
|
||||||
仅依赖 `inkfox.video` 提供的 Rust 扩展能力:
|
仅依赖 `inkfox.video` 提供的 Rust 扩展能力:
|
||||||
@@ -13,27 +14,25 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import io
|
||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
import hashlib
|
|
||||||
import io
|
|
||||||
import os
|
|
||||||
import tempfile
|
import tempfile
|
||||||
import time
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import List, Tuple, Optional, Dict, Any
|
||||||
|
import hashlib
|
||||||
|
import time
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from sqlalchemy import exc as sa_exc # type: ignore
|
|
||||||
from sqlalchemy import insert, select, update # type: ignore
|
|
||||||
|
|
||||||
from src.common.database.sqlalchemy_models import Videos, get_db_session # type: ignore
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
from src.common.database.sqlalchemy_models import Videos, get_db_session # type: ignore
|
||||||
|
|
||||||
# 简易并发控制:同一 hash 只处理一次
|
# 简易并发控制:同一 hash 只处理一次
|
||||||
_video_locks: dict[str, asyncio.Lock] = {}
|
_video_locks: Dict[str, asyncio.Lock] = {}
|
||||||
_locks_guard = asyncio.Lock()
|
_locks_guard = asyncio.Lock()
|
||||||
|
|
||||||
logger = get_logger("utils_video")
|
logger = get_logger("utils_video")
|
||||||
@@ -91,7 +90,7 @@ class VideoAnalyzer:
|
|||||||
logger.debug(f"获取系统信息失败: {e}")
|
logger.debug(f"获取系统信息失败: {e}")
|
||||||
|
|
||||||
# ---- 关键帧提取 ----
|
# ---- 关键帧提取 ----
|
||||||
async def extract_keyframes(self, video_path: str) -> list[tuple[str, float]]:
|
async def extract_keyframes(self, video_path: str) -> List[Tuple[str, float]]:
|
||||||
"""提取关键帧并返回 (base64, timestamp_seconds) 列表"""
|
"""提取关键帧并返回 (base64, timestamp_seconds) 列表"""
|
||||||
with tempfile.TemporaryDirectory() as tmp:
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
result = video.extract_keyframes_from_video( # type: ignore[attr-defined]
|
result = video.extract_keyframes_from_video( # type: ignore[attr-defined]
|
||||||
@@ -106,7 +105,7 @@ class VideoAnalyzer:
|
|||||||
)
|
)
|
||||||
files = sorted(Path(tmp).glob("keyframe_*.jpg"))[: self.max_frames]
|
files = sorted(Path(tmp).glob("keyframe_*.jpg"))[: self.max_frames]
|
||||||
total_ms = getattr(result, "total_time_ms", 0)
|
total_ms = getattr(result, "total_time_ms", 0)
|
||||||
frames: list[tuple[str, float]] = []
|
frames: List[Tuple[str, float]] = []
|
||||||
for i, f in enumerate(files):
|
for i, f in enumerate(files):
|
||||||
img = Image.open(f).convert("RGB")
|
img = Image.open(f).convert("RGB")
|
||||||
if max(img.size) > self.max_image_size:
|
if max(img.size) > self.max_image_size:
|
||||||
@@ -120,41 +119,38 @@ class VideoAnalyzer:
|
|||||||
return frames
|
return frames
|
||||||
|
|
||||||
# ---- 批量分析 ----
|
# ---- 批量分析 ----
|
||||||
async def _analyze_batch(self, frames: list[tuple[str, float]], question: str | None) -> str:
|
async def _analyze_batch(self, frames: List[Tuple[str, float]], question: Optional[str]) -> str:
|
||||||
from src.llm_models.payload_content.message import MessageBuilder
|
from src.llm_models.payload_content.message import MessageBuilder, RoleType
|
||||||
from src.llm_models.utils_model import RequestType
|
from src.llm_models.utils_model import RequestType
|
||||||
|
|
||||||
prompt = self.batch_analysis_prompt.format(
|
prompt = self.batch_analysis_prompt.format(
|
||||||
personality_core=self.personality_core, personality_side=self.personality_side
|
personality_core=self.personality_core, personality_side=self.personality_side
|
||||||
)
|
)
|
||||||
if question:
|
if question:
|
||||||
prompt += f"\n用户关注: {question}"
|
prompt += f"\n用户关注: {question}"
|
||||||
|
|
||||||
desc = [
|
desc = [
|
||||||
(f"第{i+1}帧 (时间: {ts:.2f}s)" if self.enable_frame_timing else f"第{i+1}帧")
|
(f"第{i+1}帧 (时间: {ts:.2f}s)" if self.enable_frame_timing else f"第{i+1}帧")
|
||||||
for i, (_b, ts) in enumerate(frames)
|
for i, (_b, ts) in enumerate(frames)
|
||||||
]
|
]
|
||||||
prompt += "\n帧列表: " + ", ".join(desc)
|
prompt += "\n帧列表: " + ", ".join(desc)
|
||||||
|
mb = MessageBuilder().set_role(RoleType.User).add_text_content(prompt)
|
||||||
message_builder = MessageBuilder().add_text_content(prompt)
|
|
||||||
for b64, _ in frames:
|
for b64, _ in frames:
|
||||||
message_builder.add_image_content(image_format="jpeg", image_base64=b64)
|
mb.add_image_content("jpeg", b64)
|
||||||
messages = [message_builder.build()]
|
message = mb.build()
|
||||||
|
model_info, api_provider, client = self.video_llm._select_model()
|
||||||
# 使用封装好的高级策略执行请求,而不是直接调用内部方法
|
resp = await self.video_llm._execute_request(
|
||||||
response, _ = await self.video_llm._strategy.execute_with_failover(
|
api_provider=api_provider,
|
||||||
RequestType.RESPONSE,
|
client=client,
|
||||||
raise_when_empty=False, # 即使失败也返回默认值,避免程序崩溃
|
request_type=RequestType.RESPONSE,
|
||||||
message_list=messages,
|
model_info=model_info,
|
||||||
temperature=self.video_llm.model_for_task.temperature,
|
message_list=[message],
|
||||||
max_tokens=self.video_llm.model_for_task.max_tokens,
|
temperature=None,
|
||||||
|
max_tokens=None,
|
||||||
)
|
)
|
||||||
|
return resp.content or "❌ 未获得响应"
|
||||||
return response.content or "❌ 未获得响应"
|
|
||||||
|
|
||||||
# ---- 逐帧分析 ----
|
# ---- 逐帧分析 ----
|
||||||
async def _analyze_sequential(self, frames: list[tuple[str, float]], question: str | None) -> str:
|
async def _analyze_sequential(self, frames: List[Tuple[str, float]], question: Optional[str]) -> str:
|
||||||
results: list[str] = []
|
results: List[str] = []
|
||||||
for i, (b64, ts) in enumerate(frames):
|
for i, (b64, ts) in enumerate(frames):
|
||||||
prompt = f"分析第{i+1}帧" + (f" (时间: {ts:.2f}s)" if self.enable_frame_timing else "")
|
prompt = f"分析第{i+1}帧" + (f" (时间: {ts:.2f}s)" if self.enable_frame_timing else "")
|
||||||
if question:
|
if question:
|
||||||
@@ -178,7 +174,7 @@ class VideoAnalyzer:
|
|||||||
return "\n".join(results)
|
return "\n".join(results)
|
||||||
|
|
||||||
# ---- 主入口 ----
|
# ---- 主入口 ----
|
||||||
async def analyze_video(self, video_path: str, question: str | None = None) -> tuple[bool, str]:
|
async def analyze_video(self, video_path: str, question: Optional[str] = None) -> Tuple[bool, str]:
|
||||||
if not os.path.exists(video_path):
|
if not os.path.exists(video_path):
|
||||||
return False, "❌ 文件不存在"
|
return False, "❌ 文件不存在"
|
||||||
frames = await self.extract_keyframes(video_path)
|
frames = await self.extract_keyframes(video_path)
|
||||||
@@ -193,10 +189,10 @@ class VideoAnalyzer:
|
|||||||
async def analyze_video_from_bytes(
|
async def analyze_video_from_bytes(
|
||||||
self,
|
self,
|
||||||
video_bytes: bytes,
|
video_bytes: bytes,
|
||||||
filename: str | None = None,
|
filename: Optional[str] = None,
|
||||||
prompt: str | None = None,
|
prompt: Optional[str] = None,
|
||||||
question: str | None = None,
|
question: Optional[str] = None,
|
||||||
) -> dict[str, str]:
|
) -> Dict[str, str]:
|
||||||
"""从内存字节分析视频,兼容旧调用 (prompt / question 二选一) 返回 {"summary": str}."""
|
"""从内存字节分析视频,兼容旧调用 (prompt / question 二选一) 返回 {"summary": str}."""
|
||||||
if not video_bytes:
|
if not video_bytes:
|
||||||
return {"summary": "❌ 空视频数据"}
|
return {"summary": "❌ 空视频数据"}
|
||||||
@@ -204,11 +200,17 @@ class VideoAnalyzer:
|
|||||||
q = prompt if prompt is not None else question
|
q = prompt if prompt is not None else question
|
||||||
video_hash = hashlib.sha256(video_bytes).hexdigest()
|
video_hash = hashlib.sha256(video_bytes).hexdigest()
|
||||||
|
|
||||||
# 查缓存(第一次,未加锁)
|
# 查缓存
|
||||||
cached = await self._get_cached(video_hash)
|
try:
|
||||||
if cached:
|
async with get_db_session() as session: # type: ignore
|
||||||
logger.info(f"视频缓存命中(预检查) hash={video_hash[:16]}")
|
row = await session.execute(
|
||||||
return {"summary": cached}
|
Videos.__table__.select().where(Videos.video_hash == video_hash) # type: ignore
|
||||||
|
)
|
||||||
|
existing = row.first()
|
||||||
|
if existing and existing[Videos.description] and existing[Videos.vlm_processed]: # type: ignore
|
||||||
|
return {"summary": existing[Videos.description]} # type: ignore
|
||||||
|
except Exception: # pragma: no cover
|
||||||
|
pass
|
||||||
|
|
||||||
# 获取锁避免重复处理
|
# 获取锁避免重复处理
|
||||||
async with _locks_guard:
|
async with _locks_guard:
|
||||||
@@ -217,11 +219,17 @@ class VideoAnalyzer:
|
|||||||
lock = asyncio.Lock()
|
lock = asyncio.Lock()
|
||||||
_video_locks[video_hash] = lock
|
_video_locks[video_hash] = lock
|
||||||
async with lock:
|
async with lock:
|
||||||
# 双检缓存
|
# 双检:进入锁后再查一次,避免重复处理
|
||||||
cached2 = await self._get_cached(video_hash)
|
try:
|
||||||
if cached2:
|
async with get_db_session() as session: # type: ignore
|
||||||
logger.info(f"视频缓存命中(锁后) hash={video_hash[:16]}")
|
row = await session.execute(
|
||||||
return {"summary": cached2}
|
Videos.__table__.select().where(Videos.video_hash == video_hash) # type: ignore
|
||||||
|
)
|
||||||
|
existing = row.first()
|
||||||
|
if existing and existing[Videos.description] and existing[Videos.vlm_processed]: # type: ignore
|
||||||
|
return {"summary": existing[Videos.description]} # type: ignore
|
||||||
|
except Exception: # pragma: no cover
|
||||||
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with tempfile.NamedTemporaryFile(delete=False) as fp:
|
with tempfile.NamedTemporaryFile(delete=False) as fp:
|
||||||
@@ -231,7 +239,26 @@ class VideoAnalyzer:
|
|||||||
ok, summary = await self.analyze_video(temp_path, q)
|
ok, summary = await self.analyze_video(temp_path, q)
|
||||||
# 写入缓存(仅成功)
|
# 写入缓存(仅成功)
|
||||||
if ok:
|
if ok:
|
||||||
await self._save_cache(video_hash, summary, len(video_bytes))
|
try:
|
||||||
|
async with get_db_session() as session: # type: ignore
|
||||||
|
await session.execute(
|
||||||
|
Videos.__table__.insert().values(
|
||||||
|
video_id="",
|
||||||
|
video_hash=video_hash,
|
||||||
|
description=summary,
|
||||||
|
count=1,
|
||||||
|
timestamp=time.time(),
|
||||||
|
vlm_processed=True,
|
||||||
|
duration=None,
|
||||||
|
frame_count=None,
|
||||||
|
fps=None,
|
||||||
|
resolution=None,
|
||||||
|
file_size=len(video_bytes),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
except Exception: # pragma: no cover
|
||||||
|
pass
|
||||||
return {"summary": summary}
|
return {"summary": summary}
|
||||||
finally:
|
finally:
|
||||||
if os.path.exists(temp_path):
|
if os.path.exists(temp_path):
|
||||||
@@ -242,57 +269,9 @@ class VideoAnalyzer:
|
|||||||
except Exception as e: # pragma: no cover
|
except Exception as e: # pragma: no cover
|
||||||
return {"summary": f"❌ 处理失败: {e}"}
|
return {"summary": f"❌ 处理失败: {e}"}
|
||||||
|
|
||||||
# ---- 缓存辅助 ----
|
|
||||||
async def _get_cached(self, video_hash: str) -> str | None:
|
|
||||||
try:
|
|
||||||
async with get_db_session() as session: # type: ignore
|
|
||||||
result = await session.execute(select(Videos).where(Videos.video_hash == video_hash)) # type: ignore
|
|
||||||
obj: Videos | None = result.scalar_one_or_none() # type: ignore
|
|
||||||
if obj and obj.vlm_processed and obj.description:
|
|
||||||
# 更新使用次数
|
|
||||||
try:
|
|
||||||
await session.execute(
|
|
||||||
update(Videos)
|
|
||||||
.where(Videos.id == obj.id) # type: ignore
|
|
||||||
.values(count=obj.count + 1 if obj.count is not None else 1)
|
|
||||||
)
|
|
||||||
await session.commit()
|
|
||||||
except Exception: # pragma: no cover
|
|
||||||
await session.rollback()
|
|
||||||
return obj.description
|
|
||||||
except Exception: # pragma: no cover
|
|
||||||
pass
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def _save_cache(self, video_hash: str, summary: str, file_size: int) -> None:
|
|
||||||
try:
|
|
||||||
async with get_db_session() as session: # type: ignore
|
|
||||||
stmt = insert(Videos).values( # type: ignore
|
|
||||||
video_id="",
|
|
||||||
video_hash=video_hash,
|
|
||||||
description=summary,
|
|
||||||
count=1,
|
|
||||||
timestamp=time.time(),
|
|
||||||
vlm_processed=True,
|
|
||||||
duration=None,
|
|
||||||
frame_count=None,
|
|
||||||
fps=None,
|
|
||||||
resolution=None,
|
|
||||||
file_size=file_size,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
await session.execute(stmt)
|
|
||||||
await session.commit()
|
|
||||||
logger.debug(f"视频缓存写入 success hash={video_hash}")
|
|
||||||
except sa_exc.IntegrityError: # 可能并发已写入
|
|
||||||
await session.rollback()
|
|
||||||
logger.debug(f"视频缓存已存在 hash={video_hash}")
|
|
||||||
except Exception: # pragma: no cover
|
|
||||||
logger.debug("视频缓存写入失败")
|
|
||||||
|
|
||||||
|
|
||||||
# ---- 外部接口 ----
|
# ---- 外部接口 ----
|
||||||
_INSTANCE: VideoAnalyzer | None = None
|
_INSTANCE: Optional[VideoAnalyzer] = None
|
||||||
|
|
||||||
|
|
||||||
def get_video_analyzer() -> VideoAnalyzer:
|
def get_video_analyzer() -> VideoAnalyzer:
|
||||||
@@ -306,7 +285,7 @@ def is_video_analysis_available() -> bool:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def get_video_analysis_status() -> dict[str, Any]:
|
def get_video_analysis_status() -> Dict[str, Any]:
|
||||||
try:
|
try:
|
||||||
info = video.get_system_info() # type: ignore[attr-defined]
|
info = video.get_system_info() # type: ignore[attr-defined]
|
||||||
except Exception as e: # pragma: no cover
|
except Exception as e: # pragma: no cover
|
||||||
@@ -318,4 +297,4 @@ def get_video_analysis_status() -> dict[str, Any]:
|
|||||||
"modes": ["auto", "batch", "sequential"],
|
"modes": ["auto", "batch", "sequential"],
|
||||||
"max_frames_default": inst.max_frames,
|
"max_frames_default": inst.max_frames,
|
||||||
"implementation": "inkfox",
|
"implementation": "inkfox",
|
||||||
}
|
}
|
||||||
@@ -1,6 +1,8 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from src.llm_models.payload_content.tool_option import ToolCall
|
||||||
|
|
||||||
from . import BaseDataModel
|
from . import BaseDataModel
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
提供回复器相关功能,采用标准Python包设计模式
|
提供回复器相关功能,采用标准Python包设计模式
|
||||||
使用方式:
|
使用方式:
|
||||||
from src.plugin_system.apis import generator_api
|
from src.plugin_system.apis import generator_api
|
||||||
replyer = generator_api.get_replyer(chat_stream)
|
replyer = await generator_api.get_replyer(chat_stream)
|
||||||
success, reply_set, _ = await generator_api.generate_reply(chat_stream, action_data, reasoning)
|
success, reply_set, _ = await generator_api.generate_reply(chat_stream, action_data, reasoning)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -32,6 +32,7 @@ logger = get_logger("generator_api")
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
async def get_replyer(
|
||||||
async def get_replyer(
|
async def get_replyer(
|
||||||
chat_stream: ChatStream | None = None,
|
chat_stream: ChatStream | None = None,
|
||||||
chat_id: str | None = None,
|
chat_id: str | None = None,
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ import asyncio
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
||||||
|
|
||||||
logger = get_logger("base_event")
|
logger = get_logger("base_event")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user