diff --git a/src/memory_graph/core/__init__.py b/src/memory_graph/core/__init__.py index 8089247df..c6dc426db 100644 --- a/src/memory_graph/core/__init__.py +++ b/src/memory_graph/core/__init__.py @@ -2,6 +2,8 @@ 核心模块 """ +from src.memory_graph.core.builder import MemoryBuilder +from src.memory_graph.core.extractor import MemoryExtractor from src.memory_graph.core.node_merger import NodeMerger -__all__ = ["NodeMerger"] +__all__ = ["NodeMerger", "MemoryExtractor", "MemoryBuilder"] diff --git a/src/memory_graph/core/builder.py b/src/memory_graph/core/builder.py new file mode 100644 index 000000000..df3494d2e --- /dev/null +++ b/src/memory_graph/core/builder.py @@ -0,0 +1,549 @@ +""" +记忆构建器:自动构造记忆子图 +""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any, Dict, List, Optional + +import numpy as np + +from src.common.logger import get_logger +from src.memory_graph.models import ( + EdgeType, + Memory, + MemoryEdge, + MemoryNode, + MemoryStatus, + MemoryType, + NodeType, +) +from src.memory_graph.storage.graph_store import GraphStore +from src.memory_graph.storage.vector_store import VectorStore + +logger = get_logger(__name__) + + +class MemoryBuilder: + """ + 记忆构建器 + + 负责: + 1. 根据提取的元素自动构造记忆子图 + 2. 创建节点和边的完整结构 + 3. 生成语义嵌入向量 + 4. 检查并复用已存在的相似节点 + 5. 构造符合层级结构的记忆对象 + """ + + def __init__( + self, + vector_store: VectorStore, + graph_store: GraphStore, + embedding_generator: Optional[Any] = None, + ): + """ + 初始化记忆构建器 + + Args: + vector_store: 向量存储 + graph_store: 图存储 + embedding_generator: 嵌入向量生成器(可选) + """ + self.vector_store = vector_store + self.graph_store = graph_store + self.embedding_generator = embedding_generator + + async def build_memory(self, extracted_params: Dict[str, Any]) -> Memory: + """ + 构建完整的记忆对象 + + Args: + extracted_params: 提取器返回的标准化参数 + + Returns: + Memory 对象(状态为 STAGED) + """ + try: + nodes = [] + edges = [] + memory_id = self._generate_memory_id() + + # 1. 创建主体节点 (SUBJECT) + subject_node = await self._create_or_reuse_node( + content=extracted_params["subject"], + node_type=NodeType.SUBJECT, + memory_id=memory_id, + ) + nodes.append(subject_node) + + # 2. 创建主题节点 (TOPIC) - 需要嵌入向量 + topic_node = await self._create_topic_node( + content=extracted_params["topic"], memory_id=memory_id + ) + nodes.append(topic_node) + + # 3. 连接主体 -> 记忆类型 -> 主题 + memory_type_edge = MemoryEdge( + id=self._generate_edge_id(), + source_id=subject_node.id, + target_id=topic_node.id, + relation=extracted_params["memory_type"].value, + edge_type=EdgeType.MEMORY_TYPE, + importance=extracted_params["importance"], + metadata={"memory_id": memory_id}, + ) + edges.append(memory_type_edge) + + # 4. 如果有客体,创建客体节点并连接 + if "object" in extracted_params and extracted_params["object"]: + object_node = await self._create_object_node( + content=extracted_params["object"], memory_id=memory_id + ) + nodes.append(object_node) + + # 连接主题 -> 核心关系 -> 客体 + core_relation_edge = MemoryEdge( + id=self._generate_edge_id(), + source_id=topic_node.id, + target_id=object_node.id, + relation="核心关系", # 默认关系名 + edge_type=EdgeType.CORE_RELATION, + importance=extracted_params["importance"], + metadata={"memory_id": memory_id}, + ) + edges.append(core_relation_edge) + + # 5. 处理属性 + if extracted_params.get("attributes"): + attr_nodes, attr_edges = await self._process_attributes( + attributes=extracted_params["attributes"], + parent_id=topic_node.id, + memory_id=memory_id, + importance=extracted_params["importance"], + ) + nodes.extend(attr_nodes) + edges.extend(attr_edges) + + # 6. 构建 Memory 对象 + memory = Memory( + id=memory_id, + subject_id=subject_node.id, + memory_type=extracted_params["memory_type"], + nodes=nodes, + edges=edges, + importance=extracted_params["importance"], + created_at=extracted_params["timestamp"], + last_accessed=extracted_params["timestamp"], + access_count=0, + status=MemoryStatus.STAGED, + metadata={ + "subject": extracted_params["subject"], + "topic": extracted_params["topic"], + }, + ) + + logger.info( + f"构建记忆成功: {memory_id} - {len(nodes)} 节点, {len(edges)} 边" + ) + return memory + + except Exception as e: + logger.error(f"记忆构建失败: {e}", exc_info=True) + raise RuntimeError(f"记忆构建失败: {e}") + + async def _create_or_reuse_node( + self, content: str, node_type: NodeType, memory_id: str + ) -> MemoryNode: + """ + 创建新节点或复用已存在的相似节点 + + 对于主体(SUBJECT)和属性(ATTRIBUTE),检查是否已存在相同内容的节点 + + Args: + content: 节点内容 + node_type: 节点类型 + memory_id: 所属记忆ID + + Returns: + MemoryNode 对象 + """ + # 对于主体,尝试查找已存在的节点 + if node_type == NodeType.SUBJECT: + existing = await self._find_existing_node(content, node_type) + if existing: + logger.debug(f"复用已存在的主体节点: {existing.id}") + return existing + + # 创建新节点 + node = MemoryNode( + id=self._generate_node_id(), + content=content, + node_type=node_type, + embedding=None, # 主体和属性不需要嵌入 + metadata={"memory_ids": [memory_id]}, + ) + + return node + + async def _create_topic_node(self, content: str, memory_id: str) -> MemoryNode: + """ + 创建主题节点(需要生成嵌入向量) + + Args: + content: 节点内容 + memory_id: 所属记忆ID + + Returns: + MemoryNode 对象 + """ + # 生成嵌入向量 + embedding = await self._generate_embedding(content) + + # 检查是否存在高度相似的节点 + existing = await self._find_similar_topic(content, embedding) + if existing: + logger.debug(f"复用相似的主题节点: {existing.id}") + # 添加当前记忆ID到元数据 + if "memory_ids" not in existing.metadata: + existing.metadata["memory_ids"] = [] + existing.metadata["memory_ids"].append(memory_id) + return existing + + # 创建新节点 + node = MemoryNode( + id=self._generate_node_id(), + content=content, + node_type=NodeType.TOPIC, + embedding=embedding, + metadata={"memory_ids": [memory_id]}, + ) + + return node + + async def _create_object_node(self, content: str, memory_id: str) -> MemoryNode: + """ + 创建客体节点(需要生成嵌入向量) + + Args: + content: 节点内容 + memory_id: 所属记忆ID + + Returns: + MemoryNode 对象 + """ + # 生成嵌入向量 + embedding = await self._generate_embedding(content) + + # 检查是否存在高度相似的节点 + existing = await self._find_similar_object(content, embedding) + if existing: + logger.debug(f"复用相似的客体节点: {existing.id}") + if "memory_ids" not in existing.metadata: + existing.metadata["memory_ids"] = [] + existing.metadata["memory_ids"].append(memory_id) + return existing + + # 创建新节点 + node = MemoryNode( + id=self._generate_node_id(), + content=content, + node_type=NodeType.OBJECT, + embedding=embedding, + metadata={"memory_ids": [memory_id]}, + ) + + return node + + async def _process_attributes( + self, + attributes: Dict[str, Any], + parent_id: str, + memory_id: str, + importance: float, + ) -> tuple[List[MemoryNode], List[MemoryEdge]]: + """ + 处理属性,构建属性子图 + + 结构:TOPIC -> ATTRIBUTE -> VALUE + + Args: + attributes: 属性字典 + parent_id: 父节点ID(通常是TOPIC) + memory_id: 所属记忆ID + importance: 重要性 + + Returns: + (属性节点列表, 属性边列表) + """ + nodes = [] + edges = [] + + for attr_name, attr_value in attributes.items(): + # 创建属性节点 + attr_node = await self._create_or_reuse_node( + content=attr_name, node_type=NodeType.ATTRIBUTE, memory_id=memory_id + ) + nodes.append(attr_node) + + # 连接父节点 -> 属性 + attr_edge = MemoryEdge( + id=self._generate_edge_id(), + source_id=parent_id, + target_id=attr_node.id, + relation="属性", + edge_type=EdgeType.ATTRIBUTE, + importance=importance * 0.8, # 属性的重要性略低 + metadata={"memory_id": memory_id}, + ) + edges.append(attr_edge) + + # 创建值节点 + value_node = await self._create_or_reuse_node( + content=str(attr_value), node_type=NodeType.VALUE, memory_id=memory_id + ) + nodes.append(value_node) + + # 连接属性 -> 值 + value_edge = MemoryEdge( + id=self._generate_edge_id(), + source_id=attr_node.id, + target_id=value_node.id, + relation="值", + edge_type=EdgeType.ATTRIBUTE, + importance=importance * 0.8, + metadata={"memory_id": memory_id}, + ) + edges.append(value_edge) + + return nodes, edges + + async def _generate_embedding(self, text: str) -> np.ndarray: + """ + 生成文本的嵌入向量 + + Args: + text: 文本内容 + + Returns: + 嵌入向量 + """ + if self.embedding_generator: + try: + embedding = await self.embedding_generator.generate(text) + return embedding + except Exception as e: + logger.warning(f"嵌入生成失败,使用随机向量: {e}") + + # 回退:生成随机向量(仅用于测试) + return np.random.rand(384).astype(np.float32) + + async def _find_existing_node( + self, content: str, node_type: NodeType + ) -> Optional[MemoryNode]: + """ + 查找已存在的完全匹配节点(用于主体和属性) + + Args: + content: 节点内容 + node_type: 节点类型 + + Returns: + 已存在的节点,如果没有则返回 None + """ + # 在图存储中查找 + for node_id in self.graph_store.graph.nodes(): + node_data = self.graph_store.graph.nodes[node_id] + if node_data.get("content") == content and node_data.get("node_type") == node_type.value: + # 重建 MemoryNode 对象 + return MemoryNode( + id=node_id, + content=node_data["content"], + node_type=NodeType(node_data["node_type"]), + embedding=node_data.get("embedding"), + metadata=node_data.get("metadata", {}), + ) + + return None + + async def _find_similar_topic( + self, content: str, embedding: np.ndarray + ) -> Optional[MemoryNode]: + """ + 查找相似的主题节点(基于语义相似度) + + Args: + content: 内容 + embedding: 嵌入向量 + + Returns: + 相似节点,如果没有则返回 None + """ + try: + # 搜索相似节点(阈值 0.95) + similar_nodes = await self.vector_store.search_similar_nodes( + query_embedding=embedding, + limit=1, + node_types=[NodeType.TOPIC], + min_similarity=0.95, + ) + + if similar_nodes and similar_nodes[0][1] >= 0.95: + node_id, similarity, metadata = similar_nodes[0] + logger.debug( + f"找到相似主题节点: {metadata.get('content', '')} (相似度: {similarity:.3f})" + ) + # 从图存储中获取完整节点 + if node_id in self.graph_store.graph.nodes: + node_data = self.graph_store.graph.nodes[node_id] + existing_node = MemoryNode( + id=node_id, + content=node_data["content"], + node_type=NodeType(node_data["node_type"]), + embedding=node_data.get("embedding"), + metadata=node_data.get("metadata", {}), + ) + # 添加当前记忆ID到元数据 + return existing_node + + except Exception as e: + logger.warning(f"相似节点搜索失败: {e}") + + return None + + async def _find_similar_object( + self, content: str, embedding: np.ndarray + ) -> Optional[MemoryNode]: + """ + 查找相似的客体节点(基于语义相似度) + + Args: + content: 内容 + embedding: 嵌入向量 + + Returns: + 相似节点,如果没有则返回 None + """ + try: + # 搜索相似节点(阈值 0.95) + similar_nodes = await self.vector_store.search_similar_nodes( + query_embedding=embedding, + limit=1, + node_types=[NodeType.OBJECT], + min_similarity=0.95, + ) + + if similar_nodes and similar_nodes[0][1] >= 0.95: + node_id, similarity, metadata = similar_nodes[0] + logger.debug( + f"找到相似客体节点: {metadata.get('content', '')} (相似度: {similarity:.3f})" + ) + # 从图存储中获取完整节点 + if node_id in self.graph_store.graph.nodes: + node_data = self.graph_store.graph.nodes[node_id] + existing_node = MemoryNode( + id=node_id, + content=node_data["content"], + node_type=NodeType(node_data["node_type"]), + embedding=node_data.get("embedding"), + metadata=node_data.get("metadata", {}), + ) + return existing_node + + except Exception as e: + logger.warning(f"相似节点搜索失败: {e}") + + return None + + def _generate_memory_id(self) -> str: + """生成记忆ID""" + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") + return f"mem_{timestamp}" + + def _generate_node_id(self) -> str: + """生成节点ID""" + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") + return f"node_{timestamp}" + + def _generate_edge_id(self) -> str: + """生成边ID""" + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") + return f"edge_{timestamp}" + + async def link_memories( + self, + source_memory: Memory, + target_memory: Memory, + relation_type: str, + importance: float = 0.6, + ) -> MemoryEdge: + """ + 关联两个记忆(创建因果或引用边) + + Args: + source_memory: 源记忆 + target_memory: 目标记忆 + relation_type: 关系类型(如 "导致", "引用") + importance: 重要性 + + Returns: + 创建的边 + """ + try: + # 获取两个记忆的主题节点(作为连接点) + source_topic = self._find_topic_node(source_memory) + target_topic = self._find_topic_node(target_memory) + + if not source_topic or not target_topic: + raise ValueError("无法找到记忆的主题节点") + + # 确定边的类型 + edge_type = self._determine_edge_type(relation_type) + + # 创建边 + edge_id = f"edge_{datetime.now().strftime('%Y%m%d_%H%M%S_%f')}" + edge = MemoryEdge( + id=edge_id, + source_id=source_topic.id, + target_id=target_topic.id, + relation=relation_type, + edge_type=edge_type, + importance=importance, + metadata={ + "source_memory_id": source_memory.id, + "target_memory_id": target_memory.id, + }, + ) + + logger.info( + f"关联记忆: {source_memory.id} --{relation_type}--> {target_memory.id}" + ) + return edge + + except Exception as e: + logger.error(f"记忆关联失败: {e}", exc_info=True) + raise RuntimeError(f"记忆关联失败: {e}") + + def _find_topic_node(self, memory: Memory) -> Optional[MemoryNode]: + """查找记忆中的主题节点""" + for node in memory.nodes: + if node.node_type == NodeType.TOPIC: + return node + return None + + def _determine_edge_type(self, relation_type: str) -> EdgeType: + """根据关系类型确定边的类型""" + causality_keywords = ["导致", "引起", "造成", "因为", "所以"] + reference_keywords = ["引用", "基于", "关于", "参考"] + + for keyword in causality_keywords: + if keyword in relation_type: + return EdgeType.CAUSALITY + + for keyword in reference_keywords: + if keyword in relation_type: + return EdgeType.REFERENCE + + # 默认为引用类型 + return EdgeType.REFERENCE diff --git a/src/memory_graph/core/extractor.py b/src/memory_graph/core/extractor.py new file mode 100644 index 000000000..afe2ad370 --- /dev/null +++ b/src/memory_graph/core/extractor.py @@ -0,0 +1,311 @@ +""" +记忆提取器:从工具参数中提取和验证记忆元素 +""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any, Dict, Optional + +from src.common.logger import get_logger +from src.memory_graph.models import MemoryType +from src.memory_graph.utils.time_parser import TimeParser + +logger = get_logger(__name__) + + +class MemoryExtractor: + """ + 记忆提取器 + + 负责: + 1. 从工具调用参数中提取记忆元素 + 2. 验证参数完整性和有效性 + 3. 标准化时间表达 + 4. 清洗和格式化数据 + """ + + def __init__(self, time_parser: Optional[TimeParser] = None): + """ + 初始化记忆提取器 + + Args: + time_parser: 时间解析器(可选) + """ + self.time_parser = time_parser or TimeParser() + + def extract_from_tool_params(self, params: Dict[str, Any]) -> Dict[str, Any]: + """ + 从工具参数中提取记忆元素 + + Args: + params: 工具调用参数,例如: + { + "subject": "我", + "memory_type": "事件", + "topic": "吃饭", + "object": "白米饭", + "attributes": {"时间": "今天", "地点": "家里"}, + "importance": 0.3 + } + + Returns: + 提取和标准化后的参数字典 + """ + try: + # 1. 验证必需参数 + self._validate_required_params(params) + + # 2. 提取基础元素 + extracted = { + "subject": self._clean_text(params["subject"]), + "memory_type": self._parse_memory_type(params["memory_type"]), + "topic": self._clean_text(params["topic"]), + } + + # 3. 提取可选的客体 + if "object" in params and params["object"]: + extracted["object"] = self._clean_text(params["object"]) + + # 4. 提取和标准化属性 + if "attributes" in params and params["attributes"]: + extracted["attributes"] = self._process_attributes(params["attributes"]) + else: + extracted["attributes"] = {} + + # 5. 提取重要性 + extracted["importance"] = self._parse_importance(params.get("importance", 0.5)) + + # 6. 添加时间戳 + extracted["timestamp"] = datetime.now() + + logger.debug(f"提取记忆元素: {extracted['subject']} - {extracted['topic']}") + return extracted + + except Exception as e: + logger.error(f"记忆提取失败: {e}", exc_info=True) + raise ValueError(f"记忆提取失败: {e}") + + def _validate_required_params(self, params: Dict[str, Any]) -> None: + """ + 验证必需参数 + + Args: + params: 参数字典 + + Raises: + ValueError: 如果缺少必需参数 + """ + required_fields = ["subject", "memory_type", "topic"] + + for field in required_fields: + if field not in params or not params[field]: + raise ValueError(f"缺少必需参数: {field}") + + def _clean_text(self, text: Any) -> str: + """ + 清洗文本 + + Args: + text: 输入文本 + + Returns: + 清洗后的文本 + """ + if not text: + return "" + + text = str(text).strip() + + # 移除多余的空格 + text = " ".join(text.split()) + + # 移除特殊字符(保留基本标点) + # text = re.sub(r'[^\w\s\u4e00-\u9fff,,.。!!??;;::、]', '', text) + + return text + + def _parse_memory_type(self, type_str: str) -> MemoryType: + """ + 解析记忆类型 + + Args: + type_str: 类型字符串 + + Returns: + MemoryType 枚举 + + Raises: + ValueError: 如果类型无效 + """ + type_str = type_str.strip() + + # 尝试直接匹配 + try: + return MemoryType(type_str) + except ValueError: + pass + + # 模糊匹配 + type_mapping = { + "事件": MemoryType.EVENT, + "event": MemoryType.EVENT, + "事实": MemoryType.FACT, + "fact": MemoryType.FACT, + "关系": MemoryType.RELATION, + "relation": MemoryType.RELATION, + "观点": MemoryType.OPINION, + "opinion": MemoryType.OPINION, + } + + if type_str.lower() in type_mapping: + return type_mapping[type_str.lower()] + + raise ValueError(f"无效的记忆类型: {type_str}") + + def _parse_importance(self, importance: Any) -> float: + """ + 解析重要性值 + + Args: + importance: 重要性值(可以是数字、字符串等) + + Returns: + 0-1之间的浮点数 + """ + try: + value = float(importance) + # 限制在 0-1 范围内 + return max(0.0, min(1.0, value)) + except (ValueError, TypeError): + logger.warning(f"无效的重要性值: {importance},使用默认值 0.5") + return 0.5 + + def _process_attributes(self, attributes: Dict[str, Any]) -> Dict[str, Any]: + """ + 处理属性字典 + + Args: + attributes: 原始属性字典 + + Returns: + 处理后的属性字典 + """ + processed = {} + + for key, value in attributes.items(): + key = key.strip() + + # 特殊处理:时间属性 + if key in ["时间", "time", "when"]: + parsed_time = self.time_parser.parse(str(value)) + if parsed_time: + processed["时间"] = parsed_time.isoformat() + else: + processed["时间"] = str(value) + + # 特殊处理:地点属性 + elif key in ["地点", "place", "where", "位置"]: + processed["地点"] = self._clean_text(value) + + # 特殊处理:原因属性 + elif key in ["原因", "reason", "why", "因为"]: + processed["原因"] = self._clean_text(value) + + # 特殊处理:方式属性 + elif key in ["方式", "how", "manner"]: + processed["方式"] = self._clean_text(value) + + # 其他属性 + else: + processed[key] = self._clean_text(value) + + return processed + + def extract_link_params(self, params: Dict[str, Any]) -> Dict[str, Any]: + """ + 提取记忆关联参数(用于 link_memories 工具) + + Args: + params: 工具参数,例如: + { + "source_memory_description": "我今天不开心", + "target_memory_description": "我摔东西", + "relation_type": "导致", + "importance": 0.6 + } + + Returns: + 提取后的参数 + """ + try: + required = ["source_memory_description", "target_memory_description", "relation_type"] + + for field in required: + if field not in params or not params[field]: + raise ValueError(f"缺少必需参数: {field}") + + extracted = { + "source_description": self._clean_text(params["source_memory_description"]), + "target_description": self._clean_text(params["target_memory_description"]), + "relation_type": self._clean_text(params["relation_type"]), + "importance": self._parse_importance(params.get("importance", 0.6)), + } + + logger.debug( + f"提取关联参数: {extracted['source_description']} --{extracted['relation_type']}--> " + f"{extracted['target_description']}" + ) + + return extracted + + except Exception as e: + logger.error(f"关联参数提取失败: {e}", exc_info=True) + raise ValueError(f"关联参数提取失败: {e}") + + def validate_relation_type(self, relation_type: str) -> str: + """ + 验证关系类型 + + Args: + relation_type: 关系类型字符串 + + Returns: + 标准化的关系类型 + """ + # 因果关系映射 + causality_relations = { + "因为": "因为", + "所以": "所以", + "导致": "导致", + "引起": "导致", + "造成": "导致", + "因": "因为", + "果": "所以", + } + + # 引用关系映射 + reference_relations = { + "引用": "引用", + "基于": "基于", + "关于": "关于", + "参考": "引用", + } + + # 相关关系 + related_relations = { + "相关": "相关", + "有关": "相关", + "联系": "相关", + } + + relation_type = relation_type.strip() + + # 查找匹配 + for mapping in [causality_relations, reference_relations, related_relations]: + if relation_type in mapping: + return mapping[relation_type] + + # 未找到映射,返回原值 + logger.warning(f"未识别的关系类型: {relation_type},使用原值") + return relation_type diff --git a/src/memory_graph/storage/vector_store.py b/src/memory_graph/storage/vector_store.py index 96e60d7f4..5a791916b 100644 --- a/src/memory_graph/storage/vector_store.py +++ b/src/memory_graph/storage/vector_store.py @@ -92,17 +92,27 @@ class VectorStore: return try: + # 准备元数据(ChromaDB 只支持 str, int, float, bool) + metadata = { + "content": node.content, + "node_type": node.node_type.value, + "created_at": node.created_at.isoformat(), + } + + # 处理额外的元数据,将 list 转换为 JSON 字符串 + for key, value in node.metadata.items(): + if isinstance(value, (list, dict)): + import json + metadata[key] = json.dumps(value, ensure_ascii=False) + elif isinstance(value, (str, int, float, bool)) or value is None: + metadata[key] = value + else: + metadata[key] = str(value) + self.collection.add( ids=[node.id], embeddings=[node.embedding.tolist()], - metadatas=[ - { - "content": node.content, - "node_type": node.node_type.value, - "created_at": node.created_at.isoformat(), - **node.metadata, - } - ], + metadatas=[metadata], documents=[node.content], # 文本内容用于检索 ) @@ -130,18 +140,28 @@ class VectorStore: return try: + # 准备元数据 + import json + metadatas = [] + for n in valid_nodes: + metadata = { + "content": n.content, + "node_type": n.node_type.value, + "created_at": n.created_at.isoformat(), + } + for key, value in n.metadata.items(): + if isinstance(value, (list, dict)): + metadata[key] = json.dumps(value, ensure_ascii=False) + elif isinstance(value, (str, int, float, bool)) or value is None: + metadata[key] = value # type: ignore + else: + metadata[key] = str(value) + metadatas.append(metadata) + self.collection.add( ids=[n.id for n in valid_nodes], - embeddings=[n.embedding.tolist() for n in valid_nodes], - metadatas=[ - { - "content": n.content, - "node_type": n.node_type.value, - "created_at": n.created_at.isoformat(), - **n.metadata, - } - for n in valid_nodes - ], + embeddings=[n.embedding.tolist() for n in valid_nodes], # type: ignore + metadatas=metadatas, documents=[n.content for n in valid_nodes], ) @@ -187,16 +207,26 @@ class VectorStore: ) # 解析结果 + import json similar_nodes = [] if results["ids"] and results["ids"][0]: for i, node_id in enumerate(results["ids"][0]): # ChromaDB 返回的是距离,需要转换为相似度 # 余弦距离: distance = 1 - similarity - distance = results["distances"][0][i] + distance = results["distances"][0][i] if results["distances"] else 0.0 # type: ignore similarity = 1.0 - distance if similarity >= min_similarity: - metadata = results["metadatas"][0][i] if results["metadatas"] else {} + metadata = results["metadatas"][0][i] if results["metadatas"] else {} # type: ignore + + # 解析 JSON 字符串回列表/字典 + for key, value in list(metadata.items()): + if isinstance(value, str) and (value.startswith('[') or value.startswith('{')): + try: + metadata[key] = json.loads(value) + except: + pass # 保持原值 + similar_nodes.append((node_id, similarity, metadata)) logger.debug(f"相似节点搜索: 找到 {len(similar_nodes)} 个结果") diff --git a/src/memory_graph/tools/__init__.py b/src/memory_graph/tools/__init__.py new file mode 100644 index 000000000..8bb8d538a --- /dev/null +++ b/src/memory_graph/tools/__init__.py @@ -0,0 +1,7 @@ +""" +记忆系统工具模块 +""" + +from src.memory_graph.tools.memory_tools import MemoryTools + +__all__ = ["MemoryTools"] diff --git a/src/memory_graph/tools/memory_tools.py b/src/memory_graph/tools/memory_tools.py new file mode 100644 index 000000000..1a1910d43 --- /dev/null +++ b/src/memory_graph/tools/memory_tools.py @@ -0,0 +1,495 @@ +""" +LLM 工具接口:定义记忆系统的工具 schema 和执行逻辑 +""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +from src.common.logger import get_logger +from src.memory_graph.core.builder import MemoryBuilder +from src.memory_graph.core.extractor import MemoryExtractor +from src.memory_graph.models import Memory, MemoryStatus +from src.memory_graph.storage.graph_store import GraphStore +from src.memory_graph.storage.persistence import PersistenceManager +from src.memory_graph.storage.vector_store import VectorStore +from src.memory_graph.utils.embeddings import EmbeddingGenerator + +logger = get_logger(__name__) + + +class MemoryTools: + """ + 记忆系统工具集 + + 提供给 LLM 使用的工具接口: + 1. create_memory: 创建新记忆 + 2. link_memories: 关联两个记忆 + 3. search_memories: 搜索记忆 + """ + + def __init__( + self, + vector_store: VectorStore, + graph_store: GraphStore, + persistence_manager: PersistenceManager, + embedding_generator: Optional[EmbeddingGenerator] = None, + ): + """ + 初始化工具集 + + Args: + vector_store: 向量存储 + graph_store: 图存储 + persistence_manager: 持久化管理器 + embedding_generator: 嵌入生成器(可选) + """ + self.vector_store = vector_store + self.graph_store = graph_store + self.persistence_manager = persistence_manager + self._initialized = False + + # 初始化组件 + self.extractor = MemoryExtractor() + self.builder = MemoryBuilder( + vector_store=vector_store, + graph_store=graph_store, + embedding_generator=embedding_generator, + ) + + async def _ensure_initialized(self): + """确保向量存储已初始化""" + if not self._initialized: + await self.vector_store.initialize() + self._initialized = True + + @staticmethod + def get_create_memory_schema() -> Dict[str, Any]: + """ + 获取 create_memory 工具的 JSON schema + + Returns: + 工具 schema 定义 + """ + return { + "name": "create_memory", + "description": "创建一个新的记忆。记忆由主体、类型、主题、客体(可选)和属性组成。", + "parameters": { + "type": "object", + "properties": { + "subject": { + "type": "string", + "description": "记忆的主体,通常是'我'、'用户'或具体的人名", + }, + "memory_type": { + "type": "string", + "enum": ["事件", "事实", "关系", "观点"], + "description": "记忆类型:事件(时间绑定的动作)、事实(稳定状态)、关系(人际关系)、观点(主观评价)", + }, + "topic": { + "type": "string", + "description": "记忆的主题,即发生的事情或状态", + }, + "object": { + "type": "string", + "description": "记忆的客体,即主题作用的对象(可选)", + }, + "attributes": { + "type": "object", + "description": "记忆的属性,如时间、地点、原因、方式等", + "properties": { + "时间": { + "type": "string", + "description": "时间表达式,如'今天'、'昨天'、'3天前'、'2025-11-05'", + }, + "地点": {"type": "string", "description": "地点"}, + "原因": {"type": "string", "description": "原因"}, + "方式": {"type": "string", "description": "方式"}, + }, + "additionalProperties": True, + }, + "importance": { + "type": "number", + "minimum": 0.0, + "maximum": 1.0, + "description": "记忆的重要性,0-1之间的浮点数,默认0.5", + }, + }, + "required": ["subject", "memory_type", "topic"], + }, + } + + @staticmethod + def get_link_memories_schema() -> Dict[str, Any]: + """ + 获取 link_memories 工具的 JSON schema + + Returns: + 工具 schema 定义 + """ + return { + "name": "link_memories", + "description": "关联两个已存在的记忆,建立因果或引用关系。", + "parameters": { + "type": "object", + "properties": { + "source_memory_description": { + "type": "string", + "description": "源记忆的描述,用于查找对应的记忆", + }, + "target_memory_description": { + "type": "string", + "description": "目标记忆的描述,用于查找对应的记忆", + }, + "relation_type": { + "type": "string", + "description": "关系类型,如'导致'、'引起'、'因为'、'所以'、'引用'、'基于'等", + }, + "importance": { + "type": "number", + "minimum": 0.0, + "maximum": 1.0, + "description": "关系的重要性,0-1之间的浮点数,默认0.6", + }, + }, + "required": [ + "source_memory_description", + "target_memory_description", + "relation_type", + ], + }, + } + + @staticmethod + def get_search_memories_schema() -> Dict[str, Any]: + """ + 获取 search_memories 工具的 JSON schema + + Returns: + 工具 schema 定义 + """ + return { + "name": "search_memories", + "description": "搜索相关的记忆。支持语义搜索、图遍历和时间过滤。", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "搜索查询,描述要查找的记忆内容", + }, + "memory_types": { + "type": "array", + "items": { + "type": "string", + "enum": ["事件", "事实", "关系", "观点"], + }, + "description": "要搜索的记忆类型,可多选", + }, + "time_range": { + "type": "object", + "properties": { + "start": { + "type": "string", + "description": "开始时间,如'3天前'、'2025-11-01'", + }, + "end": { + "type": "string", + "description": "结束时间,如'今天'、'2025-11-05'", + }, + }, + "description": "时间范围过滤(可选)", + }, + "top_k": { + "type": "integer", + "minimum": 1, + "maximum": 50, + "description": "返回结果数量,默认10", + }, + "expand_depth": { + "type": "integer", + "minimum": 0, + "maximum": 3, + "description": "图遍历扩展深度,0表示不扩展,默认1", + }, + }, + "required": ["query"], + }, + } + + async def create_memory(self, **params) -> Dict[str, Any]: + """ + 执行 create_memory 工具 + + Args: + **params: 工具参数 + + Returns: + 执行结果 + """ + try: + logger.info(f"创建记忆: {params.get('subject')} - {params.get('topic')}") + + # 0. 确保初始化 + await self._ensure_initialized() + + # 1. 提取参数 + extracted = self.extractor.extract_from_tool_params(params) + + # 2. 构建记忆 + memory = await self.builder.build_memory(extracted) + + # 3. 添加到存储(暂存状态) + await self._add_memory_to_stores(memory) + + # 4. 保存到磁盘 + await self.persistence_manager.save_graph_store(self.graph_store) + + logger.info(f"记忆创建成功: {memory.id}") + + return { + "success": True, + "memory_id": memory.id, + "message": f"记忆已创建: {extracted['subject']} - {extracted['topic']}", + "nodes_count": len(memory.nodes), + "edges_count": len(memory.edges), + } + + except Exception as e: + logger.error(f"记忆创建失败: {e}", exc_info=True) + return { + "success": False, + "error": str(e), + "message": "记忆创建失败", + } + + async def link_memories(self, **params) -> Dict[str, Any]: + """ + 执行 link_memories 工具 + + Args: + **params: 工具参数 + + Returns: + 执行结果 + """ + try: + logger.info( + f"关联记忆: {params.get('source_memory_description')} -> " + f"{params.get('target_memory_description')}" + ) + + # 1. 提取参数 + extracted = self.extractor.extract_link_params(params) + + # 2. 查找源记忆和目标记忆 + source_memory = await self._find_memory_by_description( + extracted["source_description"] + ) + target_memory = await self._find_memory_by_description( + extracted["target_description"] + ) + + if not source_memory: + return { + "success": False, + "error": "找不到源记忆", + "message": f"未找到匹配的源记忆: {extracted['source_description']}", + } + + if not target_memory: + return { + "success": False, + "error": "找不到目标记忆", + "message": f"未找到匹配的目标记忆: {extracted['target_description']}", + } + + # 3. 创建关联边 + edge = await self.builder.link_memories( + source_memory=source_memory, + target_memory=target_memory, + relation_type=extracted["relation_type"], + importance=extracted["importance"], + ) + + # 4. 添加边到图存储 + self.graph_store.graph.add_edge( + edge.source_id, + edge.target_id, + relation=edge.relation, + edge_type=edge.edge_type.value, + importance=edge.importance, + **edge.metadata + ) + + # 5. 保存 + await self.persistence_manager.save_graph_store(self.graph_store) + + logger.info(f"记忆关联成功: {source_memory.id} -> {target_memory.id}") + + return { + "success": True, + "message": f"记忆已关联: {extracted['relation_type']}", + "source_memory_id": source_memory.id, + "target_memory_id": target_memory.id, + "relation_type": extracted["relation_type"], + } + + except Exception as e: + logger.error(f"记忆关联失败: {e}", exc_info=True) + return { + "success": False, + "error": str(e), + "message": "记忆关联失败", + } + + async def search_memories(self, **params) -> Dict[str, Any]: + """ + 执行 search_memories 工具 + + Args: + **params: 工具参数 + + Returns: + 搜索结果 + """ + try: + query = params.get("query", "") + top_k = params.get("top_k", 10) + expand_depth = params.get("expand_depth", 1) + + logger.info(f"搜索记忆: {query} (top_k={top_k}, expand_depth={expand_depth})") + + # 0. 确保初始化 + await self._ensure_initialized() + + # 1. 生成查询嵌入 + if self.builder.embedding_generator: + query_embedding = await self.builder.embedding_generator.generate(query) + else: + logger.warning("未配置嵌入生成器,使用随机向量") + import numpy as np + query_embedding = np.random.rand(384).astype(np.float32) + + # 2. 向量搜索 + node_types_filter = None + if "memory_types" in params: + # 添加类型过滤 + pass + + similar_nodes = await self.vector_store.search_similar_nodes( + query_embedding=query_embedding, + limit=top_k * 2, # 多取一些,后续过滤 + node_types=node_types_filter, + ) + + # 3. 提取记忆ID + memory_ids = set() + for node_id, similarity, metadata in similar_nodes: + if "memory_ids" in metadata: + memory_ids.update(metadata["memory_ids"]) + + # 4. 获取完整记忆 + memories = [] + for memory_id in list(memory_ids)[:top_k]: + memory = self.graph_store.get_memory_by_id(memory_id) + if memory: + memories.append(memory) + + # 5. 格式化结果 + results = [] + for memory in memories: + result = { + "memory_id": memory.id, + "importance": memory.importance, + "created_at": memory.created_at.isoformat(), + "summary": self._summarize_memory(memory), + } + results.append(result) + + logger.info(f"搜索完成: 找到 {len(results)} 条记忆") + + return { + "success": True, + "results": results, + "total": len(results), + "query": query, + } + + except Exception as e: + logger.error(f"记忆搜索失败: {e}", exc_info=True) + return { + "success": False, + "error": str(e), + "message": "记忆搜索失败", + "results": [], + } + + async def _add_memory_to_stores(self, memory: Memory): + """将记忆添加到存储""" + # 1. 添加到图存储 + self.graph_store.add_memory(memory) + + # 2. 添加有嵌入的节点到向量存储 + for node in memory.nodes: + if node.embedding is not None: + await self.vector_store.add_node(node) + + async def _find_memory_by_description(self, description: str) -> Optional[Memory]: + """ + 通过描述查找记忆 + + Args: + description: 记忆描述 + + Returns: + 找到的记忆,如果没有则返回 None + """ + # 使用语义搜索查找最相关的记忆 + if self.builder.embedding_generator: + query_embedding = await self.builder.embedding_generator.generate(description) + else: + import numpy as np + query_embedding = np.random.rand(384).astype(np.float32) + + # 搜索相似节点 + similar_nodes = await self.vector_store.search_similar_nodes( + query_embedding=query_embedding, + limit=5, + ) + + if not similar_nodes: + return None + + # 获取最相似节点关联的记忆 + node_id, similarity, metadata = similar_nodes[0] + if "memory_ids" in metadata and metadata["memory_ids"]: + memory_id = metadata["memory_ids"][0] + return self.graph_store.get_memory_by_id(memory_id) + + return None + + def _summarize_memory(self, memory: Memory) -> str: + """生成记忆摘要""" + if not memory.metadata: + return "未知记忆" + + subject = memory.metadata.get("subject", "") + topic = memory.metadata.get("topic", "") + memory_type = memory.metadata.get("memory_type", "") + + return f"{subject} - {memory_type}: {topic}" + + @staticmethod + def get_all_tool_schemas() -> List[Dict[str, Any]]: + """ + 获取所有工具的 schema + + Returns: + 工具 schema 列表 + """ + return [ + MemoryTools.get_create_memory_schema(), + MemoryTools.get_link_memories_schema(), + MemoryTools.get_search_memories_schema(), + ] diff --git a/src/memory_graph/utils/__init__.py b/src/memory_graph/utils/__init__.py new file mode 100644 index 000000000..0b23863d0 --- /dev/null +++ b/src/memory_graph/utils/__init__.py @@ -0,0 +1,8 @@ +""" +工具模块 +""" + +from src.memory_graph.utils.embeddings import EmbeddingGenerator, get_embedding_generator +from src.memory_graph.utils.time_parser import TimeParser + +__all__ = ["TimeParser", "EmbeddingGenerator", "get_embedding_generator"] diff --git a/src/memory_graph/utils/embeddings.py b/src/memory_graph/utils/embeddings.py new file mode 100644 index 000000000..016b7bda3 --- /dev/null +++ b/src/memory_graph/utils/embeddings.py @@ -0,0 +1,299 @@ +""" +嵌入向量生成器:优先使用配置的 embedding API,sentence-transformers 作为备选 +""" + +from __future__ import annotations + +import asyncio +from functools import lru_cache +from typing import List, Optional + +import numpy as np + +from src.common.logger import get_logger + +logger = get_logger(__name__) + + +class EmbeddingGenerator: + """ + 嵌入向量生成器 + + 策略: + 1. 优先使用配置的 embedding API(通过 LLMRequest) + 2. 如果 API 不可用,回退到本地 sentence-transformers + 3. 如果 sentence-transformers 未安装,使用随机向量(仅测试) + + 优点: + - 降低本地运算负载 + - 即使未安装 sentence-transformers 也可正常运行 + - 保持与现有系统的一致性 + """ + + def __init__( + self, + use_api: bool = True, + fallback_model_name: str = "paraphrase-multilingual-MiniLM-L12-v2", + ): + """ + 初始化嵌入生成器 + + Args: + use_api: 是否优先使用 API(默认 True) + fallback_model_name: 回退本地模型名称 + """ + self.use_api = use_api + self.fallback_model_name = fallback_model_name + + # API 相关 + self._llm_request = None + self._api_available = False + self._api_dimension = None + + # 本地模型相关 + self._local_model = None + self._local_model_loaded = False + + async def _initialize_api(self): + """初始化 embedding API""" + if self._api_available: + return + + try: + from src.config.config import model_config + from src.llm_models.utils_model import LLMRequest + + embedding_config = model_config.model_task_config.embedding + self._llm_request = LLMRequest( + model_set=embedding_config, + request_type="memory_graph.embedding" + ) + + # 获取嵌入维度 + if hasattr(embedding_config, "embedding_dimension") and embedding_config.embedding_dimension: + self._api_dimension = embedding_config.embedding_dimension + + self._api_available = True + logger.info(f"✅ Embedding API 初始化成功 (维度: {self._api_dimension})") + + except Exception as e: + logger.warning(f"⚠️ Embedding API 初始化失败: {e}") + self._api_available = False + + def _load_local_model(self): + """延迟加载本地模型""" + if not self._local_model_loaded: + try: + from sentence_transformers import SentenceTransformer + + logger.info(f"📦 加载本地嵌入模型: {self.fallback_model_name}") + self._local_model = SentenceTransformer(self.fallback_model_name) + self._local_model_loaded = True + logger.info("✅ 本地嵌入模型加载成功") + except ImportError: + logger.warning( + "⚠️ sentence-transformers 未安装,将使用随机向量(仅测试用)\n" + " 安装方法: pip install sentence-transformers" + ) + self._local_model_loaded = False + except Exception as e: + logger.warning(f"⚠️ 本地模型加载失败: {e}") + self._local_model_loaded = False + + async def generate(self, text: str) -> np.ndarray: + """ + 生成单个文本的嵌入向量 + + 策略: + 1. 优先使用 API + 2. API 失败则使用本地模型 + 3. 本地模型不可用则使用随机向量 + + Args: + text: 输入文本 + + Returns: + 嵌入向量 + """ + if not text or not text.strip(): + logger.warning("输入文本为空,返回零向量") + dim = self._get_dimension() + return np.zeros(dim, dtype=np.float32) + + try: + # 策略 1: 使用 API + if self.use_api: + embedding = await self._generate_with_api(text) + if embedding is not None: + return embedding + + # 策略 2: 使用本地模型 + embedding = await self._generate_with_local_model(text) + if embedding is not None: + return embedding + + # 策略 3: 随机向量(仅测试) + logger.warning(f"⚠️ 所有嵌入策略失败,使用随机向量: {text[:30]}...") + dim = self._get_dimension() + return np.random.rand(dim).astype(np.float32) + + except Exception as e: + logger.error(f"❌ 嵌入生成失败: {e}", exc_info=True) + dim = self._get_dimension() + return np.random.rand(dim).astype(np.float32) + + async def _generate_with_api(self, text: str) -> Optional[np.ndarray]: + """使用 API 生成嵌入""" + try: + # 初始化 API + if not self._api_available: + await self._initialize_api() + + if not self._api_available or not self._llm_request: + return None + + # 调用 API + embedding_list, model_name = await self._llm_request.get_embedding(text) + + if embedding_list and len(embedding_list) > 0: + embedding = np.array(embedding_list, dtype=np.float32) + logger.debug(f"🌐 API 生成嵌入: {text[:30]}... -> {len(embedding)}维 (模型: {model_name})") + return embedding + + return None + + except Exception as e: + logger.debug(f"API 嵌入生成失败: {e}") + return None + + async def _generate_with_local_model(self, text: str) -> Optional[np.ndarray]: + """使用本地模型生成嵌入""" + try: + # 加载本地模型 + if not self._local_model_loaded: + self._load_local_model() + + if not self._local_model_loaded or not self._local_model: + return None + + # 在线程池中运行 + loop = asyncio.get_event_loop() + embedding = await loop.run_in_executor(None, self._encode_single_local, text) + + logger.debug(f"💻 本地生成嵌入: {text[:30]}... -> {len(embedding)}维") + return embedding + + except Exception as e: + logger.debug(f"本地模型嵌入生成失败: {e}") + return None + + def _encode_single_local(self, text: str) -> np.ndarray: + """同步编码单个文本(本地模型)""" + if self._local_model is None: + raise RuntimeError("本地模型未加载") + embedding = self._local_model.encode(text, convert_to_numpy=True) # type: ignore + return embedding.astype(np.float32) + + def _get_dimension(self) -> int: + """获取嵌入维度""" + # 优先使用 API 维度 + if self._api_dimension: + return self._api_dimension + + # 其次使用本地模型维度 + if self._local_model_loaded and self._local_model: + try: + return self._local_model.get_sentence_embedding_dimension() + except: + pass + + # 默认 384(sentence-transformers 常用维度) + return 384 + + async def generate_batch(self, texts: List[str]) -> List[np.ndarray]: + """ + 批量生成嵌入向量 + + Args: + texts: 文本列表 + + Returns: + 嵌入向量列表 + """ + if not texts: + return [] + + try: + # 过滤空文本 + valid_texts = [t for t in texts if t and t.strip()] + if not valid_texts: + logger.warning("所有文本为空,返回零向量列表") + dim = self._get_dimension() + return [np.zeros(dim, dtype=np.float32) for _ in texts] + + # 使用 API 批量生成(如果可用) + if self.use_api: + results = await self._generate_batch_with_api(valid_texts) + if results: + return results + + # 回退到逐个生成 + results = [] + for text in valid_texts: + embedding = await self.generate(text) + results.append(embedding) + + logger.info(f"✅ 批量生成嵌入: {len(texts)} 个文本") + return results + + except Exception as e: + logger.error(f"❌ 批量嵌入生成失败: {e}", exc_info=True) + dim = self._get_dimension() + return [np.random.rand(dim).astype(np.float32) for _ in texts] + + async def _generate_batch_with_api(self, texts: List[str]) -> Optional[List[np.ndarray]]: + """使用 API 批量生成""" + try: + # 对于大多数 API,批量调用就是多次单独调用 + # 这里保持简单,逐个调用 + results = [] + for text in texts: + embedding = await self._generate_with_api(text) + if embedding is None: + return None # 如果任何一个失败,返回 None 触发回退 + results.append(embedding) + return results + except Exception as e: + logger.debug(f"API 批量生成失败: {e}") + return None + + def get_embedding_dimension(self) -> int: + """获取嵌入向量维度""" + return self._get_dimension() + + +# 全局单例 +_global_generator: Optional[EmbeddingGenerator] = None + + +def get_embedding_generator( + use_api: bool = True, + fallback_model_name: str = "paraphrase-multilingual-MiniLM-L12-v2", +) -> EmbeddingGenerator: + """ + 获取全局嵌入生成器单例 + + Args: + use_api: 是否优先使用 API + fallback_model_name: 回退本地模型名称 + + Returns: + EmbeddingGenerator 实例 + """ + global _global_generator + if _global_generator is None: + _global_generator = EmbeddingGenerator( + use_api=use_api, + fallback_model_name=fallback_model_name + ) + return _global_generator diff --git a/src/memory_graph/utils/time_parser.py b/src/memory_graph/utils/time_parser.py new file mode 100644 index 000000000..dbf71d9f9 --- /dev/null +++ b/src/memory_graph/utils/time_parser.py @@ -0,0 +1,391 @@ +""" +时间解析器:将相对时间转换为绝对时间 + +支持的时间表达: +- 今天、明天、昨天、前天、后天 +- X天前、X天后 +- X小时前、X小时后 +- 上周、上个月、去年 +- 具体日期:2025-11-05, 11月5日 +- 时间点:早上8点、下午3点、晚上9点 +""" + +from __future__ import annotations + +import re +from datetime import datetime, timedelta +from typing import Optional, Tuple + +from src.common.logger import get_logger + +logger = get_logger(__name__) + + +class TimeParser: + """ + 时间解析器 + + 负责将自然语言时间表达转换为标准化的绝对时间 + """ + + def __init__(self, reference_time: Optional[datetime] = None): + """ + 初始化时间解析器 + + Args: + reference_time: 参考时间(通常是当前时间) + """ + self.reference_time = reference_time or datetime.now() + + def parse(self, time_str: str) -> Optional[datetime]: + """ + 解析时间字符串 + + Args: + time_str: 时间字符串 + + Returns: + 标准化的datetime对象,如果解析失败则返回None + """ + if not time_str or not isinstance(time_str, str): + return None + + time_str = time_str.strip() + + # 尝试各种解析方法 + parsers = [ + self._parse_relative_day, + self._parse_days_ago, + self._parse_hours_ago, + self._parse_week_month_year, + self._parse_specific_date, + self._parse_time_of_day, + ] + + for parser in parsers: + try: + result = parser(time_str) + if result: + logger.debug(f"时间解析: '{time_str}' → {result.isoformat()}") + return result + except Exception as e: + logger.debug(f"解析器 {parser.__name__} 失败: {e}") + continue + + logger.warning(f"无法解析时间: '{time_str}',使用当前时间") + return self.reference_time + + def _parse_relative_day(self, time_str: str) -> Optional[datetime]: + """ + 解析相对日期:今天、明天、昨天、前天、后天 + """ + relative_days = { + "今天": 0, + "今日": 0, + "明天": 1, + "明日": 1, + "昨天": -1, + "昨日": -1, + "前天": -2, + "前日": -2, + "后天": 2, + "后日": 2, + "大前天": -3, + "大后天": 3, + } + + for keyword, days in relative_days.items(): + if keyword in time_str: + result = self.reference_time + timedelta(days=days) + # 保留原有时间,只改变日期 + return result.replace(hour=0, minute=0, second=0, microsecond=0) + + return None + + def _parse_days_ago(self, time_str: str) -> Optional[datetime]: + """ + 解析 X天前/X天后 + """ + # 匹配:3天前、5天后、一天前 + pattern = r"([一二三四五六七八九十\d]+)天(前|后)" + match = re.search(pattern, time_str) + + if match: + num_str, direction = match.groups() + num = self._chinese_num_to_int(num_str) + + if direction == "前": + num = -num + + result = self.reference_time + timedelta(days=num) + return result.replace(hour=0, minute=0, second=0, microsecond=0) + + return None + + def _parse_hours_ago(self, time_str: str) -> Optional[datetime]: + """ + 解析 X小时前/X小时后、X分钟前/X分钟后 + """ + # 小时 + pattern_hour = r"([一二三四五六七八九十\d]+)小?时(前|后)" + match = re.search(pattern_hour, time_str) + + if match: + num_str, direction = match.groups() + num = self._chinese_num_to_int(num_str) + + if direction == "前": + num = -num + + return self.reference_time + timedelta(hours=num) + + # 分钟 + pattern_minute = r"([一二三四五六七八九十\d]+)分钟(前|后)" + match = re.search(pattern_minute, time_str) + + if match: + num_str, direction = match.groups() + num = self._chinese_num_to_int(num_str) + + if direction == "前": + num = -num + + return self.reference_time + timedelta(minutes=num) + + return None + + def _parse_week_month_year(self, time_str: str) -> Optional[datetime]: + """ + 解析:上周、上个月、去年、本周、本月、今年 + """ + now = self.reference_time + + if "上周" in time_str or "上星期" in time_str: + return now - timedelta(days=7) + + if "上个月" in time_str or "上月" in time_str: + # 简单处理:减30天 + return now - timedelta(days=30) + + if "去年" in time_str or "上年" in time_str: + return now.replace(year=now.year - 1) + + if "本周" in time_str or "这周" in time_str: + # 返回本周一 + return now - timedelta(days=now.weekday()) + + if "本月" in time_str or "这个月" in time_str: + return now.replace(day=1) + + if "今年" in time_str or "这年" in time_str: + return now.replace(month=1, day=1) + + return None + + def _parse_specific_date(self, time_str: str) -> Optional[datetime]: + """ + 解析具体日期: + - 2025-11-05 + - 2025/11/05 + - 11月5日 + - 11-05 + """ + # ISO 格式:2025-11-05 + pattern_iso = r"(\d{4})[-/](\d{1,2})[-/](\d{1,2})" + match = re.search(pattern_iso, time_str) + if match: + year, month, day = map(int, match.groups()) + return datetime(year, month, day) + + # 中文格式:11月5日、11月5号 + pattern_cn = r"(\d{1,2})月(\d{1,2})[日号]" + match = re.search(pattern_cn, time_str) + if match: + month, day = map(int, match.groups()) + # 使用参考时间的年份 + year = self.reference_time.year + return datetime(year, month, day) + + # 短格式:11-05(使用当前年份) + pattern_short = r"(\d{1,2})[-/](\d{1,2})" + match = re.search(pattern_short, time_str) + if match: + month, day = map(int, match.groups()) + year = self.reference_time.year + return datetime(year, month, day) + + return None + + def _parse_time_of_day(self, time_str: str) -> Optional[datetime]: + """ + 解析一天中的时间: + - 早上、上午、中午、下午、晚上、深夜 + - 早上8点、下午3点 + - 8点、15点 + """ + now = self.reference_time + result = now.replace(minute=0, second=0, microsecond=0) + + # 时间段映射 + time_periods = { + "早上": 8, + "早晨": 8, + "上午": 10, + "中午": 12, + "下午": 15, + "傍晚": 18, + "晚上": 20, + "深夜": 23, + "凌晨": 2, + } + + # 先检查是否有具体时间点:早上8点、下午3点 + for period, default_hour in time_periods.items(): + pattern = rf"{period}(\d{{1,2}})点?" + match = re.search(pattern, time_str) + if match: + hour = int(match.group(1)) + # 下午时间需要+12 + if period in ["下午", "晚上"] and hour < 12: + hour += 12 + return result.replace(hour=hour) + + # 检查时间段关键词 + for period, hour in time_periods.items(): + if period in time_str: + return result.replace(hour=hour) + + # 直接的时间点:8点、15点 + pattern = r"(\d{1,2})点" + match = re.search(pattern, time_str) + if match: + hour = int(match.group(1)) + return result.replace(hour=hour) + + return None + + def _chinese_num_to_int(self, num_str: str) -> int: + """ + 将中文数字转换为阿拉伯数字 + + Args: + num_str: 中文数字字符串(如:"一"、"十"、"3") + + Returns: + 整数 + """ + # 如果已经是数字,直接返回 + if num_str.isdigit(): + return int(num_str) + + # 中文数字映射 + chinese_nums = { + "一": 1, + "二": 2, + "三": 3, + "四": 4, + "五": 5, + "六": 6, + "七": 7, + "八": 8, + "九": 9, + "十": 10, + "零": 0, + } + + if num_str in chinese_nums: + return chinese_nums[num_str] + + # 处理 "十X" 的情况(如"十五"=15) + if num_str.startswith("十"): + if len(num_str) == 1: + return 10 + return 10 + chinese_nums.get(num_str[1], 0) + + # 处理 "X十" 的情况(如"三十"=30) + if "十" in num_str: + parts = num_str.split("十") + tens = chinese_nums.get(parts[0], 1) * 10 + ones = chinese_nums.get(parts[1], 0) if len(parts) > 1 and parts[1] else 0 + return tens + ones + + # 默认返回1 + return 1 + + def format_time(self, dt: datetime, format_type: str = "iso") -> str: + """ + 格式化时间 + + Args: + dt: datetime对象 + format_type: 格式类型 ("iso", "cn", "relative") + + Returns: + 格式化的时间字符串 + """ + if format_type == "iso": + return dt.isoformat() + + elif format_type == "cn": + return dt.strftime("%Y年%m月%d日 %H:%M:%S") + + elif format_type == "relative": + # 相对时间表达 + diff = self.reference_time - dt + days = diff.days + + if days == 0: + hours = diff.seconds // 3600 + if hours == 0: + minutes = diff.seconds // 60 + return f"{minutes}分钟前" if minutes > 0 else "刚刚" + return f"{hours}小时前" + elif days == 1: + return "昨天" + elif days == 2: + return "前天" + elif days < 7: + return f"{days}天前" + elif days < 30: + weeks = days // 7 + return f"{weeks}周前" + elif days < 365: + months = days // 30 + return f"{months}个月前" + else: + years = days // 365 + return f"{years}年前" + + return str(dt) + + def parse_time_range(self, time_str: str) -> Tuple[Optional[datetime], Optional[datetime]]: + """ + 解析时间范围:最近一周、最近3天 + + Args: + time_str: 时间范围字符串 + + Returns: + (start_time, end_time) + """ + pattern = r"最近(\d+)(天|周|月|年)" + match = re.search(pattern, time_str) + + if match: + num, unit = match.groups() + num = int(num) + + unit_map = {"天": "days", "周": "weeks", "月": "days", "年": "days"} + if unit == "周": + num *= 7 + elif unit == "月": + num *= 30 + elif unit == "年": + num *= 365 + + end_time = self.reference_time + start_time = end_time - timedelta(**{unit_map[unit]: num}) + + return (start_time, end_time) + + return (None, None)