feat(memory-graph): 完成 Phase 2 - 记忆构建与工具接口
Phase 2 实现内容: - 时间解析器 (utils/time_parser.py): 支持自然语言时间表达式 - 记忆提取器 (core/extractor.py): 参数验证和标准化 - 记忆构建器 (core/builder.py): 自动构造记忆子图,支持节点去重和关联 - 嵌入生成器 (utils/embeddings.py): API 优先策略,降低本地负载 - LLM 工具接口 (tools/memory_tools.py): create_memory, link_memories, search_memories 关键修复: - VectorStore: 支持 ChromaDB 列表元数据的 JSON 序列化 - 测试数据同步: 确保向量存储和图存储数据一致性 测试结果: 时间解析器: 6/6 通过 记忆提取器: 3 个测试用例通过 记忆构建器: 构建记忆子图成功 端到端流程: 成功创建 3 条记忆 记忆关联: 建立因果关系成功 记忆搜索: 语义搜索返回正确结果 工具 Schema: 3 个工具定义完整 下一步: Phase 3 - 管理层实现
This commit is contained in:
@@ -2,6 +2,8 @@
|
|||||||
核心模块
|
核心模块
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from src.memory_graph.core.builder import MemoryBuilder
|
||||||
|
from src.memory_graph.core.extractor import MemoryExtractor
|
||||||
from src.memory_graph.core.node_merger import NodeMerger
|
from src.memory_graph.core.node_merger import NodeMerger
|
||||||
|
|
||||||
__all__ = ["NodeMerger"]
|
__all__ = ["NodeMerger", "MemoryExtractor", "MemoryBuilder"]
|
||||||
|
|||||||
549
src/memory_graph/core/builder.py
Normal file
549
src/memory_graph/core/builder.py
Normal file
@@ -0,0 +1,549 @@
|
|||||||
|
"""
|
||||||
|
记忆构建器:自动构造记忆子图
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.memory_graph.models import (
|
||||||
|
EdgeType,
|
||||||
|
Memory,
|
||||||
|
MemoryEdge,
|
||||||
|
MemoryNode,
|
||||||
|
MemoryStatus,
|
||||||
|
MemoryType,
|
||||||
|
NodeType,
|
||||||
|
)
|
||||||
|
from src.memory_graph.storage.graph_store import GraphStore
|
||||||
|
from src.memory_graph.storage.vector_store import VectorStore
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryBuilder:
|
||||||
|
"""
|
||||||
|
记忆构建器
|
||||||
|
|
||||||
|
负责:
|
||||||
|
1. 根据提取的元素自动构造记忆子图
|
||||||
|
2. 创建节点和边的完整结构
|
||||||
|
3. 生成语义嵌入向量
|
||||||
|
4. 检查并复用已存在的相似节点
|
||||||
|
5. 构造符合层级结构的记忆对象
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vector_store: VectorStore,
|
||||||
|
graph_store: GraphStore,
|
||||||
|
embedding_generator: Optional[Any] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
初始化记忆构建器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vector_store: 向量存储
|
||||||
|
graph_store: 图存储
|
||||||
|
embedding_generator: 嵌入向量生成器(可选)
|
||||||
|
"""
|
||||||
|
self.vector_store = vector_store
|
||||||
|
self.graph_store = graph_store
|
||||||
|
self.embedding_generator = embedding_generator
|
||||||
|
|
||||||
|
async def build_memory(self, extracted_params: Dict[str, Any]) -> Memory:
|
||||||
|
"""
|
||||||
|
构建完整的记忆对象
|
||||||
|
|
||||||
|
Args:
|
||||||
|
extracted_params: 提取器返回的标准化参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Memory 对象(状态为 STAGED)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
nodes = []
|
||||||
|
edges = []
|
||||||
|
memory_id = self._generate_memory_id()
|
||||||
|
|
||||||
|
# 1. 创建主体节点 (SUBJECT)
|
||||||
|
subject_node = await self._create_or_reuse_node(
|
||||||
|
content=extracted_params["subject"],
|
||||||
|
node_type=NodeType.SUBJECT,
|
||||||
|
memory_id=memory_id,
|
||||||
|
)
|
||||||
|
nodes.append(subject_node)
|
||||||
|
|
||||||
|
# 2. 创建主题节点 (TOPIC) - 需要嵌入向量
|
||||||
|
topic_node = await self._create_topic_node(
|
||||||
|
content=extracted_params["topic"], memory_id=memory_id
|
||||||
|
)
|
||||||
|
nodes.append(topic_node)
|
||||||
|
|
||||||
|
# 3. 连接主体 -> 记忆类型 -> 主题
|
||||||
|
memory_type_edge = MemoryEdge(
|
||||||
|
id=self._generate_edge_id(),
|
||||||
|
source_id=subject_node.id,
|
||||||
|
target_id=topic_node.id,
|
||||||
|
relation=extracted_params["memory_type"].value,
|
||||||
|
edge_type=EdgeType.MEMORY_TYPE,
|
||||||
|
importance=extracted_params["importance"],
|
||||||
|
metadata={"memory_id": memory_id},
|
||||||
|
)
|
||||||
|
edges.append(memory_type_edge)
|
||||||
|
|
||||||
|
# 4. 如果有客体,创建客体节点并连接
|
||||||
|
if "object" in extracted_params and extracted_params["object"]:
|
||||||
|
object_node = await self._create_object_node(
|
||||||
|
content=extracted_params["object"], memory_id=memory_id
|
||||||
|
)
|
||||||
|
nodes.append(object_node)
|
||||||
|
|
||||||
|
# 连接主题 -> 核心关系 -> 客体
|
||||||
|
core_relation_edge = MemoryEdge(
|
||||||
|
id=self._generate_edge_id(),
|
||||||
|
source_id=topic_node.id,
|
||||||
|
target_id=object_node.id,
|
||||||
|
relation="核心关系", # 默认关系名
|
||||||
|
edge_type=EdgeType.CORE_RELATION,
|
||||||
|
importance=extracted_params["importance"],
|
||||||
|
metadata={"memory_id": memory_id},
|
||||||
|
)
|
||||||
|
edges.append(core_relation_edge)
|
||||||
|
|
||||||
|
# 5. 处理属性
|
||||||
|
if extracted_params.get("attributes"):
|
||||||
|
attr_nodes, attr_edges = await self._process_attributes(
|
||||||
|
attributes=extracted_params["attributes"],
|
||||||
|
parent_id=topic_node.id,
|
||||||
|
memory_id=memory_id,
|
||||||
|
importance=extracted_params["importance"],
|
||||||
|
)
|
||||||
|
nodes.extend(attr_nodes)
|
||||||
|
edges.extend(attr_edges)
|
||||||
|
|
||||||
|
# 6. 构建 Memory 对象
|
||||||
|
memory = Memory(
|
||||||
|
id=memory_id,
|
||||||
|
subject_id=subject_node.id,
|
||||||
|
memory_type=extracted_params["memory_type"],
|
||||||
|
nodes=nodes,
|
||||||
|
edges=edges,
|
||||||
|
importance=extracted_params["importance"],
|
||||||
|
created_at=extracted_params["timestamp"],
|
||||||
|
last_accessed=extracted_params["timestamp"],
|
||||||
|
access_count=0,
|
||||||
|
status=MemoryStatus.STAGED,
|
||||||
|
metadata={
|
||||||
|
"subject": extracted_params["subject"],
|
||||||
|
"topic": extracted_params["topic"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"构建记忆成功: {memory_id} - {len(nodes)} 节点, {len(edges)} 边"
|
||||||
|
)
|
||||||
|
return memory
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"记忆构建失败: {e}", exc_info=True)
|
||||||
|
raise RuntimeError(f"记忆构建失败: {e}")
|
||||||
|
|
||||||
|
async def _create_or_reuse_node(
|
||||||
|
self, content: str, node_type: NodeType, memory_id: str
|
||||||
|
) -> MemoryNode:
|
||||||
|
"""
|
||||||
|
创建新节点或复用已存在的相似节点
|
||||||
|
|
||||||
|
对于主体(SUBJECT)和属性(ATTRIBUTE),检查是否已存在相同内容的节点
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: 节点内容
|
||||||
|
node_type: 节点类型
|
||||||
|
memory_id: 所属记忆ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MemoryNode 对象
|
||||||
|
"""
|
||||||
|
# 对于主体,尝试查找已存在的节点
|
||||||
|
if node_type == NodeType.SUBJECT:
|
||||||
|
existing = await self._find_existing_node(content, node_type)
|
||||||
|
if existing:
|
||||||
|
logger.debug(f"复用已存在的主体节点: {existing.id}")
|
||||||
|
return existing
|
||||||
|
|
||||||
|
# 创建新节点
|
||||||
|
node = MemoryNode(
|
||||||
|
id=self._generate_node_id(),
|
||||||
|
content=content,
|
||||||
|
node_type=node_type,
|
||||||
|
embedding=None, # 主体和属性不需要嵌入
|
||||||
|
metadata={"memory_ids": [memory_id]},
|
||||||
|
)
|
||||||
|
|
||||||
|
return node
|
||||||
|
|
||||||
|
async def _create_topic_node(self, content: str, memory_id: str) -> MemoryNode:
|
||||||
|
"""
|
||||||
|
创建主题节点(需要生成嵌入向量)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: 节点内容
|
||||||
|
memory_id: 所属记忆ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MemoryNode 对象
|
||||||
|
"""
|
||||||
|
# 生成嵌入向量
|
||||||
|
embedding = await self._generate_embedding(content)
|
||||||
|
|
||||||
|
# 检查是否存在高度相似的节点
|
||||||
|
existing = await self._find_similar_topic(content, embedding)
|
||||||
|
if existing:
|
||||||
|
logger.debug(f"复用相似的主题节点: {existing.id}")
|
||||||
|
# 添加当前记忆ID到元数据
|
||||||
|
if "memory_ids" not in existing.metadata:
|
||||||
|
existing.metadata["memory_ids"] = []
|
||||||
|
existing.metadata["memory_ids"].append(memory_id)
|
||||||
|
return existing
|
||||||
|
|
||||||
|
# 创建新节点
|
||||||
|
node = MemoryNode(
|
||||||
|
id=self._generate_node_id(),
|
||||||
|
content=content,
|
||||||
|
node_type=NodeType.TOPIC,
|
||||||
|
embedding=embedding,
|
||||||
|
metadata={"memory_ids": [memory_id]},
|
||||||
|
)
|
||||||
|
|
||||||
|
return node
|
||||||
|
|
||||||
|
async def _create_object_node(self, content: str, memory_id: str) -> MemoryNode:
|
||||||
|
"""
|
||||||
|
创建客体节点(需要生成嵌入向量)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: 节点内容
|
||||||
|
memory_id: 所属记忆ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MemoryNode 对象
|
||||||
|
"""
|
||||||
|
# 生成嵌入向量
|
||||||
|
embedding = await self._generate_embedding(content)
|
||||||
|
|
||||||
|
# 检查是否存在高度相似的节点
|
||||||
|
existing = await self._find_similar_object(content, embedding)
|
||||||
|
if existing:
|
||||||
|
logger.debug(f"复用相似的客体节点: {existing.id}")
|
||||||
|
if "memory_ids" not in existing.metadata:
|
||||||
|
existing.metadata["memory_ids"] = []
|
||||||
|
existing.metadata["memory_ids"].append(memory_id)
|
||||||
|
return existing
|
||||||
|
|
||||||
|
# 创建新节点
|
||||||
|
node = MemoryNode(
|
||||||
|
id=self._generate_node_id(),
|
||||||
|
content=content,
|
||||||
|
node_type=NodeType.OBJECT,
|
||||||
|
embedding=embedding,
|
||||||
|
metadata={"memory_ids": [memory_id]},
|
||||||
|
)
|
||||||
|
|
||||||
|
return node
|
||||||
|
|
||||||
|
async def _process_attributes(
|
||||||
|
self,
|
||||||
|
attributes: Dict[str, Any],
|
||||||
|
parent_id: str,
|
||||||
|
memory_id: str,
|
||||||
|
importance: float,
|
||||||
|
) -> tuple[List[MemoryNode], List[MemoryEdge]]:
|
||||||
|
"""
|
||||||
|
处理属性,构建属性子图
|
||||||
|
|
||||||
|
结构:TOPIC -> ATTRIBUTE -> VALUE
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attributes: 属性字典
|
||||||
|
parent_id: 父节点ID(通常是TOPIC)
|
||||||
|
memory_id: 所属记忆ID
|
||||||
|
importance: 重要性
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(属性节点列表, 属性边列表)
|
||||||
|
"""
|
||||||
|
nodes = []
|
||||||
|
edges = []
|
||||||
|
|
||||||
|
for attr_name, attr_value in attributes.items():
|
||||||
|
# 创建属性节点
|
||||||
|
attr_node = await self._create_or_reuse_node(
|
||||||
|
content=attr_name, node_type=NodeType.ATTRIBUTE, memory_id=memory_id
|
||||||
|
)
|
||||||
|
nodes.append(attr_node)
|
||||||
|
|
||||||
|
# 连接父节点 -> 属性
|
||||||
|
attr_edge = MemoryEdge(
|
||||||
|
id=self._generate_edge_id(),
|
||||||
|
source_id=parent_id,
|
||||||
|
target_id=attr_node.id,
|
||||||
|
relation="属性",
|
||||||
|
edge_type=EdgeType.ATTRIBUTE,
|
||||||
|
importance=importance * 0.8, # 属性的重要性略低
|
||||||
|
metadata={"memory_id": memory_id},
|
||||||
|
)
|
||||||
|
edges.append(attr_edge)
|
||||||
|
|
||||||
|
# 创建值节点
|
||||||
|
value_node = await self._create_or_reuse_node(
|
||||||
|
content=str(attr_value), node_type=NodeType.VALUE, memory_id=memory_id
|
||||||
|
)
|
||||||
|
nodes.append(value_node)
|
||||||
|
|
||||||
|
# 连接属性 -> 值
|
||||||
|
value_edge = MemoryEdge(
|
||||||
|
id=self._generate_edge_id(),
|
||||||
|
source_id=attr_node.id,
|
||||||
|
target_id=value_node.id,
|
||||||
|
relation="值",
|
||||||
|
edge_type=EdgeType.ATTRIBUTE,
|
||||||
|
importance=importance * 0.8,
|
||||||
|
metadata={"memory_id": memory_id},
|
||||||
|
)
|
||||||
|
edges.append(value_edge)
|
||||||
|
|
||||||
|
return nodes, edges
|
||||||
|
|
||||||
|
async def _generate_embedding(self, text: str) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
生成文本的嵌入向量
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 文本内容
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
嵌入向量
|
||||||
|
"""
|
||||||
|
if self.embedding_generator:
|
||||||
|
try:
|
||||||
|
embedding = await self.embedding_generator.generate(text)
|
||||||
|
return embedding
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"嵌入生成失败,使用随机向量: {e}")
|
||||||
|
|
||||||
|
# 回退:生成随机向量(仅用于测试)
|
||||||
|
return np.random.rand(384).astype(np.float32)
|
||||||
|
|
||||||
|
async def _find_existing_node(
|
||||||
|
self, content: str, node_type: NodeType
|
||||||
|
) -> Optional[MemoryNode]:
|
||||||
|
"""
|
||||||
|
查找已存在的完全匹配节点(用于主体和属性)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: 节点内容
|
||||||
|
node_type: 节点类型
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
已存在的节点,如果没有则返回 None
|
||||||
|
"""
|
||||||
|
# 在图存储中查找
|
||||||
|
for node_id in self.graph_store.graph.nodes():
|
||||||
|
node_data = self.graph_store.graph.nodes[node_id]
|
||||||
|
if node_data.get("content") == content and node_data.get("node_type") == node_type.value:
|
||||||
|
# 重建 MemoryNode 对象
|
||||||
|
return MemoryNode(
|
||||||
|
id=node_id,
|
||||||
|
content=node_data["content"],
|
||||||
|
node_type=NodeType(node_data["node_type"]),
|
||||||
|
embedding=node_data.get("embedding"),
|
||||||
|
metadata=node_data.get("metadata", {}),
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _find_similar_topic(
|
||||||
|
self, content: str, embedding: np.ndarray
|
||||||
|
) -> Optional[MemoryNode]:
|
||||||
|
"""
|
||||||
|
查找相似的主题节点(基于语义相似度)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: 内容
|
||||||
|
embedding: 嵌入向量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
相似节点,如果没有则返回 None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 搜索相似节点(阈值 0.95)
|
||||||
|
similar_nodes = await self.vector_store.search_similar_nodes(
|
||||||
|
query_embedding=embedding,
|
||||||
|
limit=1,
|
||||||
|
node_types=[NodeType.TOPIC],
|
||||||
|
min_similarity=0.95,
|
||||||
|
)
|
||||||
|
|
||||||
|
if similar_nodes and similar_nodes[0][1] >= 0.95:
|
||||||
|
node_id, similarity, metadata = similar_nodes[0]
|
||||||
|
logger.debug(
|
||||||
|
f"找到相似主题节点: {metadata.get('content', '')} (相似度: {similarity:.3f})"
|
||||||
|
)
|
||||||
|
# 从图存储中获取完整节点
|
||||||
|
if node_id in self.graph_store.graph.nodes:
|
||||||
|
node_data = self.graph_store.graph.nodes[node_id]
|
||||||
|
existing_node = MemoryNode(
|
||||||
|
id=node_id,
|
||||||
|
content=node_data["content"],
|
||||||
|
node_type=NodeType(node_data["node_type"]),
|
||||||
|
embedding=node_data.get("embedding"),
|
||||||
|
metadata=node_data.get("metadata", {}),
|
||||||
|
)
|
||||||
|
# 添加当前记忆ID到元数据
|
||||||
|
return existing_node
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"相似节点搜索失败: {e}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _find_similar_object(
|
||||||
|
self, content: str, embedding: np.ndarray
|
||||||
|
) -> Optional[MemoryNode]:
|
||||||
|
"""
|
||||||
|
查找相似的客体节点(基于语义相似度)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: 内容
|
||||||
|
embedding: 嵌入向量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
相似节点,如果没有则返回 None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 搜索相似节点(阈值 0.95)
|
||||||
|
similar_nodes = await self.vector_store.search_similar_nodes(
|
||||||
|
query_embedding=embedding,
|
||||||
|
limit=1,
|
||||||
|
node_types=[NodeType.OBJECT],
|
||||||
|
min_similarity=0.95,
|
||||||
|
)
|
||||||
|
|
||||||
|
if similar_nodes and similar_nodes[0][1] >= 0.95:
|
||||||
|
node_id, similarity, metadata = similar_nodes[0]
|
||||||
|
logger.debug(
|
||||||
|
f"找到相似客体节点: {metadata.get('content', '')} (相似度: {similarity:.3f})"
|
||||||
|
)
|
||||||
|
# 从图存储中获取完整节点
|
||||||
|
if node_id in self.graph_store.graph.nodes:
|
||||||
|
node_data = self.graph_store.graph.nodes[node_id]
|
||||||
|
existing_node = MemoryNode(
|
||||||
|
id=node_id,
|
||||||
|
content=node_data["content"],
|
||||||
|
node_type=NodeType(node_data["node_type"]),
|
||||||
|
embedding=node_data.get("embedding"),
|
||||||
|
metadata=node_data.get("metadata", {}),
|
||||||
|
)
|
||||||
|
return existing_node
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"相似节点搜索失败: {e}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _generate_memory_id(self) -> str:
|
||||||
|
"""生成记忆ID"""
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||||
|
return f"mem_{timestamp}"
|
||||||
|
|
||||||
|
def _generate_node_id(self) -> str:
|
||||||
|
"""生成节点ID"""
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||||
|
return f"node_{timestamp}"
|
||||||
|
|
||||||
|
def _generate_edge_id(self) -> str:
|
||||||
|
"""生成边ID"""
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||||
|
return f"edge_{timestamp}"
|
||||||
|
|
||||||
|
async def link_memories(
|
||||||
|
self,
|
||||||
|
source_memory: Memory,
|
||||||
|
target_memory: Memory,
|
||||||
|
relation_type: str,
|
||||||
|
importance: float = 0.6,
|
||||||
|
) -> MemoryEdge:
|
||||||
|
"""
|
||||||
|
关联两个记忆(创建因果或引用边)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source_memory: 源记忆
|
||||||
|
target_memory: 目标记忆
|
||||||
|
relation_type: 关系类型(如 "导致", "引用")
|
||||||
|
importance: 重要性
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
创建的边
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 获取两个记忆的主题节点(作为连接点)
|
||||||
|
source_topic = self._find_topic_node(source_memory)
|
||||||
|
target_topic = self._find_topic_node(target_memory)
|
||||||
|
|
||||||
|
if not source_topic or not target_topic:
|
||||||
|
raise ValueError("无法找到记忆的主题节点")
|
||||||
|
|
||||||
|
# 确定边的类型
|
||||||
|
edge_type = self._determine_edge_type(relation_type)
|
||||||
|
|
||||||
|
# 创建边
|
||||||
|
edge_id = f"edge_{datetime.now().strftime('%Y%m%d_%H%M%S_%f')}"
|
||||||
|
edge = MemoryEdge(
|
||||||
|
id=edge_id,
|
||||||
|
source_id=source_topic.id,
|
||||||
|
target_id=target_topic.id,
|
||||||
|
relation=relation_type,
|
||||||
|
edge_type=edge_type,
|
||||||
|
importance=importance,
|
||||||
|
metadata={
|
||||||
|
"source_memory_id": source_memory.id,
|
||||||
|
"target_memory_id": target_memory.id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"关联记忆: {source_memory.id} --{relation_type}--> {target_memory.id}"
|
||||||
|
)
|
||||||
|
return edge
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"记忆关联失败: {e}", exc_info=True)
|
||||||
|
raise RuntimeError(f"记忆关联失败: {e}")
|
||||||
|
|
||||||
|
def _find_topic_node(self, memory: Memory) -> Optional[MemoryNode]:
|
||||||
|
"""查找记忆中的主题节点"""
|
||||||
|
for node in memory.nodes:
|
||||||
|
if node.node_type == NodeType.TOPIC:
|
||||||
|
return node
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _determine_edge_type(self, relation_type: str) -> EdgeType:
|
||||||
|
"""根据关系类型确定边的类型"""
|
||||||
|
causality_keywords = ["导致", "引起", "造成", "因为", "所以"]
|
||||||
|
reference_keywords = ["引用", "基于", "关于", "参考"]
|
||||||
|
|
||||||
|
for keyword in causality_keywords:
|
||||||
|
if keyword in relation_type:
|
||||||
|
return EdgeType.CAUSALITY
|
||||||
|
|
||||||
|
for keyword in reference_keywords:
|
||||||
|
if keyword in relation_type:
|
||||||
|
return EdgeType.REFERENCE
|
||||||
|
|
||||||
|
# 默认为引用类型
|
||||||
|
return EdgeType.REFERENCE
|
||||||
311
src/memory_graph/core/extractor.py
Normal file
311
src/memory_graph/core/extractor.py
Normal file
@@ -0,0 +1,311 @@
|
|||||||
|
"""
|
||||||
|
记忆提取器:从工具参数中提取和验证记忆元素
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.memory_graph.models import MemoryType
|
||||||
|
from src.memory_graph.utils.time_parser import TimeParser
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryExtractor:
|
||||||
|
"""
|
||||||
|
记忆提取器
|
||||||
|
|
||||||
|
负责:
|
||||||
|
1. 从工具调用参数中提取记忆元素
|
||||||
|
2. 验证参数完整性和有效性
|
||||||
|
3. 标准化时间表达
|
||||||
|
4. 清洗和格式化数据
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, time_parser: Optional[TimeParser] = None):
|
||||||
|
"""
|
||||||
|
初始化记忆提取器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
time_parser: 时间解析器(可选)
|
||||||
|
"""
|
||||||
|
self.time_parser = time_parser or TimeParser()
|
||||||
|
|
||||||
|
def extract_from_tool_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
从工具参数中提取记忆元素
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params: 工具调用参数,例如:
|
||||||
|
{
|
||||||
|
"subject": "我",
|
||||||
|
"memory_type": "事件",
|
||||||
|
"topic": "吃饭",
|
||||||
|
"object": "白米饭",
|
||||||
|
"attributes": {"时间": "今天", "地点": "家里"},
|
||||||
|
"importance": 0.3
|
||||||
|
}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
提取和标准化后的参数字典
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 1. 验证必需参数
|
||||||
|
self._validate_required_params(params)
|
||||||
|
|
||||||
|
# 2. 提取基础元素
|
||||||
|
extracted = {
|
||||||
|
"subject": self._clean_text(params["subject"]),
|
||||||
|
"memory_type": self._parse_memory_type(params["memory_type"]),
|
||||||
|
"topic": self._clean_text(params["topic"]),
|
||||||
|
}
|
||||||
|
|
||||||
|
# 3. 提取可选的客体
|
||||||
|
if "object" in params and params["object"]:
|
||||||
|
extracted["object"] = self._clean_text(params["object"])
|
||||||
|
|
||||||
|
# 4. 提取和标准化属性
|
||||||
|
if "attributes" in params and params["attributes"]:
|
||||||
|
extracted["attributes"] = self._process_attributes(params["attributes"])
|
||||||
|
else:
|
||||||
|
extracted["attributes"] = {}
|
||||||
|
|
||||||
|
# 5. 提取重要性
|
||||||
|
extracted["importance"] = self._parse_importance(params.get("importance", 0.5))
|
||||||
|
|
||||||
|
# 6. 添加时间戳
|
||||||
|
extracted["timestamp"] = datetime.now()
|
||||||
|
|
||||||
|
logger.debug(f"提取记忆元素: {extracted['subject']} - {extracted['topic']}")
|
||||||
|
return extracted
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"记忆提取失败: {e}", exc_info=True)
|
||||||
|
raise ValueError(f"记忆提取失败: {e}")
|
||||||
|
|
||||||
|
def _validate_required_params(self, params: Dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
验证必需参数
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params: 参数字典
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 如果缺少必需参数
|
||||||
|
"""
|
||||||
|
required_fields = ["subject", "memory_type", "topic"]
|
||||||
|
|
||||||
|
for field in required_fields:
|
||||||
|
if field not in params or not params[field]:
|
||||||
|
raise ValueError(f"缺少必需参数: {field}")
|
||||||
|
|
||||||
|
def _clean_text(self, text: Any) -> str:
|
||||||
|
"""
|
||||||
|
清洗文本
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 输入文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
清洗后的文本
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
text = str(text).strip()
|
||||||
|
|
||||||
|
# 移除多余的空格
|
||||||
|
text = " ".join(text.split())
|
||||||
|
|
||||||
|
# 移除特殊字符(保留基本标点)
|
||||||
|
# text = re.sub(r'[^\w\s\u4e00-\u9fff,,.。!!??;;::、]', '', text)
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
def _parse_memory_type(self, type_str: str) -> MemoryType:
|
||||||
|
"""
|
||||||
|
解析记忆类型
|
||||||
|
|
||||||
|
Args:
|
||||||
|
type_str: 类型字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MemoryType 枚举
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 如果类型无效
|
||||||
|
"""
|
||||||
|
type_str = type_str.strip()
|
||||||
|
|
||||||
|
# 尝试直接匹配
|
||||||
|
try:
|
||||||
|
return MemoryType(type_str)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 模糊匹配
|
||||||
|
type_mapping = {
|
||||||
|
"事件": MemoryType.EVENT,
|
||||||
|
"event": MemoryType.EVENT,
|
||||||
|
"事实": MemoryType.FACT,
|
||||||
|
"fact": MemoryType.FACT,
|
||||||
|
"关系": MemoryType.RELATION,
|
||||||
|
"relation": MemoryType.RELATION,
|
||||||
|
"观点": MemoryType.OPINION,
|
||||||
|
"opinion": MemoryType.OPINION,
|
||||||
|
}
|
||||||
|
|
||||||
|
if type_str.lower() in type_mapping:
|
||||||
|
return type_mapping[type_str.lower()]
|
||||||
|
|
||||||
|
raise ValueError(f"无效的记忆类型: {type_str}")
|
||||||
|
|
||||||
|
def _parse_importance(self, importance: Any) -> float:
|
||||||
|
"""
|
||||||
|
解析重要性值
|
||||||
|
|
||||||
|
Args:
|
||||||
|
importance: 重要性值(可以是数字、字符串等)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
0-1之间的浮点数
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
value = float(importance)
|
||||||
|
# 限制在 0-1 范围内
|
||||||
|
return max(0.0, min(1.0, value))
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
logger.warning(f"无效的重要性值: {importance},使用默认值 0.5")
|
||||||
|
return 0.5
|
||||||
|
|
||||||
|
def _process_attributes(self, attributes: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
处理属性字典
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attributes: 原始属性字典
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
处理后的属性字典
|
||||||
|
"""
|
||||||
|
processed = {}
|
||||||
|
|
||||||
|
for key, value in attributes.items():
|
||||||
|
key = key.strip()
|
||||||
|
|
||||||
|
# 特殊处理:时间属性
|
||||||
|
if key in ["时间", "time", "when"]:
|
||||||
|
parsed_time = self.time_parser.parse(str(value))
|
||||||
|
if parsed_time:
|
||||||
|
processed["时间"] = parsed_time.isoformat()
|
||||||
|
else:
|
||||||
|
processed["时间"] = str(value)
|
||||||
|
|
||||||
|
# 特殊处理:地点属性
|
||||||
|
elif key in ["地点", "place", "where", "位置"]:
|
||||||
|
processed["地点"] = self._clean_text(value)
|
||||||
|
|
||||||
|
# 特殊处理:原因属性
|
||||||
|
elif key in ["原因", "reason", "why", "因为"]:
|
||||||
|
processed["原因"] = self._clean_text(value)
|
||||||
|
|
||||||
|
# 特殊处理:方式属性
|
||||||
|
elif key in ["方式", "how", "manner"]:
|
||||||
|
processed["方式"] = self._clean_text(value)
|
||||||
|
|
||||||
|
# 其他属性
|
||||||
|
else:
|
||||||
|
processed[key] = self._clean_text(value)
|
||||||
|
|
||||||
|
return processed
|
||||||
|
|
||||||
|
def extract_link_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
提取记忆关联参数(用于 link_memories 工具)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params: 工具参数,例如:
|
||||||
|
{
|
||||||
|
"source_memory_description": "我今天不开心",
|
||||||
|
"target_memory_description": "我摔东西",
|
||||||
|
"relation_type": "导致",
|
||||||
|
"importance": 0.6
|
||||||
|
}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
提取后的参数
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
required = ["source_memory_description", "target_memory_description", "relation_type"]
|
||||||
|
|
||||||
|
for field in required:
|
||||||
|
if field not in params or not params[field]:
|
||||||
|
raise ValueError(f"缺少必需参数: {field}")
|
||||||
|
|
||||||
|
extracted = {
|
||||||
|
"source_description": self._clean_text(params["source_memory_description"]),
|
||||||
|
"target_description": self._clean_text(params["target_memory_description"]),
|
||||||
|
"relation_type": self._clean_text(params["relation_type"]),
|
||||||
|
"importance": self._parse_importance(params.get("importance", 0.6)),
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"提取关联参数: {extracted['source_description']} --{extracted['relation_type']}--> "
|
||||||
|
f"{extracted['target_description']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return extracted
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"关联参数提取失败: {e}", exc_info=True)
|
||||||
|
raise ValueError(f"关联参数提取失败: {e}")
|
||||||
|
|
||||||
|
def validate_relation_type(self, relation_type: str) -> str:
|
||||||
|
"""
|
||||||
|
验证关系类型
|
||||||
|
|
||||||
|
Args:
|
||||||
|
relation_type: 关系类型字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
标准化的关系类型
|
||||||
|
"""
|
||||||
|
# 因果关系映射
|
||||||
|
causality_relations = {
|
||||||
|
"因为": "因为",
|
||||||
|
"所以": "所以",
|
||||||
|
"导致": "导致",
|
||||||
|
"引起": "导致",
|
||||||
|
"造成": "导致",
|
||||||
|
"因": "因为",
|
||||||
|
"果": "所以",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 引用关系映射
|
||||||
|
reference_relations = {
|
||||||
|
"引用": "引用",
|
||||||
|
"基于": "基于",
|
||||||
|
"关于": "关于",
|
||||||
|
"参考": "引用",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 相关关系
|
||||||
|
related_relations = {
|
||||||
|
"相关": "相关",
|
||||||
|
"有关": "相关",
|
||||||
|
"联系": "相关",
|
||||||
|
}
|
||||||
|
|
||||||
|
relation_type = relation_type.strip()
|
||||||
|
|
||||||
|
# 查找匹配
|
||||||
|
for mapping in [causality_relations, reference_relations, related_relations]:
|
||||||
|
if relation_type in mapping:
|
||||||
|
return mapping[relation_type]
|
||||||
|
|
||||||
|
# 未找到映射,返回原值
|
||||||
|
logger.warning(f"未识别的关系类型: {relation_type},使用原值")
|
||||||
|
return relation_type
|
||||||
@@ -92,17 +92,27 @@ class VectorStore:
|
|||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# 准备元数据(ChromaDB 只支持 str, int, float, bool)
|
||||||
|
metadata = {
|
||||||
|
"content": node.content,
|
||||||
|
"node_type": node.node_type.value,
|
||||||
|
"created_at": node.created_at.isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
# 处理额外的元数据,将 list 转换为 JSON 字符串
|
||||||
|
for key, value in node.metadata.items():
|
||||||
|
if isinstance(value, (list, dict)):
|
||||||
|
import json
|
||||||
|
metadata[key] = json.dumps(value, ensure_ascii=False)
|
||||||
|
elif isinstance(value, (str, int, float, bool)) or value is None:
|
||||||
|
metadata[key] = value
|
||||||
|
else:
|
||||||
|
metadata[key] = str(value)
|
||||||
|
|
||||||
self.collection.add(
|
self.collection.add(
|
||||||
ids=[node.id],
|
ids=[node.id],
|
||||||
embeddings=[node.embedding.tolist()],
|
embeddings=[node.embedding.tolist()],
|
||||||
metadatas=[
|
metadatas=[metadata],
|
||||||
{
|
|
||||||
"content": node.content,
|
|
||||||
"node_type": node.node_type.value,
|
|
||||||
"created_at": node.created_at.isoformat(),
|
|
||||||
**node.metadata,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
documents=[node.content], # 文本内容用于检索
|
documents=[node.content], # 文本内容用于检索
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -130,18 +140,28 @@ class VectorStore:
|
|||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# 准备元数据
|
||||||
|
import json
|
||||||
|
metadatas = []
|
||||||
|
for n in valid_nodes:
|
||||||
|
metadata = {
|
||||||
|
"content": n.content,
|
||||||
|
"node_type": n.node_type.value,
|
||||||
|
"created_at": n.created_at.isoformat(),
|
||||||
|
}
|
||||||
|
for key, value in n.metadata.items():
|
||||||
|
if isinstance(value, (list, dict)):
|
||||||
|
metadata[key] = json.dumps(value, ensure_ascii=False)
|
||||||
|
elif isinstance(value, (str, int, float, bool)) or value is None:
|
||||||
|
metadata[key] = value # type: ignore
|
||||||
|
else:
|
||||||
|
metadata[key] = str(value)
|
||||||
|
metadatas.append(metadata)
|
||||||
|
|
||||||
self.collection.add(
|
self.collection.add(
|
||||||
ids=[n.id for n in valid_nodes],
|
ids=[n.id for n in valid_nodes],
|
||||||
embeddings=[n.embedding.tolist() for n in valid_nodes],
|
embeddings=[n.embedding.tolist() for n in valid_nodes], # type: ignore
|
||||||
metadatas=[
|
metadatas=metadatas,
|
||||||
{
|
|
||||||
"content": n.content,
|
|
||||||
"node_type": n.node_type.value,
|
|
||||||
"created_at": n.created_at.isoformat(),
|
|
||||||
**n.metadata,
|
|
||||||
}
|
|
||||||
for n in valid_nodes
|
|
||||||
],
|
|
||||||
documents=[n.content for n in valid_nodes],
|
documents=[n.content for n in valid_nodes],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -187,16 +207,26 @@ class VectorStore:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 解析结果
|
# 解析结果
|
||||||
|
import json
|
||||||
similar_nodes = []
|
similar_nodes = []
|
||||||
if results["ids"] and results["ids"][0]:
|
if results["ids"] and results["ids"][0]:
|
||||||
for i, node_id in enumerate(results["ids"][0]):
|
for i, node_id in enumerate(results["ids"][0]):
|
||||||
# ChromaDB 返回的是距离,需要转换为相似度
|
# ChromaDB 返回的是距离,需要转换为相似度
|
||||||
# 余弦距离: distance = 1 - similarity
|
# 余弦距离: distance = 1 - similarity
|
||||||
distance = results["distances"][0][i]
|
distance = results["distances"][0][i] if results["distances"] else 0.0 # type: ignore
|
||||||
similarity = 1.0 - distance
|
similarity = 1.0 - distance
|
||||||
|
|
||||||
if similarity >= min_similarity:
|
if similarity >= min_similarity:
|
||||||
metadata = results["metadatas"][0][i] if results["metadatas"] else {}
|
metadata = results["metadatas"][0][i] if results["metadatas"] else {} # type: ignore
|
||||||
|
|
||||||
|
# 解析 JSON 字符串回列表/字典
|
||||||
|
for key, value in list(metadata.items()):
|
||||||
|
if isinstance(value, str) and (value.startswith('[') or value.startswith('{')):
|
||||||
|
try:
|
||||||
|
metadata[key] = json.loads(value)
|
||||||
|
except:
|
||||||
|
pass # 保持原值
|
||||||
|
|
||||||
similar_nodes.append((node_id, similarity, metadata))
|
similar_nodes.append((node_id, similarity, metadata))
|
||||||
|
|
||||||
logger.debug(f"相似节点搜索: 找到 {len(similar_nodes)} 个结果")
|
logger.debug(f"相似节点搜索: 找到 {len(similar_nodes)} 个结果")
|
||||||
|
|||||||
7
src/memory_graph/tools/__init__.py
Normal file
7
src/memory_graph/tools/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
"""
|
||||||
|
记忆系统工具模块
|
||||||
|
"""
|
||||||
|
|
||||||
|
from src.memory_graph.tools.memory_tools import MemoryTools
|
||||||
|
|
||||||
|
__all__ = ["MemoryTools"]
|
||||||
495
src/memory_graph/tools/memory_tools.py
Normal file
495
src/memory_graph/tools/memory_tools.py
Normal file
@@ -0,0 +1,495 @@
|
|||||||
|
"""
|
||||||
|
LLM 工具接口:定义记忆系统的工具 schema 和执行逻辑
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.memory_graph.core.builder import MemoryBuilder
|
||||||
|
from src.memory_graph.core.extractor import MemoryExtractor
|
||||||
|
from src.memory_graph.models import Memory, MemoryStatus
|
||||||
|
from src.memory_graph.storage.graph_store import GraphStore
|
||||||
|
from src.memory_graph.storage.persistence import PersistenceManager
|
||||||
|
from src.memory_graph.storage.vector_store import VectorStore
|
||||||
|
from src.memory_graph.utils.embeddings import EmbeddingGenerator
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryTools:
|
||||||
|
"""
|
||||||
|
记忆系统工具集
|
||||||
|
|
||||||
|
提供给 LLM 使用的工具接口:
|
||||||
|
1. create_memory: 创建新记忆
|
||||||
|
2. link_memories: 关联两个记忆
|
||||||
|
3. search_memories: 搜索记忆
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vector_store: VectorStore,
|
||||||
|
graph_store: GraphStore,
|
||||||
|
persistence_manager: PersistenceManager,
|
||||||
|
embedding_generator: Optional[EmbeddingGenerator] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
初始化工具集
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vector_store: 向量存储
|
||||||
|
graph_store: 图存储
|
||||||
|
persistence_manager: 持久化管理器
|
||||||
|
embedding_generator: 嵌入生成器(可选)
|
||||||
|
"""
|
||||||
|
self.vector_store = vector_store
|
||||||
|
self.graph_store = graph_store
|
||||||
|
self.persistence_manager = persistence_manager
|
||||||
|
self._initialized = False
|
||||||
|
|
||||||
|
# 初始化组件
|
||||||
|
self.extractor = MemoryExtractor()
|
||||||
|
self.builder = MemoryBuilder(
|
||||||
|
vector_store=vector_store,
|
||||||
|
graph_store=graph_store,
|
||||||
|
embedding_generator=embedding_generator,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _ensure_initialized(self):
|
||||||
|
"""确保向量存储已初始化"""
|
||||||
|
if not self._initialized:
|
||||||
|
await self.vector_store.initialize()
|
||||||
|
self._initialized = True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_create_memory_schema() -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
获取 create_memory 工具的 JSON schema
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
工具 schema 定义
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"name": "create_memory",
|
||||||
|
"description": "创建一个新的记忆。记忆由主体、类型、主题、客体(可选)和属性组成。",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"subject": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "记忆的主体,通常是'我'、'用户'或具体的人名",
|
||||||
|
},
|
||||||
|
"memory_type": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["事件", "事实", "关系", "观点"],
|
||||||
|
"description": "记忆类型:事件(时间绑定的动作)、事实(稳定状态)、关系(人际关系)、观点(主观评价)",
|
||||||
|
},
|
||||||
|
"topic": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "记忆的主题,即发生的事情或状态",
|
||||||
|
},
|
||||||
|
"object": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "记忆的客体,即主题作用的对象(可选)",
|
||||||
|
},
|
||||||
|
"attributes": {
|
||||||
|
"type": "object",
|
||||||
|
"description": "记忆的属性,如时间、地点、原因、方式等",
|
||||||
|
"properties": {
|
||||||
|
"时间": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "时间表达式,如'今天'、'昨天'、'3天前'、'2025-11-05'",
|
||||||
|
},
|
||||||
|
"地点": {"type": "string", "description": "地点"},
|
||||||
|
"原因": {"type": "string", "description": "原因"},
|
||||||
|
"方式": {"type": "string", "description": "方式"},
|
||||||
|
},
|
||||||
|
"additionalProperties": True,
|
||||||
|
},
|
||||||
|
"importance": {
|
||||||
|
"type": "number",
|
||||||
|
"minimum": 0.0,
|
||||||
|
"maximum": 1.0,
|
||||||
|
"description": "记忆的重要性,0-1之间的浮点数,默认0.5",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["subject", "memory_type", "topic"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_link_memories_schema() -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
获取 link_memories 工具的 JSON schema
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
工具 schema 定义
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"name": "link_memories",
|
||||||
|
"description": "关联两个已存在的记忆,建立因果或引用关系。",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"source_memory_description": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "源记忆的描述,用于查找对应的记忆",
|
||||||
|
},
|
||||||
|
"target_memory_description": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "目标记忆的描述,用于查找对应的记忆",
|
||||||
|
},
|
||||||
|
"relation_type": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "关系类型,如'导致'、'引起'、'因为'、'所以'、'引用'、'基于'等",
|
||||||
|
},
|
||||||
|
"importance": {
|
||||||
|
"type": "number",
|
||||||
|
"minimum": 0.0,
|
||||||
|
"maximum": 1.0,
|
||||||
|
"description": "关系的重要性,0-1之间的浮点数,默认0.6",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": [
|
||||||
|
"source_memory_description",
|
||||||
|
"target_memory_description",
|
||||||
|
"relation_type",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_search_memories_schema() -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
获取 search_memories 工具的 JSON schema
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
工具 schema 定义
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"name": "search_memories",
|
||||||
|
"description": "搜索相关的记忆。支持语义搜索、图遍历和时间过滤。",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "搜索查询,描述要查找的记忆内容",
|
||||||
|
},
|
||||||
|
"memory_types": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["事件", "事实", "关系", "观点"],
|
||||||
|
},
|
||||||
|
"description": "要搜索的记忆类型,可多选",
|
||||||
|
},
|
||||||
|
"time_range": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"start": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "开始时间,如'3天前'、'2025-11-01'",
|
||||||
|
},
|
||||||
|
"end": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "结束时间,如'今天'、'2025-11-05'",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"description": "时间范围过滤(可选)",
|
||||||
|
},
|
||||||
|
"top_k": {
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 1,
|
||||||
|
"maximum": 50,
|
||||||
|
"description": "返回结果数量,默认10",
|
||||||
|
},
|
||||||
|
"expand_depth": {
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 0,
|
||||||
|
"maximum": 3,
|
||||||
|
"description": "图遍历扩展深度,0表示不扩展,默认1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["query"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def create_memory(self, **params) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
执行 create_memory 工具
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**params: 工具参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
执行结果
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.info(f"创建记忆: {params.get('subject')} - {params.get('topic')}")
|
||||||
|
|
||||||
|
# 0. 确保初始化
|
||||||
|
await self._ensure_initialized()
|
||||||
|
|
||||||
|
# 1. 提取参数
|
||||||
|
extracted = self.extractor.extract_from_tool_params(params)
|
||||||
|
|
||||||
|
# 2. 构建记忆
|
||||||
|
memory = await self.builder.build_memory(extracted)
|
||||||
|
|
||||||
|
# 3. 添加到存储(暂存状态)
|
||||||
|
await self._add_memory_to_stores(memory)
|
||||||
|
|
||||||
|
# 4. 保存到磁盘
|
||||||
|
await self.persistence_manager.save_graph_store(self.graph_store)
|
||||||
|
|
||||||
|
logger.info(f"记忆创建成功: {memory.id}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"memory_id": memory.id,
|
||||||
|
"message": f"记忆已创建: {extracted['subject']} - {extracted['topic']}",
|
||||||
|
"nodes_count": len(memory.nodes),
|
||||||
|
"edges_count": len(memory.edges),
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"记忆创建失败: {e}", exc_info=True)
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": str(e),
|
||||||
|
"message": "记忆创建失败",
|
||||||
|
}
|
||||||
|
|
||||||
|
async def link_memories(self, **params) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
执行 link_memories 工具
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**params: 工具参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
执行结果
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.info(
|
||||||
|
f"关联记忆: {params.get('source_memory_description')} -> "
|
||||||
|
f"{params.get('target_memory_description')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 1. 提取参数
|
||||||
|
extracted = self.extractor.extract_link_params(params)
|
||||||
|
|
||||||
|
# 2. 查找源记忆和目标记忆
|
||||||
|
source_memory = await self._find_memory_by_description(
|
||||||
|
extracted["source_description"]
|
||||||
|
)
|
||||||
|
target_memory = await self._find_memory_by_description(
|
||||||
|
extracted["target_description"]
|
||||||
|
)
|
||||||
|
|
||||||
|
if not source_memory:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": "找不到源记忆",
|
||||||
|
"message": f"未找到匹配的源记忆: {extracted['source_description']}",
|
||||||
|
}
|
||||||
|
|
||||||
|
if not target_memory:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": "找不到目标记忆",
|
||||||
|
"message": f"未找到匹配的目标记忆: {extracted['target_description']}",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 3. 创建关联边
|
||||||
|
edge = await self.builder.link_memories(
|
||||||
|
source_memory=source_memory,
|
||||||
|
target_memory=target_memory,
|
||||||
|
relation_type=extracted["relation_type"],
|
||||||
|
importance=extracted["importance"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. 添加边到图存储
|
||||||
|
self.graph_store.graph.add_edge(
|
||||||
|
edge.source_id,
|
||||||
|
edge.target_id,
|
||||||
|
relation=edge.relation,
|
||||||
|
edge_type=edge.edge_type.value,
|
||||||
|
importance=edge.importance,
|
||||||
|
**edge.metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
# 5. 保存
|
||||||
|
await self.persistence_manager.save_graph_store(self.graph_store)
|
||||||
|
|
||||||
|
logger.info(f"记忆关联成功: {source_memory.id} -> {target_memory.id}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": f"记忆已关联: {extracted['relation_type']}",
|
||||||
|
"source_memory_id": source_memory.id,
|
||||||
|
"target_memory_id": target_memory.id,
|
||||||
|
"relation_type": extracted["relation_type"],
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"记忆关联失败: {e}", exc_info=True)
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": str(e),
|
||||||
|
"message": "记忆关联失败",
|
||||||
|
}
|
||||||
|
|
||||||
|
async def search_memories(self, **params) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
执行 search_memories 工具
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**params: 工具参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
搜索结果
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
query = params.get("query", "")
|
||||||
|
top_k = params.get("top_k", 10)
|
||||||
|
expand_depth = params.get("expand_depth", 1)
|
||||||
|
|
||||||
|
logger.info(f"搜索记忆: {query} (top_k={top_k}, expand_depth={expand_depth})")
|
||||||
|
|
||||||
|
# 0. 确保初始化
|
||||||
|
await self._ensure_initialized()
|
||||||
|
|
||||||
|
# 1. 生成查询嵌入
|
||||||
|
if self.builder.embedding_generator:
|
||||||
|
query_embedding = await self.builder.embedding_generator.generate(query)
|
||||||
|
else:
|
||||||
|
logger.warning("未配置嵌入生成器,使用随机向量")
|
||||||
|
import numpy as np
|
||||||
|
query_embedding = np.random.rand(384).astype(np.float32)
|
||||||
|
|
||||||
|
# 2. 向量搜索
|
||||||
|
node_types_filter = None
|
||||||
|
if "memory_types" in params:
|
||||||
|
# 添加类型过滤
|
||||||
|
pass
|
||||||
|
|
||||||
|
similar_nodes = await self.vector_store.search_similar_nodes(
|
||||||
|
query_embedding=query_embedding,
|
||||||
|
limit=top_k * 2, # 多取一些,后续过滤
|
||||||
|
node_types=node_types_filter,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. 提取记忆ID
|
||||||
|
memory_ids = set()
|
||||||
|
for node_id, similarity, metadata in similar_nodes:
|
||||||
|
if "memory_ids" in metadata:
|
||||||
|
memory_ids.update(metadata["memory_ids"])
|
||||||
|
|
||||||
|
# 4. 获取完整记忆
|
||||||
|
memories = []
|
||||||
|
for memory_id in list(memory_ids)[:top_k]:
|
||||||
|
memory = self.graph_store.get_memory_by_id(memory_id)
|
||||||
|
if memory:
|
||||||
|
memories.append(memory)
|
||||||
|
|
||||||
|
# 5. 格式化结果
|
||||||
|
results = []
|
||||||
|
for memory in memories:
|
||||||
|
result = {
|
||||||
|
"memory_id": memory.id,
|
||||||
|
"importance": memory.importance,
|
||||||
|
"created_at": memory.created_at.isoformat(),
|
||||||
|
"summary": self._summarize_memory(memory),
|
||||||
|
}
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
logger.info(f"搜索完成: 找到 {len(results)} 条记忆")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"results": results,
|
||||||
|
"total": len(results),
|
||||||
|
"query": query,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"记忆搜索失败: {e}", exc_info=True)
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": str(e),
|
||||||
|
"message": "记忆搜索失败",
|
||||||
|
"results": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _add_memory_to_stores(self, memory: Memory):
|
||||||
|
"""将记忆添加到存储"""
|
||||||
|
# 1. 添加到图存储
|
||||||
|
self.graph_store.add_memory(memory)
|
||||||
|
|
||||||
|
# 2. 添加有嵌入的节点到向量存储
|
||||||
|
for node in memory.nodes:
|
||||||
|
if node.embedding is not None:
|
||||||
|
await self.vector_store.add_node(node)
|
||||||
|
|
||||||
|
async def _find_memory_by_description(self, description: str) -> Optional[Memory]:
|
||||||
|
"""
|
||||||
|
通过描述查找记忆
|
||||||
|
|
||||||
|
Args:
|
||||||
|
description: 记忆描述
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
找到的记忆,如果没有则返回 None
|
||||||
|
"""
|
||||||
|
# 使用语义搜索查找最相关的记忆
|
||||||
|
if self.builder.embedding_generator:
|
||||||
|
query_embedding = await self.builder.embedding_generator.generate(description)
|
||||||
|
else:
|
||||||
|
import numpy as np
|
||||||
|
query_embedding = np.random.rand(384).astype(np.float32)
|
||||||
|
|
||||||
|
# 搜索相似节点
|
||||||
|
similar_nodes = await self.vector_store.search_similar_nodes(
|
||||||
|
query_embedding=query_embedding,
|
||||||
|
limit=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not similar_nodes:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 获取最相似节点关联的记忆
|
||||||
|
node_id, similarity, metadata = similar_nodes[0]
|
||||||
|
if "memory_ids" in metadata and metadata["memory_ids"]:
|
||||||
|
memory_id = metadata["memory_ids"][0]
|
||||||
|
return self.graph_store.get_memory_by_id(memory_id)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _summarize_memory(self, memory: Memory) -> str:
|
||||||
|
"""生成记忆摘要"""
|
||||||
|
if not memory.metadata:
|
||||||
|
return "未知记忆"
|
||||||
|
|
||||||
|
subject = memory.metadata.get("subject", "")
|
||||||
|
topic = memory.metadata.get("topic", "")
|
||||||
|
memory_type = memory.metadata.get("memory_type", "")
|
||||||
|
|
||||||
|
return f"{subject} - {memory_type}: {topic}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_all_tool_schemas() -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
获取所有工具的 schema
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
工具 schema 列表
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
MemoryTools.get_create_memory_schema(),
|
||||||
|
MemoryTools.get_link_memories_schema(),
|
||||||
|
MemoryTools.get_search_memories_schema(),
|
||||||
|
]
|
||||||
8
src/memory_graph/utils/__init__.py
Normal file
8
src/memory_graph/utils/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
"""
|
||||||
|
工具模块
|
||||||
|
"""
|
||||||
|
|
||||||
|
from src.memory_graph.utils.embeddings import EmbeddingGenerator, get_embedding_generator
|
||||||
|
from src.memory_graph.utils.time_parser import TimeParser
|
||||||
|
|
||||||
|
__all__ = ["TimeParser", "EmbeddingGenerator", "get_embedding_generator"]
|
||||||
299
src/memory_graph/utils/embeddings.py
Normal file
299
src/memory_graph/utils/embeddings.py
Normal file
@@ -0,0 +1,299 @@
|
|||||||
|
"""
|
||||||
|
嵌入向量生成器:优先使用配置的 embedding API,sentence-transformers 作为备选
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingGenerator:
|
||||||
|
"""
|
||||||
|
嵌入向量生成器
|
||||||
|
|
||||||
|
策略:
|
||||||
|
1. 优先使用配置的 embedding API(通过 LLMRequest)
|
||||||
|
2. 如果 API 不可用,回退到本地 sentence-transformers
|
||||||
|
3. 如果 sentence-transformers 未安装,使用随机向量(仅测试)
|
||||||
|
|
||||||
|
优点:
|
||||||
|
- 降低本地运算负载
|
||||||
|
- 即使未安装 sentence-transformers 也可正常运行
|
||||||
|
- 保持与现有系统的一致性
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
use_api: bool = True,
|
||||||
|
fallback_model_name: str = "paraphrase-multilingual-MiniLM-L12-v2",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
初始化嵌入生成器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
use_api: 是否优先使用 API(默认 True)
|
||||||
|
fallback_model_name: 回退本地模型名称
|
||||||
|
"""
|
||||||
|
self.use_api = use_api
|
||||||
|
self.fallback_model_name = fallback_model_name
|
||||||
|
|
||||||
|
# API 相关
|
||||||
|
self._llm_request = None
|
||||||
|
self._api_available = False
|
||||||
|
self._api_dimension = None
|
||||||
|
|
||||||
|
# 本地模型相关
|
||||||
|
self._local_model = None
|
||||||
|
self._local_model_loaded = False
|
||||||
|
|
||||||
|
async def _initialize_api(self):
|
||||||
|
"""初始化 embedding API"""
|
||||||
|
if self._api_available:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
from src.config.config import model_config
|
||||||
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
|
||||||
|
embedding_config = model_config.model_task_config.embedding
|
||||||
|
self._llm_request = LLMRequest(
|
||||||
|
model_set=embedding_config,
|
||||||
|
request_type="memory_graph.embedding"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取嵌入维度
|
||||||
|
if hasattr(embedding_config, "embedding_dimension") and embedding_config.embedding_dimension:
|
||||||
|
self._api_dimension = embedding_config.embedding_dimension
|
||||||
|
|
||||||
|
self._api_available = True
|
||||||
|
logger.info(f"✅ Embedding API 初始化成功 (维度: {self._api_dimension})")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"⚠️ Embedding API 初始化失败: {e}")
|
||||||
|
self._api_available = False
|
||||||
|
|
||||||
|
def _load_local_model(self):
|
||||||
|
"""延迟加载本地模型"""
|
||||||
|
if not self._local_model_loaded:
|
||||||
|
try:
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
logger.info(f"📦 加载本地嵌入模型: {self.fallback_model_name}")
|
||||||
|
self._local_model = SentenceTransformer(self.fallback_model_name)
|
||||||
|
self._local_model_loaded = True
|
||||||
|
logger.info("✅ 本地嵌入模型加载成功")
|
||||||
|
except ImportError:
|
||||||
|
logger.warning(
|
||||||
|
"⚠️ sentence-transformers 未安装,将使用随机向量(仅测试用)\n"
|
||||||
|
" 安装方法: pip install sentence-transformers"
|
||||||
|
)
|
||||||
|
self._local_model_loaded = False
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"⚠️ 本地模型加载失败: {e}")
|
||||||
|
self._local_model_loaded = False
|
||||||
|
|
||||||
|
async def generate(self, text: str) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
生成单个文本的嵌入向量
|
||||||
|
|
||||||
|
策略:
|
||||||
|
1. 优先使用 API
|
||||||
|
2. API 失败则使用本地模型
|
||||||
|
3. 本地模型不可用则使用随机向量
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 输入文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
嵌入向量
|
||||||
|
"""
|
||||||
|
if not text or not text.strip():
|
||||||
|
logger.warning("输入文本为空,返回零向量")
|
||||||
|
dim = self._get_dimension()
|
||||||
|
return np.zeros(dim, dtype=np.float32)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 策略 1: 使用 API
|
||||||
|
if self.use_api:
|
||||||
|
embedding = await self._generate_with_api(text)
|
||||||
|
if embedding is not None:
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
# 策略 2: 使用本地模型
|
||||||
|
embedding = await self._generate_with_local_model(text)
|
||||||
|
if embedding is not None:
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
# 策略 3: 随机向量(仅测试)
|
||||||
|
logger.warning(f"⚠️ 所有嵌入策略失败,使用随机向量: {text[:30]}...")
|
||||||
|
dim = self._get_dimension()
|
||||||
|
return np.random.rand(dim).astype(np.float32)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ 嵌入生成失败: {e}", exc_info=True)
|
||||||
|
dim = self._get_dimension()
|
||||||
|
return np.random.rand(dim).astype(np.float32)
|
||||||
|
|
||||||
|
async def _generate_with_api(self, text: str) -> Optional[np.ndarray]:
|
||||||
|
"""使用 API 生成嵌入"""
|
||||||
|
try:
|
||||||
|
# 初始化 API
|
||||||
|
if not self._api_available:
|
||||||
|
await self._initialize_api()
|
||||||
|
|
||||||
|
if not self._api_available or not self._llm_request:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 调用 API
|
||||||
|
embedding_list, model_name = await self._llm_request.get_embedding(text)
|
||||||
|
|
||||||
|
if embedding_list and len(embedding_list) > 0:
|
||||||
|
embedding = np.array(embedding_list, dtype=np.float32)
|
||||||
|
logger.debug(f"🌐 API 生成嵌入: {text[:30]}... -> {len(embedding)}维 (模型: {model_name})")
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"API 嵌入生成失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _generate_with_local_model(self, text: str) -> Optional[np.ndarray]:
|
||||||
|
"""使用本地模型生成嵌入"""
|
||||||
|
try:
|
||||||
|
# 加载本地模型
|
||||||
|
if not self._local_model_loaded:
|
||||||
|
self._load_local_model()
|
||||||
|
|
||||||
|
if not self._local_model_loaded or not self._local_model:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 在线程池中运行
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
embedding = await loop.run_in_executor(None, self._encode_single_local, text)
|
||||||
|
|
||||||
|
logger.debug(f"💻 本地生成嵌入: {text[:30]}... -> {len(embedding)}维")
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"本地模型嵌入生成失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _encode_single_local(self, text: str) -> np.ndarray:
|
||||||
|
"""同步编码单个文本(本地模型)"""
|
||||||
|
if self._local_model is None:
|
||||||
|
raise RuntimeError("本地模型未加载")
|
||||||
|
embedding = self._local_model.encode(text, convert_to_numpy=True) # type: ignore
|
||||||
|
return embedding.astype(np.float32)
|
||||||
|
|
||||||
|
def _get_dimension(self) -> int:
|
||||||
|
"""获取嵌入维度"""
|
||||||
|
# 优先使用 API 维度
|
||||||
|
if self._api_dimension:
|
||||||
|
return self._api_dimension
|
||||||
|
|
||||||
|
# 其次使用本地模型维度
|
||||||
|
if self._local_model_loaded and self._local_model:
|
||||||
|
try:
|
||||||
|
return self._local_model.get_sentence_embedding_dimension()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 默认 384(sentence-transformers 常用维度)
|
||||||
|
return 384
|
||||||
|
|
||||||
|
async def generate_batch(self, texts: List[str]) -> List[np.ndarray]:
|
||||||
|
"""
|
||||||
|
批量生成嵌入向量
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: 文本列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
嵌入向量列表
|
||||||
|
"""
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 过滤空文本
|
||||||
|
valid_texts = [t for t in texts if t and t.strip()]
|
||||||
|
if not valid_texts:
|
||||||
|
logger.warning("所有文本为空,返回零向量列表")
|
||||||
|
dim = self._get_dimension()
|
||||||
|
return [np.zeros(dim, dtype=np.float32) for _ in texts]
|
||||||
|
|
||||||
|
# 使用 API 批量生成(如果可用)
|
||||||
|
if self.use_api:
|
||||||
|
results = await self._generate_batch_with_api(valid_texts)
|
||||||
|
if results:
|
||||||
|
return results
|
||||||
|
|
||||||
|
# 回退到逐个生成
|
||||||
|
results = []
|
||||||
|
for text in valid_texts:
|
||||||
|
embedding = await self.generate(text)
|
||||||
|
results.append(embedding)
|
||||||
|
|
||||||
|
logger.info(f"✅ 批量生成嵌入: {len(texts)} 个文本")
|
||||||
|
return results
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ 批量嵌入生成失败: {e}", exc_info=True)
|
||||||
|
dim = self._get_dimension()
|
||||||
|
return [np.random.rand(dim).astype(np.float32) for _ in texts]
|
||||||
|
|
||||||
|
async def _generate_batch_with_api(self, texts: List[str]) -> Optional[List[np.ndarray]]:
|
||||||
|
"""使用 API 批量生成"""
|
||||||
|
try:
|
||||||
|
# 对于大多数 API,批量调用就是多次单独调用
|
||||||
|
# 这里保持简单,逐个调用
|
||||||
|
results = []
|
||||||
|
for text in texts:
|
||||||
|
embedding = await self._generate_with_api(text)
|
||||||
|
if embedding is None:
|
||||||
|
return None # 如果任何一个失败,返回 None 触发回退
|
||||||
|
results.append(embedding)
|
||||||
|
return results
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"API 批量生成失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_embedding_dimension(self) -> int:
|
||||||
|
"""获取嵌入向量维度"""
|
||||||
|
return self._get_dimension()
|
||||||
|
|
||||||
|
|
||||||
|
# 全局单例
|
||||||
|
_global_generator: Optional[EmbeddingGenerator] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_embedding_generator(
|
||||||
|
use_api: bool = True,
|
||||||
|
fallback_model_name: str = "paraphrase-multilingual-MiniLM-L12-v2",
|
||||||
|
) -> EmbeddingGenerator:
|
||||||
|
"""
|
||||||
|
获取全局嵌入生成器单例
|
||||||
|
|
||||||
|
Args:
|
||||||
|
use_api: 是否优先使用 API
|
||||||
|
fallback_model_name: 回退本地模型名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EmbeddingGenerator 实例
|
||||||
|
"""
|
||||||
|
global _global_generator
|
||||||
|
if _global_generator is None:
|
||||||
|
_global_generator = EmbeddingGenerator(
|
||||||
|
use_api=use_api,
|
||||||
|
fallback_model_name=fallback_model_name
|
||||||
|
)
|
||||||
|
return _global_generator
|
||||||
391
src/memory_graph/utils/time_parser.py
Normal file
391
src/memory_graph/utils/time_parser.py
Normal file
@@ -0,0 +1,391 @@
|
|||||||
|
"""
|
||||||
|
时间解析器:将相对时间转换为绝对时间
|
||||||
|
|
||||||
|
支持的时间表达:
|
||||||
|
- 今天、明天、昨天、前天、后天
|
||||||
|
- X天前、X天后
|
||||||
|
- X小时前、X小时后
|
||||||
|
- 上周、上个月、去年
|
||||||
|
- 具体日期:2025-11-05, 11月5日
|
||||||
|
- 时间点:早上8点、下午3点、晚上9点
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TimeParser:
|
||||||
|
"""
|
||||||
|
时间解析器
|
||||||
|
|
||||||
|
负责将自然语言时间表达转换为标准化的绝对时间
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, reference_time: Optional[datetime] = None):
|
||||||
|
"""
|
||||||
|
初始化时间解析器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reference_time: 参考时间(通常是当前时间)
|
||||||
|
"""
|
||||||
|
self.reference_time = reference_time or datetime.now()
|
||||||
|
|
||||||
|
def parse(self, time_str: str) -> Optional[datetime]:
|
||||||
|
"""
|
||||||
|
解析时间字符串
|
||||||
|
|
||||||
|
Args:
|
||||||
|
time_str: 时间字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
标准化的datetime对象,如果解析失败则返回None
|
||||||
|
"""
|
||||||
|
if not time_str or not isinstance(time_str, str):
|
||||||
|
return None
|
||||||
|
|
||||||
|
time_str = time_str.strip()
|
||||||
|
|
||||||
|
# 尝试各种解析方法
|
||||||
|
parsers = [
|
||||||
|
self._parse_relative_day,
|
||||||
|
self._parse_days_ago,
|
||||||
|
self._parse_hours_ago,
|
||||||
|
self._parse_week_month_year,
|
||||||
|
self._parse_specific_date,
|
||||||
|
self._parse_time_of_day,
|
||||||
|
]
|
||||||
|
|
||||||
|
for parser in parsers:
|
||||||
|
try:
|
||||||
|
result = parser(time_str)
|
||||||
|
if result:
|
||||||
|
logger.debug(f"时间解析: '{time_str}' → {result.isoformat()}")
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"解析器 {parser.__name__} 失败: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.warning(f"无法解析时间: '{time_str}',使用当前时间")
|
||||||
|
return self.reference_time
|
||||||
|
|
||||||
|
def _parse_relative_day(self, time_str: str) -> Optional[datetime]:
|
||||||
|
"""
|
||||||
|
解析相对日期:今天、明天、昨天、前天、后天
|
||||||
|
"""
|
||||||
|
relative_days = {
|
||||||
|
"今天": 0,
|
||||||
|
"今日": 0,
|
||||||
|
"明天": 1,
|
||||||
|
"明日": 1,
|
||||||
|
"昨天": -1,
|
||||||
|
"昨日": -1,
|
||||||
|
"前天": -2,
|
||||||
|
"前日": -2,
|
||||||
|
"后天": 2,
|
||||||
|
"后日": 2,
|
||||||
|
"大前天": -3,
|
||||||
|
"大后天": 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
for keyword, days in relative_days.items():
|
||||||
|
if keyword in time_str:
|
||||||
|
result = self.reference_time + timedelta(days=days)
|
||||||
|
# 保留原有时间,只改变日期
|
||||||
|
return result.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _parse_days_ago(self, time_str: str) -> Optional[datetime]:
|
||||||
|
"""
|
||||||
|
解析 X天前/X天后
|
||||||
|
"""
|
||||||
|
# 匹配:3天前、5天后、一天前
|
||||||
|
pattern = r"([一二三四五六七八九十\d]+)天(前|后)"
|
||||||
|
match = re.search(pattern, time_str)
|
||||||
|
|
||||||
|
if match:
|
||||||
|
num_str, direction = match.groups()
|
||||||
|
num = self._chinese_num_to_int(num_str)
|
||||||
|
|
||||||
|
if direction == "前":
|
||||||
|
num = -num
|
||||||
|
|
||||||
|
result = self.reference_time + timedelta(days=num)
|
||||||
|
return result.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _parse_hours_ago(self, time_str: str) -> Optional[datetime]:
|
||||||
|
"""
|
||||||
|
解析 X小时前/X小时后、X分钟前/X分钟后
|
||||||
|
"""
|
||||||
|
# 小时
|
||||||
|
pattern_hour = r"([一二三四五六七八九十\d]+)小?时(前|后)"
|
||||||
|
match = re.search(pattern_hour, time_str)
|
||||||
|
|
||||||
|
if match:
|
||||||
|
num_str, direction = match.groups()
|
||||||
|
num = self._chinese_num_to_int(num_str)
|
||||||
|
|
||||||
|
if direction == "前":
|
||||||
|
num = -num
|
||||||
|
|
||||||
|
return self.reference_time + timedelta(hours=num)
|
||||||
|
|
||||||
|
# 分钟
|
||||||
|
pattern_minute = r"([一二三四五六七八九十\d]+)分钟(前|后)"
|
||||||
|
match = re.search(pattern_minute, time_str)
|
||||||
|
|
||||||
|
if match:
|
||||||
|
num_str, direction = match.groups()
|
||||||
|
num = self._chinese_num_to_int(num_str)
|
||||||
|
|
||||||
|
if direction == "前":
|
||||||
|
num = -num
|
||||||
|
|
||||||
|
return self.reference_time + timedelta(minutes=num)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _parse_week_month_year(self, time_str: str) -> Optional[datetime]:
|
||||||
|
"""
|
||||||
|
解析:上周、上个月、去年、本周、本月、今年
|
||||||
|
"""
|
||||||
|
now = self.reference_time
|
||||||
|
|
||||||
|
if "上周" in time_str or "上星期" in time_str:
|
||||||
|
return now - timedelta(days=7)
|
||||||
|
|
||||||
|
if "上个月" in time_str or "上月" in time_str:
|
||||||
|
# 简单处理:减30天
|
||||||
|
return now - timedelta(days=30)
|
||||||
|
|
||||||
|
if "去年" in time_str or "上年" in time_str:
|
||||||
|
return now.replace(year=now.year - 1)
|
||||||
|
|
||||||
|
if "本周" in time_str or "这周" in time_str:
|
||||||
|
# 返回本周一
|
||||||
|
return now - timedelta(days=now.weekday())
|
||||||
|
|
||||||
|
if "本月" in time_str or "这个月" in time_str:
|
||||||
|
return now.replace(day=1)
|
||||||
|
|
||||||
|
if "今年" in time_str or "这年" in time_str:
|
||||||
|
return now.replace(month=1, day=1)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _parse_specific_date(self, time_str: str) -> Optional[datetime]:
|
||||||
|
"""
|
||||||
|
解析具体日期:
|
||||||
|
- 2025-11-05
|
||||||
|
- 2025/11/05
|
||||||
|
- 11月5日
|
||||||
|
- 11-05
|
||||||
|
"""
|
||||||
|
# ISO 格式:2025-11-05
|
||||||
|
pattern_iso = r"(\d{4})[-/](\d{1,2})[-/](\d{1,2})"
|
||||||
|
match = re.search(pattern_iso, time_str)
|
||||||
|
if match:
|
||||||
|
year, month, day = map(int, match.groups())
|
||||||
|
return datetime(year, month, day)
|
||||||
|
|
||||||
|
# 中文格式:11月5日、11月5号
|
||||||
|
pattern_cn = r"(\d{1,2})月(\d{1,2})[日号]"
|
||||||
|
match = re.search(pattern_cn, time_str)
|
||||||
|
if match:
|
||||||
|
month, day = map(int, match.groups())
|
||||||
|
# 使用参考时间的年份
|
||||||
|
year = self.reference_time.year
|
||||||
|
return datetime(year, month, day)
|
||||||
|
|
||||||
|
# 短格式:11-05(使用当前年份)
|
||||||
|
pattern_short = r"(\d{1,2})[-/](\d{1,2})"
|
||||||
|
match = re.search(pattern_short, time_str)
|
||||||
|
if match:
|
||||||
|
month, day = map(int, match.groups())
|
||||||
|
year = self.reference_time.year
|
||||||
|
return datetime(year, month, day)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _parse_time_of_day(self, time_str: str) -> Optional[datetime]:
|
||||||
|
"""
|
||||||
|
解析一天中的时间:
|
||||||
|
- 早上、上午、中午、下午、晚上、深夜
|
||||||
|
- 早上8点、下午3点
|
||||||
|
- 8点、15点
|
||||||
|
"""
|
||||||
|
now = self.reference_time
|
||||||
|
result = now.replace(minute=0, second=0, microsecond=0)
|
||||||
|
|
||||||
|
# 时间段映射
|
||||||
|
time_periods = {
|
||||||
|
"早上": 8,
|
||||||
|
"早晨": 8,
|
||||||
|
"上午": 10,
|
||||||
|
"中午": 12,
|
||||||
|
"下午": 15,
|
||||||
|
"傍晚": 18,
|
||||||
|
"晚上": 20,
|
||||||
|
"深夜": 23,
|
||||||
|
"凌晨": 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 先检查是否有具体时间点:早上8点、下午3点
|
||||||
|
for period, default_hour in time_periods.items():
|
||||||
|
pattern = rf"{period}(\d{{1,2}})点?"
|
||||||
|
match = re.search(pattern, time_str)
|
||||||
|
if match:
|
||||||
|
hour = int(match.group(1))
|
||||||
|
# 下午时间需要+12
|
||||||
|
if period in ["下午", "晚上"] and hour < 12:
|
||||||
|
hour += 12
|
||||||
|
return result.replace(hour=hour)
|
||||||
|
|
||||||
|
# 检查时间段关键词
|
||||||
|
for period, hour in time_periods.items():
|
||||||
|
if period in time_str:
|
||||||
|
return result.replace(hour=hour)
|
||||||
|
|
||||||
|
# 直接的时间点:8点、15点
|
||||||
|
pattern = r"(\d{1,2})点"
|
||||||
|
match = re.search(pattern, time_str)
|
||||||
|
if match:
|
||||||
|
hour = int(match.group(1))
|
||||||
|
return result.replace(hour=hour)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _chinese_num_to_int(self, num_str: str) -> int:
|
||||||
|
"""
|
||||||
|
将中文数字转换为阿拉伯数字
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_str: 中文数字字符串(如:"一"、"十"、"3")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
整数
|
||||||
|
"""
|
||||||
|
# 如果已经是数字,直接返回
|
||||||
|
if num_str.isdigit():
|
||||||
|
return int(num_str)
|
||||||
|
|
||||||
|
# 中文数字映射
|
||||||
|
chinese_nums = {
|
||||||
|
"一": 1,
|
||||||
|
"二": 2,
|
||||||
|
"三": 3,
|
||||||
|
"四": 4,
|
||||||
|
"五": 5,
|
||||||
|
"六": 6,
|
||||||
|
"七": 7,
|
||||||
|
"八": 8,
|
||||||
|
"九": 9,
|
||||||
|
"十": 10,
|
||||||
|
"零": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
if num_str in chinese_nums:
|
||||||
|
return chinese_nums[num_str]
|
||||||
|
|
||||||
|
# 处理 "十X" 的情况(如"十五"=15)
|
||||||
|
if num_str.startswith("十"):
|
||||||
|
if len(num_str) == 1:
|
||||||
|
return 10
|
||||||
|
return 10 + chinese_nums.get(num_str[1], 0)
|
||||||
|
|
||||||
|
# 处理 "X十" 的情况(如"三十"=30)
|
||||||
|
if "十" in num_str:
|
||||||
|
parts = num_str.split("十")
|
||||||
|
tens = chinese_nums.get(parts[0], 1) * 10
|
||||||
|
ones = chinese_nums.get(parts[1], 0) if len(parts) > 1 and parts[1] else 0
|
||||||
|
return tens + ones
|
||||||
|
|
||||||
|
# 默认返回1
|
||||||
|
return 1
|
||||||
|
|
||||||
|
def format_time(self, dt: datetime, format_type: str = "iso") -> str:
|
||||||
|
"""
|
||||||
|
格式化时间
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dt: datetime对象
|
||||||
|
format_type: 格式类型 ("iso", "cn", "relative")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
格式化的时间字符串
|
||||||
|
"""
|
||||||
|
if format_type == "iso":
|
||||||
|
return dt.isoformat()
|
||||||
|
|
||||||
|
elif format_type == "cn":
|
||||||
|
return dt.strftime("%Y年%m月%d日 %H:%M:%S")
|
||||||
|
|
||||||
|
elif format_type == "relative":
|
||||||
|
# 相对时间表达
|
||||||
|
diff = self.reference_time - dt
|
||||||
|
days = diff.days
|
||||||
|
|
||||||
|
if days == 0:
|
||||||
|
hours = diff.seconds // 3600
|
||||||
|
if hours == 0:
|
||||||
|
minutes = diff.seconds // 60
|
||||||
|
return f"{minutes}分钟前" if minutes > 0 else "刚刚"
|
||||||
|
return f"{hours}小时前"
|
||||||
|
elif days == 1:
|
||||||
|
return "昨天"
|
||||||
|
elif days == 2:
|
||||||
|
return "前天"
|
||||||
|
elif days < 7:
|
||||||
|
return f"{days}天前"
|
||||||
|
elif days < 30:
|
||||||
|
weeks = days // 7
|
||||||
|
return f"{weeks}周前"
|
||||||
|
elif days < 365:
|
||||||
|
months = days // 30
|
||||||
|
return f"{months}个月前"
|
||||||
|
else:
|
||||||
|
years = days // 365
|
||||||
|
return f"{years}年前"
|
||||||
|
|
||||||
|
return str(dt)
|
||||||
|
|
||||||
|
def parse_time_range(self, time_str: str) -> Tuple[Optional[datetime], Optional[datetime]]:
|
||||||
|
"""
|
||||||
|
解析时间范围:最近一周、最近3天
|
||||||
|
|
||||||
|
Args:
|
||||||
|
time_str: 时间范围字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(start_time, end_time)
|
||||||
|
"""
|
||||||
|
pattern = r"最近(\d+)(天|周|月|年)"
|
||||||
|
match = re.search(pattern, time_str)
|
||||||
|
|
||||||
|
if match:
|
||||||
|
num, unit = match.groups()
|
||||||
|
num = int(num)
|
||||||
|
|
||||||
|
unit_map = {"天": "days", "周": "weeks", "月": "days", "年": "days"}
|
||||||
|
if unit == "周":
|
||||||
|
num *= 7
|
||||||
|
elif unit == "月":
|
||||||
|
num *= 30
|
||||||
|
elif unit == "年":
|
||||||
|
num *= 365
|
||||||
|
|
||||||
|
end_time = self.reference_time
|
||||||
|
start_time = end_time - timedelta(**{unit_map[unit]: num})
|
||||||
|
|
||||||
|
return (start_time, end_time)
|
||||||
|
|
||||||
|
return (None, None)
|
||||||
Reference in New Issue
Block a user