Refactor memory metadata index to use Rust backend
Replaces the Python implementation of MemoryMetadataIndex with a Rust-accelerated version, removing legacy code and fallback logic. Updates search, add, and stats methods to delegate to the Rust backend. Also adds missing imports in llm_data_model.py and base_event.py for improved type support and event handling.
This commit is contained in:
@@ -1,193 +1,71 @@
|
||||
"""
|
||||
记忆元数据索引管理器
|
||||
使用JSON文件存储记忆元数据,支持快速模糊搜索和过滤
|
||||
记忆元数据索引。
|
||||
"""
|
||||
|
||||
import threading
|
||||
from dataclasses import asdict, dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass, asdict
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
from time import time
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 仅允许规范导入路径:from inkfox.memory import PyMetadataIndex
|
||||
try: # pragma: no cover
|
||||
from inkfox.memory import PyMetadataIndex as _RustIndex # type: ignore
|
||||
logger.debug("已从 inkfox.memory 成功导入 PyMetadataIndex")
|
||||
except Exception as ex: # noqa: BLE001
|
||||
# 不再做任何回退;强制要求正确的 Rust 模块子模块注册
|
||||
raise RuntimeError(
|
||||
"无法导入 inkfox.memory.PyMetadataIndex: %s\n"
|
||||
"请确认: 1) 已在当前虚拟环境下执行 'maturin develop --release' 安装扩展; "
|
||||
"2) 运行进程使用同一个 venv; 3) 没有旧的 'inkfox' 目录遮蔽 so/pyd; 4) Python 版本与编译匹配" % ex
|
||||
) from ex
|
||||
|
||||
@dataclass
|
||||
class MemoryMetadataIndexEntry:
|
||||
"""元数据索引条目(轻量级,只用于快速过滤)"""
|
||||
|
||||
memory_id: str
|
||||
user_id: str
|
||||
|
||||
# 分类信息
|
||||
memory_type: str # MemoryType.value
|
||||
subjects: list[str] # 主语列表
|
||||
objects: list[str] # 宾语列表
|
||||
keywords: list[str] # 关键词列表
|
||||
tags: list[str] # 标签列表
|
||||
|
||||
# 数值字段(用于范围过滤)
|
||||
importance: int # ImportanceLevel.value (1-4)
|
||||
confidence: int # ConfidenceLevel.value (1-4)
|
||||
created_at: float # 创建时间戳
|
||||
access_count: int # 访问次数
|
||||
|
||||
# 可选字段
|
||||
memory_type: str
|
||||
subjects: list[str]
|
||||
objects: list[str]
|
||||
keywords: list[str]
|
||||
tags: list[str]
|
||||
importance: int
|
||||
confidence: int
|
||||
created_at: float
|
||||
access_count: int
|
||||
chat_id: str | None = None
|
||||
content_preview: str | None = None # 内容预览(前100字符)
|
||||
content_preview: str | None = None
|
||||
|
||||
|
||||
class MemoryMetadataIndex:
|
||||
"""记忆元数据索引管理器"""
|
||||
"""Rust 加速版本唯一实现。"""
|
||||
|
||||
def __init__(self, index_file: str = "data/memory_metadata_index.json"):
|
||||
self.index_file = Path(index_file)
|
||||
self.index: dict[str, MemoryMetadataIndexEntry] = {} # memory_id -> entry
|
||||
self._rust = _RustIndex(index_file)
|
||||
# 仅为向量层和调试提供最小缓存(长度判断、get_entry 返回)
|
||||
self.index: dict[str, MemoryMetadataIndexEntry] = {}
|
||||
logger.info("✅ MemoryMetadataIndex (Rust) 初始化完成,仅支持加速实现")
|
||||
|
||||
# 倒排索引(用于快速查找)
|
||||
self.type_index: dict[str, set[str]] = {} # type -> {memory_ids}
|
||||
self.subject_index: dict[str, set[str]] = {} # subject -> {memory_ids}
|
||||
self.keyword_index: dict[str, set[str]] = {} # keyword -> {memory_ids}
|
||||
self.tag_index: dict[str, set[str]] = {} # tag -> {memory_ids}
|
||||
|
||||
self.lock = threading.RLock()
|
||||
|
||||
# 加载已有索引
|
||||
self._load_index()
|
||||
|
||||
def _load_index(self):
|
||||
"""从文件加载索引"""
|
||||
if not self.index_file.exists():
|
||||
logger.info(f"元数据索引文件不存在,将创建新索引: {self.index_file}")
|
||||
# 向后代码仍调用的接口:batch_add_or_update / add_or_update
|
||||
def batch_add_or_update(self, entries: list[MemoryMetadataIndexEntry]):
|
||||
if not entries:
|
||||
return
|
||||
|
||||
try:
|
||||
with open(self.index_file, "rb") as f:
|
||||
data = orjson.loads(f.read())
|
||||
|
||||
# 重建内存索引
|
||||
for entry_data in data.get("entries", []):
|
||||
entry = MemoryMetadataIndexEntry(**entry_data)
|
||||
self.index[entry.memory_id] = entry
|
||||
self._update_inverted_indices(entry)
|
||||
|
||||
logger.info(f"✅ 加载元数据索引: {len(self.index)} 条记录")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载元数据索引失败: {e}", exc_info=True)
|
||||
|
||||
def _save_index(self):
|
||||
"""保存索引到文件"""
|
||||
try:
|
||||
# 确保目录存在
|
||||
self.index_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 序列化所有条目
|
||||
entries = [asdict(entry) for entry in self.index.values()]
|
||||
data = {
|
||||
"version": "1.0",
|
||||
"count": len(entries),
|
||||
"last_updated": datetime.now().isoformat(),
|
||||
"entries": entries,
|
||||
}
|
||||
|
||||
# 写入文件(使用临时文件 + 原子重命名)
|
||||
temp_file = self.index_file.with_suffix(".tmp")
|
||||
with open(temp_file, "wb") as f:
|
||||
f.write(orjson.dumps(data, option=orjson.OPT_INDENT_2))
|
||||
|
||||
temp_file.replace(self.index_file)
|
||||
logger.debug(f"元数据索引已保存: {len(entries)} 条记录")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存元数据索引失败: {e}", exc_info=True)
|
||||
|
||||
def _update_inverted_indices(self, entry: MemoryMetadataIndexEntry):
|
||||
"""更新倒排索引"""
|
||||
memory_id = entry.memory_id
|
||||
|
||||
# 类型索引
|
||||
self.type_index.setdefault(entry.memory_type, set()).add(memory_id)
|
||||
|
||||
# 主语索引
|
||||
for subject in entry.subjects:
|
||||
subject_norm = subject.strip().lower()
|
||||
if subject_norm:
|
||||
self.subject_index.setdefault(subject_norm, set()).add(memory_id)
|
||||
|
||||
# 关键词索引
|
||||
for keyword in entry.keywords:
|
||||
keyword_norm = keyword.strip().lower()
|
||||
if keyword_norm:
|
||||
self.keyword_index.setdefault(keyword_norm, set()).add(memory_id)
|
||||
|
||||
# 标签索引
|
||||
for tag in entry.tags:
|
||||
tag_norm = tag.strip().lower()
|
||||
if tag_norm:
|
||||
self.tag_index.setdefault(tag_norm, set()).add(memory_id)
|
||||
payload = []
|
||||
for e in entries:
|
||||
if not e.memory_id:
|
||||
continue
|
||||
self.index[e.memory_id] = e
|
||||
payload.append(asdict(e))
|
||||
if payload:
|
||||
try:
|
||||
self._rust.batch_add(payload)
|
||||
except Exception as ex: # noqa: BLE001
|
||||
logger.error(f"Rust 元数据批量添加失败: {ex}")
|
||||
|
||||
def add_or_update(self, entry: MemoryMetadataIndexEntry):
|
||||
"""添加或更新索引条目"""
|
||||
with self.lock:
|
||||
# 如果已存在,先从倒排索引中移除旧记录
|
||||
if entry.memory_id in self.index:
|
||||
self._remove_from_inverted_indices(entry.memory_id)
|
||||
|
||||
# 添加新记录
|
||||
self.index[entry.memory_id] = entry
|
||||
self._update_inverted_indices(entry)
|
||||
|
||||
def _remove_from_inverted_indices(self, memory_id: str):
|
||||
"""从倒排索引中移除记录"""
|
||||
if memory_id not in self.index:
|
||||
return
|
||||
|
||||
entry = self.index[memory_id]
|
||||
|
||||
# 从类型索引移除
|
||||
if entry.memory_type in self.type_index:
|
||||
self.type_index[entry.memory_type].discard(memory_id)
|
||||
|
||||
# 从主语索引移除
|
||||
for subject in entry.subjects:
|
||||
subject_norm = subject.strip().lower()
|
||||
if subject_norm in self.subject_index:
|
||||
self.subject_index[subject_norm].discard(memory_id)
|
||||
|
||||
# 从关键词索引移除
|
||||
for keyword in entry.keywords:
|
||||
keyword_norm = keyword.strip().lower()
|
||||
if keyword_norm in self.keyword_index:
|
||||
self.keyword_index[keyword_norm].discard(memory_id)
|
||||
|
||||
# 从标签索引移除
|
||||
for tag in entry.tags:
|
||||
tag_norm = tag.strip().lower()
|
||||
if tag_norm in self.tag_index:
|
||||
self.tag_index[tag_norm].discard(memory_id)
|
||||
|
||||
def remove(self, memory_id: str):
|
||||
"""移除索引条目"""
|
||||
with self.lock:
|
||||
if memory_id in self.index:
|
||||
self._remove_from_inverted_indices(memory_id)
|
||||
del self.index[memory_id]
|
||||
|
||||
def batch_add_or_update(self, entries: list[MemoryMetadataIndexEntry]):
|
||||
"""批量添加或更新"""
|
||||
with self.lock:
|
||||
for entry in entries:
|
||||
self.add_or_update(entry)
|
||||
|
||||
def save(self):
|
||||
"""保存索引到磁盘"""
|
||||
with self.lock:
|
||||
self._save_index()
|
||||
self.batch_add_or_update([entry])
|
||||
|
||||
def search(
|
||||
self,
|
||||
@@ -201,287 +79,54 @@ class MemoryMetadataIndex:
|
||||
created_before: float | None = None,
|
||||
user_id: str | None = None,
|
||||
limit: int | None = None,
|
||||
flexible_mode: bool = True, # 新增:灵活匹配模式
|
||||
flexible_mode: bool = True,
|
||||
) -> list[str]:
|
||||
"""
|
||||
搜索符合条件的记忆ID列表(支持模糊匹配)
|
||||
|
||||
Returns:
|
||||
List[str]: 符合条件的 memory_id 列表
|
||||
"""
|
||||
with self.lock:
|
||||
params: dict[str, Any] = {
|
||||
"user_id": user_id,
|
||||
"memory_types": memory_types,
|
||||
"subjects": subjects,
|
||||
"keywords": keywords,
|
||||
"tags": tags,
|
||||
"importance_min": importance_min,
|
||||
"importance_max": importance_max,
|
||||
"created_after": created_after,
|
||||
"created_before": created_before,
|
||||
"limit": limit,
|
||||
}
|
||||
params = {k: v for k, v in params.items() if v is not None}
|
||||
try:
|
||||
if flexible_mode:
|
||||
return self._search_flexible(
|
||||
memory_types=memory_types,
|
||||
subjects=subjects,
|
||||
keywords=keywords, # 保留用于兼容性
|
||||
tags=tags, # 保留用于兼容性
|
||||
created_after=created_after,
|
||||
created_before=created_before,
|
||||
user_id=user_id,
|
||||
limit=limit,
|
||||
)
|
||||
else:
|
||||
return self._search_strict(
|
||||
memory_types=memory_types,
|
||||
subjects=subjects,
|
||||
keywords=keywords,
|
||||
tags=tags,
|
||||
importance_min=importance_min,
|
||||
importance_max=importance_max,
|
||||
created_after=created_after,
|
||||
created_before=created_before,
|
||||
user_id=user_id,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
def _search_flexible(
|
||||
self,
|
||||
memory_types: list[str] | None = None,
|
||||
subjects: list[str] | None = None,
|
||||
created_after: float | None = None,
|
||||
created_before: float | None = None,
|
||||
user_id: str | None = None,
|
||||
limit: int | None = None,
|
||||
**kwargs, # 接受但不使用的参数
|
||||
) -> list[str]:
|
||||
"""
|
||||
灵活搜索模式:2/4项匹配即可,支持部分匹配
|
||||
|
||||
评分维度:
|
||||
1. 记忆类型匹配 (0-1分)
|
||||
2. 主语匹配 (0-1分)
|
||||
3. 宾语匹配 (0-1分)
|
||||
4. 时间范围匹配 (0-1分)
|
||||
|
||||
总分 >= 2分即视为有效
|
||||
"""
|
||||
# 用户过滤(必选)
|
||||
if user_id:
|
||||
base_candidates = {mid for mid, entry in self.index.items() if entry.user_id == user_id}
|
||||
else:
|
||||
base_candidates = set(self.index.keys())
|
||||
|
||||
scored_candidates = []
|
||||
|
||||
for memory_id in base_candidates:
|
||||
entry = self.index[memory_id]
|
||||
score = 0
|
||||
match_details = []
|
||||
|
||||
# 1. 记忆类型匹配
|
||||
if memory_types:
|
||||
type_score = 0
|
||||
for mtype in memory_types:
|
||||
if entry.memory_type == mtype:
|
||||
type_score = 1
|
||||
break
|
||||
# 部分匹配:类型名称包含
|
||||
if mtype.lower() in entry.memory_type.lower() or entry.memory_type.lower() in mtype.lower():
|
||||
type_score = 0.5
|
||||
break
|
||||
score += type_score
|
||||
if type_score > 0:
|
||||
match_details.append(f"类型:{entry.memory_type}")
|
||||
else:
|
||||
match_details.append("类型:未指定")
|
||||
|
||||
# 2. 主语匹配(支持部分匹配)
|
||||
if subjects:
|
||||
subject_score = 0
|
||||
for subject in subjects:
|
||||
subject_norm = subject.strip().lower()
|
||||
for entry_subject in entry.subjects:
|
||||
entry_subject_norm = entry_subject.strip().lower()
|
||||
# 精确匹配
|
||||
if subject_norm == entry_subject_norm:
|
||||
subject_score = 1
|
||||
break
|
||||
# 部分匹配:包含关系
|
||||
if subject_norm in entry_subject_norm or entry_subject_norm in subject_norm:
|
||||
subject_score = 0.6
|
||||
break
|
||||
if subject_score == 1:
|
||||
break
|
||||
score += subject_score
|
||||
if subject_score > 0:
|
||||
match_details.append("主语:匹配")
|
||||
else:
|
||||
match_details.append("主语:未指定")
|
||||
|
||||
# 3. 宾语匹配(支持部分匹配)
|
||||
object_score = 0
|
||||
if entry.objects:
|
||||
for entry_object in entry.objects:
|
||||
entry_object_norm = str(entry_object).strip().lower()
|
||||
# 检查是否与主语相关(主宾关联)
|
||||
for subject in subjects or []:
|
||||
subject_norm = subject.strip().lower()
|
||||
if subject_norm in entry_object_norm or entry_object_norm in subject_norm:
|
||||
object_score = 0.8
|
||||
match_details.append("宾语:主宾关联")
|
||||
break
|
||||
if object_score > 0:
|
||||
break
|
||||
|
||||
score += object_score
|
||||
if object_score > 0:
|
||||
match_details.append("宾语:匹配")
|
||||
elif not entry.objects:
|
||||
match_details.append("宾语:无")
|
||||
|
||||
# 4. 时间范围匹配
|
||||
time_score = 0
|
||||
if created_after is not None or created_before is not None:
|
||||
time_match = True
|
||||
if created_after is not None and entry.created_at < created_after:
|
||||
time_match = False
|
||||
if created_before is not None and entry.created_at > created_before:
|
||||
time_match = False
|
||||
if time_match:
|
||||
time_score = 1
|
||||
match_details.append("时间:匹配")
|
||||
else:
|
||||
match_details.append("时间:不匹配")
|
||||
else:
|
||||
match_details.append("时间:未指定")
|
||||
|
||||
score += time_score
|
||||
|
||||
# 只有总分 >= 2 的记忆才会被返回
|
||||
if score >= 2:
|
||||
scored_candidates.append((memory_id, score, match_details))
|
||||
|
||||
# 按分数和时间排序
|
||||
scored_candidates.sort(key=lambda x: (x[1], self.index[x[0]].created_at), reverse=True)
|
||||
|
||||
if limit:
|
||||
result_ids = [mid for mid, _, _ in scored_candidates[:limit]]
|
||||
else:
|
||||
result_ids = [mid for mid, _, _ in scored_candidates]
|
||||
|
||||
logger.debug(
|
||||
f"[灵活搜索] 过滤条件: types={memory_types}, subjects={subjects}, "
|
||||
f"time_range=[{created_after}, {created_before}], 返回={len(result_ids)}条"
|
||||
)
|
||||
|
||||
# 记录匹配统计
|
||||
if scored_candidates and len(scored_candidates) > 0:
|
||||
avg_score = sum(score for _, score, _ in scored_candidates) / len(scored_candidates)
|
||||
logger.debug(f"[灵活搜索] 平均匹配分数: {avg_score:.2f}, 最高分: {scored_candidates[0][1]:.2f}")
|
||||
|
||||
return result_ids
|
||||
|
||||
def _search_strict(
|
||||
self,
|
||||
memory_types: list[str] | None = None,
|
||||
subjects: list[str] | None = None,
|
||||
keywords: list[str] | None = None,
|
||||
tags: list[str] | None = None,
|
||||
importance_min: int | None = None,
|
||||
importance_max: int | None = None,
|
||||
created_after: float | None = None,
|
||||
created_before: float | None = None,
|
||||
user_id: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> list[str]:
|
||||
"""严格搜索模式(原有逻辑)"""
|
||||
# 初始候选集(所有记忆)
|
||||
candidate_ids: set[str] | None = None
|
||||
|
||||
# 用户过滤(必选)
|
||||
if user_id:
|
||||
candidate_ids = {mid for mid, entry in self.index.items() if entry.user_id == user_id}
|
||||
else:
|
||||
candidate_ids = set(self.index.keys())
|
||||
|
||||
# 类型过滤(OR关系)
|
||||
if memory_types:
|
||||
type_ids = set()
|
||||
for mtype in memory_types:
|
||||
type_ids.update(self.type_index.get(mtype, set()))
|
||||
candidate_ids &= type_ids
|
||||
|
||||
# 主语过滤(OR关系,支持模糊匹配)
|
||||
if subjects:
|
||||
subject_ids = set()
|
||||
for subject in subjects:
|
||||
subject_norm = subject.strip().lower()
|
||||
# 精确匹配
|
||||
if subject_norm in self.subject_index:
|
||||
subject_ids.update(self.subject_index[subject_norm])
|
||||
# 模糊匹配(包含)
|
||||
for indexed_subject, ids in self.subject_index.items():
|
||||
if subject_norm in indexed_subject or indexed_subject in subject_norm:
|
||||
subject_ids.update(ids)
|
||||
candidate_ids &= subject_ids
|
||||
|
||||
# 关键词过滤(OR关系,支持模糊匹配)
|
||||
if keywords:
|
||||
keyword_ids = set()
|
||||
for keyword in keywords:
|
||||
keyword_norm = keyword.strip().lower()
|
||||
# 精确匹配
|
||||
if keyword_norm in self.keyword_index:
|
||||
keyword_ids.update(self.keyword_index[keyword_norm])
|
||||
# 模糊匹配(包含)
|
||||
for indexed_keyword, ids in self.keyword_index.items():
|
||||
if keyword_norm in indexed_keyword or indexed_keyword in keyword_norm:
|
||||
keyword_ids.update(ids)
|
||||
candidate_ids &= keyword_ids
|
||||
|
||||
# 标签过滤(OR关系)
|
||||
if tags:
|
||||
tag_ids = set()
|
||||
for tag in tags:
|
||||
tag_norm = tag.strip().lower()
|
||||
tag_ids.update(self.tag_index.get(tag_norm, set()))
|
||||
candidate_ids &= tag_ids
|
||||
|
||||
# 重要性过滤
|
||||
if importance_min is not None or importance_max is not None:
|
||||
importance_ids = {
|
||||
mid
|
||||
for mid in candidate_ids
|
||||
if (importance_min is None or self.index[mid].importance >= importance_min)
|
||||
and (importance_max is None or self.index[mid].importance <= importance_max)
|
||||
}
|
||||
candidate_ids &= importance_ids
|
||||
|
||||
# 时间范围过滤
|
||||
if created_after is not None or created_before is not None:
|
||||
time_ids = {
|
||||
mid
|
||||
for mid in candidate_ids
|
||||
if (created_after is None or self.index[mid].created_at >= created_after)
|
||||
and (created_before is None or self.index[mid].created_at <= created_before)
|
||||
}
|
||||
candidate_ids &= time_ids
|
||||
|
||||
# 转换为列表并排序(按创建时间倒序)
|
||||
result_ids = sorted(candidate_ids, key=lambda mid: self.index[mid].created_at, reverse=True)
|
||||
|
||||
# 限制数量
|
||||
if limit:
|
||||
result_ids = result_ids[:limit]
|
||||
|
||||
logger.debug(
|
||||
f"[严格搜索] types={memory_types}, subjects={subjects}, keywords={keywords}, 返回={len(result_ids)}条"
|
||||
)
|
||||
|
||||
return result_ids
|
||||
return list(self._rust.search_flexible(params))
|
||||
return list(self._rust.search_strict(params))
|
||||
except Exception as ex: # noqa: BLE001
|
||||
logger.error(f"Rust 搜索失败返回空: {ex}")
|
||||
return []
|
||||
|
||||
def get_entry(self, memory_id: str) -> MemoryMetadataIndexEntry | None:
|
||||
"""获取单个索引条目"""
|
||||
return self.index.get(memory_id)
|
||||
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
"""获取索引统计信息"""
|
||||
with self.lock:
|
||||
try:
|
||||
raw = self._rust.stats()
|
||||
return {
|
||||
"total_memories": len(self.index),
|
||||
"types": {mtype: len(ids) for mtype, ids in self.type_index.items()},
|
||||
"subjects_count": len(self.subject_index),
|
||||
"keywords_count": len(self.keyword_index),
|
||||
"tags_count": len(self.tag_index),
|
||||
"total_memories": raw.get("total", 0),
|
||||
"types": raw.get("types_dist", {}),
|
||||
"subjects_count": raw.get("subjects_indexed", 0),
|
||||
"keywords_count": raw.get("keywords_indexed", 0),
|
||||
"tags_count": raw.get("tags_indexed", 0),
|
||||
}
|
||||
except Exception as ex: # noqa: BLE001
|
||||
logger.warning(f"读取 Rust stats 失败: {ex}")
|
||||
return {"total_memories": 0}
|
||||
|
||||
def save(self): # 仅调用 rust save
|
||||
try:
|
||||
self._rust.save()
|
||||
except Exception as ex: # noqa: BLE001
|
||||
logger.warning(f"Rust save 失败: {ex}")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"MemoryMetadataIndexEntry",
|
||||
"MemoryMetadataIndex",
|
||||
]
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from src.llm_models.payload_content.tool_option import ToolCall
|
||||
|
||||
from . import BaseDataModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -2,6 +2,7 @@ import asyncio
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
||||
|
||||
logger = get_logger("base_event")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user