Refactor memory metadata index to use Rust backend

Replaces the Python implementation of MemoryMetadataIndex with a Rust-accelerated version, removing legacy code and fallback logic. Updates search, add, and stats methods to delegate to the Rust backend. Also adds missing imports in llm_data_model.py and base_event.py for improved type support and event handling.
This commit is contained in:
雅诺狐
2025-10-05 16:35:31 +08:00
committed by Windpicker-owo
parent a513aeb68e
commit a9c592b203
3 changed files with 89 additions and 441 deletions

View File

@@ -1,193 +1,71 @@
"""
记忆元数据索引管理器
使用JSON文件存储记忆元数据支持快速模糊搜索和过滤
记忆元数据索引
"""
import threading
from dataclasses import asdict, dataclass
from datetime import datetime
from pathlib import Path
from dataclasses import dataclass, asdict
from typing import Any
import orjson
from time import time
from src.common.logger import get_logger
logger = get_logger(__name__)
# 仅允许规范导入路径from inkfox.memory import PyMetadataIndex
try: # pragma: no cover
from inkfox.memory import PyMetadataIndex as _RustIndex # type: ignore
logger.debug("已从 inkfox.memory 成功导入 PyMetadataIndex")
except Exception as ex: # noqa: BLE001
# 不再做任何回退;强制要求正确的 Rust 模块子模块注册
raise RuntimeError(
"无法导入 inkfox.memory.PyMetadataIndex: %s\n"
"请确认: 1) 已在当前虚拟环境下执行 'maturin develop --release' 安装扩展; "
"2) 运行进程使用同一个 venv; 3) 没有旧的 'inkfox' 目录遮蔽 so/pyd; 4) Python 版本与编译匹配" % ex
) from ex
@dataclass
class MemoryMetadataIndexEntry:
"""元数据索引条目(轻量级,只用于快速过滤)"""
memory_id: str
user_id: str
# 分类信息
memory_type: str # MemoryType.value
subjects: list[str] # 主语列表
objects: list[str] # 宾语列表
keywords: list[str] # 关键词列表
tags: list[str] # 标签列表
# 数值字段(用于范围过滤)
importance: int # ImportanceLevel.value (1-4)
confidence: int # ConfidenceLevel.value (1-4)
created_at: float # 创建时间戳
access_count: int # 访问次数
# 可选字段
memory_type: str
subjects: list[str]
objects: list[str]
keywords: list[str]
tags: list[str]
importance: int
confidence: int
created_at: float
access_count: int
chat_id: str | None = None
content_preview: str | None = None # 内容预览前100字符
content_preview: str | None = None
class MemoryMetadataIndex:
"""记忆元数据索引管理器"""
"""Rust 加速版本唯一实现。"""
def __init__(self, index_file: str = "data/memory_metadata_index.json"):
self.index_file = Path(index_file)
self.index: dict[str, MemoryMetadataIndexEntry] = {} # memory_id -> entry
self._rust = _RustIndex(index_file)
# 仅为向量层和调试提供最小缓存长度判断、get_entry 返回)
self.index: dict[str, MemoryMetadataIndexEntry] = {}
logger.info("✅ MemoryMetadataIndex (Rust) 初始化完成,仅支持加速实现")
# 倒排索引(用于快速查找)
self.type_index: dict[str, set[str]] = {} # type -> {memory_ids}
self.subject_index: dict[str, set[str]] = {} # subject -> {memory_ids}
self.keyword_index: dict[str, set[str]] = {} # keyword -> {memory_ids}
self.tag_index: dict[str, set[str]] = {} # tag -> {memory_ids}
self.lock = threading.RLock()
# 加载已有索引
self._load_index()
def _load_index(self):
"""从文件加载索引"""
if not self.index_file.exists():
logger.info(f"元数据索引文件不存在,将创建新索引: {self.index_file}")
# 向后代码仍调用的接口batch_add_or_update / add_or_update
def batch_add_or_update(self, entries: list[MemoryMetadataIndexEntry]):
if not entries:
return
try:
with open(self.index_file, "rb") as f:
data = orjson.loads(f.read())
# 重建内存索引
for entry_data in data.get("entries", []):
entry = MemoryMetadataIndexEntry(**entry_data)
self.index[entry.memory_id] = entry
self._update_inverted_indices(entry)
logger.info(f"✅ 加载元数据索引: {len(self.index)} 条记录")
except Exception as e:
logger.error(f"加载元数据索引失败: {e}", exc_info=True)
def _save_index(self):
"""保存索引到文件"""
try:
# 确保目录存在
self.index_file.parent.mkdir(parents=True, exist_ok=True)
# 序列化所有条目
entries = [asdict(entry) for entry in self.index.values()]
data = {
"version": "1.0",
"count": len(entries),
"last_updated": datetime.now().isoformat(),
"entries": entries,
}
# 写入文件(使用临时文件 + 原子重命名)
temp_file = self.index_file.with_suffix(".tmp")
with open(temp_file, "wb") as f:
f.write(orjson.dumps(data, option=orjson.OPT_INDENT_2))
temp_file.replace(self.index_file)
logger.debug(f"元数据索引已保存: {len(entries)} 条记录")
except Exception as e:
logger.error(f"保存元数据索引失败: {e}", exc_info=True)
def _update_inverted_indices(self, entry: MemoryMetadataIndexEntry):
"""更新倒排索引"""
memory_id = entry.memory_id
# 类型索引
self.type_index.setdefault(entry.memory_type, set()).add(memory_id)
# 主语索引
for subject in entry.subjects:
subject_norm = subject.strip().lower()
if subject_norm:
self.subject_index.setdefault(subject_norm, set()).add(memory_id)
# 关键词索引
for keyword in entry.keywords:
keyword_norm = keyword.strip().lower()
if keyword_norm:
self.keyword_index.setdefault(keyword_norm, set()).add(memory_id)
# 标签索引
for tag in entry.tags:
tag_norm = tag.strip().lower()
if tag_norm:
self.tag_index.setdefault(tag_norm, set()).add(memory_id)
payload = []
for e in entries:
if not e.memory_id:
continue
self.index[e.memory_id] = e
payload.append(asdict(e))
if payload:
try:
self._rust.batch_add(payload)
except Exception as ex: # noqa: BLE001
logger.error(f"Rust 元数据批量添加失败: {ex}")
def add_or_update(self, entry: MemoryMetadataIndexEntry):
"""添加或更新索引条目"""
with self.lock:
# 如果已存在,先从倒排索引中移除旧记录
if entry.memory_id in self.index:
self._remove_from_inverted_indices(entry.memory_id)
# 添加新记录
self.index[entry.memory_id] = entry
self._update_inverted_indices(entry)
def _remove_from_inverted_indices(self, memory_id: str):
"""从倒排索引中移除记录"""
if memory_id not in self.index:
return
entry = self.index[memory_id]
# 从类型索引移除
if entry.memory_type in self.type_index:
self.type_index[entry.memory_type].discard(memory_id)
# 从主语索引移除
for subject in entry.subjects:
subject_norm = subject.strip().lower()
if subject_norm in self.subject_index:
self.subject_index[subject_norm].discard(memory_id)
# 从关键词索引移除
for keyword in entry.keywords:
keyword_norm = keyword.strip().lower()
if keyword_norm in self.keyword_index:
self.keyword_index[keyword_norm].discard(memory_id)
# 从标签索引移除
for tag in entry.tags:
tag_norm = tag.strip().lower()
if tag_norm in self.tag_index:
self.tag_index[tag_norm].discard(memory_id)
def remove(self, memory_id: str):
"""移除索引条目"""
with self.lock:
if memory_id in self.index:
self._remove_from_inverted_indices(memory_id)
del self.index[memory_id]
def batch_add_or_update(self, entries: list[MemoryMetadataIndexEntry]):
"""批量添加或更新"""
with self.lock:
for entry in entries:
self.add_or_update(entry)
def save(self):
"""保存索引到磁盘"""
with self.lock:
self._save_index()
self.batch_add_or_update([entry])
def search(
self,
@@ -201,287 +79,54 @@ class MemoryMetadataIndex:
created_before: float | None = None,
user_id: str | None = None,
limit: int | None = None,
flexible_mode: bool = True, # 新增:灵活匹配模式
flexible_mode: bool = True,
) -> list[str]:
"""
搜索符合条件的记忆ID列表支持模糊匹配
Returns:
List[str]: 符合条件的 memory_id 列表
"""
with self.lock:
params: dict[str, Any] = {
"user_id": user_id,
"memory_types": memory_types,
"subjects": subjects,
"keywords": keywords,
"tags": tags,
"importance_min": importance_min,
"importance_max": importance_max,
"created_after": created_after,
"created_before": created_before,
"limit": limit,
}
params = {k: v for k, v in params.items() if v is not None}
try:
if flexible_mode:
return self._search_flexible(
memory_types=memory_types,
subjects=subjects,
keywords=keywords, # 保留用于兼容性
tags=tags, # 保留用于兼容性
created_after=created_after,
created_before=created_before,
user_id=user_id,
limit=limit,
)
else:
return self._search_strict(
memory_types=memory_types,
subjects=subjects,
keywords=keywords,
tags=tags,
importance_min=importance_min,
importance_max=importance_max,
created_after=created_after,
created_before=created_before,
user_id=user_id,
limit=limit,
)
def _search_flexible(
self,
memory_types: list[str] | None = None,
subjects: list[str] | None = None,
created_after: float | None = None,
created_before: float | None = None,
user_id: str | None = None,
limit: int | None = None,
**kwargs, # 接受但不使用的参数
) -> list[str]:
"""
灵活搜索模式2/4项匹配即可支持部分匹配
评分维度:
1. 记忆类型匹配 (0-1分)
2. 主语匹配 (0-1分)
3. 宾语匹配 (0-1分)
4. 时间范围匹配 (0-1分)
总分 >= 2分即视为有效
"""
# 用户过滤(必选)
if user_id:
base_candidates = {mid for mid, entry in self.index.items() if entry.user_id == user_id}
else:
base_candidates = set(self.index.keys())
scored_candidates = []
for memory_id in base_candidates:
entry = self.index[memory_id]
score = 0
match_details = []
# 1. 记忆类型匹配
if memory_types:
type_score = 0
for mtype in memory_types:
if entry.memory_type == mtype:
type_score = 1
break
# 部分匹配:类型名称包含
if mtype.lower() in entry.memory_type.lower() or entry.memory_type.lower() in mtype.lower():
type_score = 0.5
break
score += type_score
if type_score > 0:
match_details.append(f"类型:{entry.memory_type}")
else:
match_details.append("类型:未指定")
# 2. 主语匹配(支持部分匹配)
if subjects:
subject_score = 0
for subject in subjects:
subject_norm = subject.strip().lower()
for entry_subject in entry.subjects:
entry_subject_norm = entry_subject.strip().lower()
# 精确匹配
if subject_norm == entry_subject_norm:
subject_score = 1
break
# 部分匹配:包含关系
if subject_norm in entry_subject_norm or entry_subject_norm in subject_norm:
subject_score = 0.6
break
if subject_score == 1:
break
score += subject_score
if subject_score > 0:
match_details.append("主语:匹配")
else:
match_details.append("主语:未指定")
# 3. 宾语匹配(支持部分匹配)
object_score = 0
if entry.objects:
for entry_object in entry.objects:
entry_object_norm = str(entry_object).strip().lower()
# 检查是否与主语相关(主宾关联)
for subject in subjects or []:
subject_norm = subject.strip().lower()
if subject_norm in entry_object_norm or entry_object_norm in subject_norm:
object_score = 0.8
match_details.append("宾语:主宾关联")
break
if object_score > 0:
break
score += object_score
if object_score > 0:
match_details.append("宾语:匹配")
elif not entry.objects:
match_details.append("宾语:无")
# 4. 时间范围匹配
time_score = 0
if created_after is not None or created_before is not None:
time_match = True
if created_after is not None and entry.created_at < created_after:
time_match = False
if created_before is not None and entry.created_at > created_before:
time_match = False
if time_match:
time_score = 1
match_details.append("时间:匹配")
else:
match_details.append("时间:不匹配")
else:
match_details.append("时间:未指定")
score += time_score
# 只有总分 >= 2 的记忆才会被返回
if score >= 2:
scored_candidates.append((memory_id, score, match_details))
# 按分数和时间排序
scored_candidates.sort(key=lambda x: (x[1], self.index[x[0]].created_at), reverse=True)
if limit:
result_ids = [mid for mid, _, _ in scored_candidates[:limit]]
else:
result_ids = [mid for mid, _, _ in scored_candidates]
logger.debug(
f"[灵活搜索] 过滤条件: types={memory_types}, subjects={subjects}, "
f"time_range=[{created_after}, {created_before}], 返回={len(result_ids)}"
)
# 记录匹配统计
if scored_candidates and len(scored_candidates) > 0:
avg_score = sum(score for _, score, _ in scored_candidates) / len(scored_candidates)
logger.debug(f"[灵活搜索] 平均匹配分数: {avg_score:.2f}, 最高分: {scored_candidates[0][1]:.2f}")
return result_ids
def _search_strict(
self,
memory_types: list[str] | None = None,
subjects: list[str] | None = None,
keywords: list[str] | None = None,
tags: list[str] | None = None,
importance_min: int | None = None,
importance_max: int | None = None,
created_after: float | None = None,
created_before: float | None = None,
user_id: str | None = None,
limit: int | None = None,
) -> list[str]:
"""严格搜索模式(原有逻辑)"""
# 初始候选集(所有记忆)
candidate_ids: set[str] | None = None
# 用户过滤(必选)
if user_id:
candidate_ids = {mid for mid, entry in self.index.items() if entry.user_id == user_id}
else:
candidate_ids = set(self.index.keys())
# 类型过滤OR关系
if memory_types:
type_ids = set()
for mtype in memory_types:
type_ids.update(self.type_index.get(mtype, set()))
candidate_ids &= type_ids
# 主语过滤OR关系支持模糊匹配
if subjects:
subject_ids = set()
for subject in subjects:
subject_norm = subject.strip().lower()
# 精确匹配
if subject_norm in self.subject_index:
subject_ids.update(self.subject_index[subject_norm])
# 模糊匹配(包含)
for indexed_subject, ids in self.subject_index.items():
if subject_norm in indexed_subject or indexed_subject in subject_norm:
subject_ids.update(ids)
candidate_ids &= subject_ids
# 关键词过滤OR关系支持模糊匹配
if keywords:
keyword_ids = set()
for keyword in keywords:
keyword_norm = keyword.strip().lower()
# 精确匹配
if keyword_norm in self.keyword_index:
keyword_ids.update(self.keyword_index[keyword_norm])
# 模糊匹配(包含)
for indexed_keyword, ids in self.keyword_index.items():
if keyword_norm in indexed_keyword or indexed_keyword in keyword_norm:
keyword_ids.update(ids)
candidate_ids &= keyword_ids
# 标签过滤OR关系
if tags:
tag_ids = set()
for tag in tags:
tag_norm = tag.strip().lower()
tag_ids.update(self.tag_index.get(tag_norm, set()))
candidate_ids &= tag_ids
# 重要性过滤
if importance_min is not None or importance_max is not None:
importance_ids = {
mid
for mid in candidate_ids
if (importance_min is None or self.index[mid].importance >= importance_min)
and (importance_max is None or self.index[mid].importance <= importance_max)
}
candidate_ids &= importance_ids
# 时间范围过滤
if created_after is not None or created_before is not None:
time_ids = {
mid
for mid in candidate_ids
if (created_after is None or self.index[mid].created_at >= created_after)
and (created_before is None or self.index[mid].created_at <= created_before)
}
candidate_ids &= time_ids
# 转换为列表并排序(按创建时间倒序)
result_ids = sorted(candidate_ids, key=lambda mid: self.index[mid].created_at, reverse=True)
# 限制数量
if limit:
result_ids = result_ids[:limit]
logger.debug(
f"[严格搜索] types={memory_types}, subjects={subjects}, keywords={keywords}, 返回={len(result_ids)}"
)
return result_ids
return list(self._rust.search_flexible(params))
return list(self._rust.search_strict(params))
except Exception as ex: # noqa: BLE001
logger.error(f"Rust 搜索失败返回空: {ex}")
return []
def get_entry(self, memory_id: str) -> MemoryMetadataIndexEntry | None:
"""获取单个索引条目"""
return self.index.get(memory_id)
def get_stats(self) -> dict[str, Any]:
"""获取索引统计信息"""
with self.lock:
try:
raw = self._rust.stats()
return {
"total_memories": len(self.index),
"types": {mtype: len(ids) for mtype, ids in self.type_index.items()},
"subjects_count": len(self.subject_index),
"keywords_count": len(self.keyword_index),
"tags_count": len(self.tag_index),
"total_memories": raw.get("total", 0),
"types": raw.get("types_dist", {}),
"subjects_count": raw.get("subjects_indexed", 0),
"keywords_count": raw.get("keywords_indexed", 0),
"tags_count": raw.get("tags_indexed", 0),
}
except Exception as ex: # noqa: BLE001
logger.warning(f"读取 Rust stats 失败: {ex}")
return {"total_memories": 0}
def save(self): # 仅调用 rust save
try:
self._rust.save()
except Exception as ex: # noqa: BLE001
logger.warning(f"Rust save 失败: {ex}")
__all__ = [
"MemoryMetadataIndexEntry",
"MemoryMetadataIndex",
]

View File

@@ -1,6 +1,8 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
from src.llm_models.payload_content.tool_option import ToolCall
from . import BaseDataModel
if TYPE_CHECKING:

View File

@@ -2,6 +2,7 @@ import asyncio
from typing import Any
from src.common.logger import get_logger
from src.plugin_system.base.base_events_handler import BaseEventHandler
logger = get_logger("base_event")