fix: 修复代码质量问题 - 更正异常处理和导入语句
Co-authored-by: Windpicker-owo <221029311+Windpicker-owo@users.noreply.github.com>
This commit is contained in:
@@ -5,4 +5,4 @@
|
||||
from src.memory_graph.utils.embeddings import EmbeddingGenerator, get_embedding_generator
|
||||
from src.memory_graph.utils.time_parser import TimeParser
|
||||
|
||||
__all__ = ["TimeParser", "EmbeddingGenerator", "get_embedding_generator"]
|
||||
__all__ = ["EmbeddingGenerator", "TimeParser", "get_embedding_generator"]
|
||||
|
||||
@@ -5,8 +5,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -18,12 +16,12 @@ logger = get_logger(__name__)
|
||||
class EmbeddingGenerator:
|
||||
"""
|
||||
嵌入向量生成器
|
||||
|
||||
|
||||
策略:
|
||||
1. 优先使用配置的 embedding API(通过 LLMRequest)
|
||||
2. 如果 API 不可用,回退到本地 sentence-transformers
|
||||
3. 如果 sentence-transformers 未安装,使用随机向量(仅测试)
|
||||
|
||||
|
||||
优点:
|
||||
- 降低本地运算负载
|
||||
- 即使未安装 sentence-transformers 也可正常运行
|
||||
@@ -37,19 +35,19 @@ class EmbeddingGenerator:
|
||||
):
|
||||
"""
|
||||
初始化嵌入生成器
|
||||
|
||||
|
||||
Args:
|
||||
use_api: 是否优先使用 API(默认 True)
|
||||
fallback_model_name: 回退本地模型名称
|
||||
"""
|
||||
self.use_api = use_api
|
||||
self.fallback_model_name = fallback_model_name
|
||||
|
||||
|
||||
# API 相关
|
||||
self._llm_request = None
|
||||
self._api_available = False
|
||||
self._api_dimension = None
|
||||
|
||||
|
||||
# 本地模型相关
|
||||
self._local_model = None
|
||||
self._local_model_loaded = False
|
||||
@@ -58,24 +56,24 @@ class EmbeddingGenerator:
|
||||
"""初始化 embedding API"""
|
||||
if self._api_available:
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
|
||||
embedding_config = model_config.model_task_config.embedding
|
||||
self._llm_request = LLMRequest(
|
||||
model_set=embedding_config,
|
||||
request_type="memory_graph.embedding"
|
||||
)
|
||||
|
||||
|
||||
# 获取嵌入维度
|
||||
if hasattr(embedding_config, "embedding_dimension") and embedding_config.embedding_dimension:
|
||||
self._api_dimension = embedding_config.embedding_dimension
|
||||
|
||||
|
||||
self._api_available = True
|
||||
logger.info(f"✅ Embedding API 初始化成功 (维度: {self._api_dimension})")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ Embedding API 初始化失败: {e}")
|
||||
self._api_available = False
|
||||
@@ -103,15 +101,15 @@ class EmbeddingGenerator:
|
||||
async def generate(self, text: str) -> np.ndarray:
|
||||
"""
|
||||
生成单个文本的嵌入向量
|
||||
|
||||
|
||||
策略:
|
||||
1. 优先使用 API
|
||||
2. API 失败则使用本地模型
|
||||
3. 本地模型不可用则使用随机向量
|
||||
|
||||
|
||||
Args:
|
||||
text: 输入文本
|
||||
|
||||
|
||||
Returns:
|
||||
嵌入向量
|
||||
"""
|
||||
@@ -126,12 +124,12 @@ class EmbeddingGenerator:
|
||||
embedding = await self._generate_with_api(text)
|
||||
if embedding is not None:
|
||||
return embedding
|
||||
|
||||
|
||||
# 策略 2: 使用本地模型
|
||||
embedding = await self._generate_with_local_model(text)
|
||||
if embedding is not None:
|
||||
return embedding
|
||||
|
||||
|
||||
# 策略 3: 随机向量(仅测试)
|
||||
logger.warning(f"⚠️ 所有嵌入策略失败,使用随机向量: {text[:30]}...")
|
||||
dim = self._get_dimension()
|
||||
@@ -142,47 +140,47 @@ class EmbeddingGenerator:
|
||||
dim = self._get_dimension()
|
||||
return np.random.rand(dim).astype(np.float32)
|
||||
|
||||
async def _generate_with_api(self, text: str) -> Optional[np.ndarray]:
|
||||
async def _generate_with_api(self, text: str) -> np.ndarray | None:
|
||||
"""使用 API 生成嵌入"""
|
||||
try:
|
||||
# 初始化 API
|
||||
if not self._api_available:
|
||||
await self._initialize_api()
|
||||
|
||||
|
||||
if not self._api_available or not self._llm_request:
|
||||
return None
|
||||
|
||||
|
||||
# 调用 API
|
||||
embedding_list, model_name = await self._llm_request.get_embedding(text)
|
||||
|
||||
|
||||
if embedding_list and len(embedding_list) > 0:
|
||||
embedding = np.array(embedding_list, dtype=np.float32)
|
||||
logger.debug(f"🌐 API 生成嵌入: {text[:30]}... -> {len(embedding)}维 (模型: {model_name})")
|
||||
return embedding
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"API 嵌入生成失败: {e}")
|
||||
return None
|
||||
|
||||
async def _generate_with_local_model(self, text: str) -> Optional[np.ndarray]:
|
||||
async def _generate_with_local_model(self, text: str) -> np.ndarray | None:
|
||||
"""使用本地模型生成嵌入"""
|
||||
try:
|
||||
# 加载本地模型
|
||||
if not self._local_model_loaded:
|
||||
self._load_local_model()
|
||||
|
||||
|
||||
if not self._local_model_loaded or not self._local_model:
|
||||
return None
|
||||
|
||||
|
||||
# 在线程池中运行
|
||||
loop = asyncio.get_event_loop()
|
||||
embedding = await loop.run_in_executor(None, self._encode_single_local, text)
|
||||
|
||||
|
||||
logger.debug(f"💻 本地生成嵌入: {text[:30]}... -> {len(embedding)}维")
|
||||
return embedding
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"本地模型嵌入生成失败: {e}")
|
||||
return None
|
||||
@@ -199,24 +197,24 @@ class EmbeddingGenerator:
|
||||
# 优先使用 API 维度
|
||||
if self._api_dimension:
|
||||
return self._api_dimension
|
||||
|
||||
|
||||
# 其次使用本地模型维度
|
||||
if self._local_model_loaded and self._local_model:
|
||||
try:
|
||||
return self._local_model.get_sentence_embedding_dimension()
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# 默认 384(sentence-transformers 常用维度)
|
||||
return 384
|
||||
|
||||
async def generate_batch(self, texts: List[str]) -> List[np.ndarray]:
|
||||
async def generate_batch(self, texts: list[str]) -> list[np.ndarray]:
|
||||
"""
|
||||
批量生成嵌入向量
|
||||
|
||||
|
||||
Args:
|
||||
texts: 文本列表
|
||||
|
||||
|
||||
Returns:
|
||||
嵌入向量列表
|
||||
"""
|
||||
@@ -236,13 +234,13 @@ class EmbeddingGenerator:
|
||||
results = await self._generate_batch_with_api(valid_texts)
|
||||
if results:
|
||||
return results
|
||||
|
||||
|
||||
# 回退到逐个生成
|
||||
results = []
|
||||
for text in valid_texts:
|
||||
embedding = await self.generate(text)
|
||||
results.append(embedding)
|
||||
|
||||
|
||||
logger.info(f"✅ 批量生成嵌入: {len(texts)} 个文本")
|
||||
return results
|
||||
|
||||
@@ -251,7 +249,7 @@ class EmbeddingGenerator:
|
||||
dim = self._get_dimension()
|
||||
return [np.random.rand(dim).astype(np.float32) for _ in texts]
|
||||
|
||||
async def _generate_batch_with_api(self, texts: List[str]) -> Optional[List[np.ndarray]]:
|
||||
async def _generate_batch_with_api(self, texts: list[str]) -> list[np.ndarray] | None:
|
||||
"""使用 API 批量生成"""
|
||||
try:
|
||||
# 对于大多数 API,批量调用就是多次单独调用
|
||||
@@ -273,7 +271,7 @@ class EmbeddingGenerator:
|
||||
|
||||
|
||||
# 全局单例
|
||||
_global_generator: Optional[EmbeddingGenerator] = None
|
||||
_global_generator: EmbeddingGenerator | None = None
|
||||
|
||||
|
||||
def get_embedding_generator(
|
||||
@@ -282,11 +280,11 @@ def get_embedding_generator(
|
||||
) -> EmbeddingGenerator:
|
||||
"""
|
||||
获取全局嵌入生成器单例
|
||||
|
||||
|
||||
Args:
|
||||
use_api: 是否优先使用 API
|
||||
fallback_model_name: 回退本地模型名称
|
||||
|
||||
|
||||
Returns:
|
||||
EmbeddingGenerator 实例
|
||||
"""
|
||||
|
||||
@@ -5,10 +5,9 @@
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from src.memory_graph.models import Memory, MemoryNode, NodeType, EdgeType, MemoryType
|
||||
from src.memory_graph.models import EdgeType, Memory, MemoryType, NodeType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -16,18 +15,18 @@ logger = logging.getLogger(__name__)
|
||||
def format_memory_for_prompt(memory: Memory, include_metadata: bool = False) -> str:
|
||||
"""
|
||||
将记忆对象格式化为适合提示词的自然语言描述
|
||||
|
||||
|
||||
根据记忆的图结构,构建完整的主谓宾描述,包含:
|
||||
- 主语(subject node)
|
||||
- 谓语/动作(topic node)
|
||||
- 宾语/对象(object node,如果存在)
|
||||
- 属性信息(attributes,如时间、地点等)
|
||||
- 关系信息(记忆之间的关系)
|
||||
|
||||
|
||||
Args:
|
||||
memory: 记忆对象
|
||||
include_metadata: 是否包含元数据(时间、重要性等)
|
||||
|
||||
|
||||
Returns:
|
||||
格式化后的自然语言描述
|
||||
"""
|
||||
@@ -37,24 +36,22 @@ def format_memory_for_prompt(memory: Memory, include_metadata: bool = False) ->
|
||||
if not subject_node:
|
||||
logger.warning(f"记忆 {memory.id} 缺少主体节点")
|
||||
return "(记忆格式错误:缺少主体)"
|
||||
|
||||
|
||||
subject_text = subject_node.content
|
||||
|
||||
|
||||
# 2. 查找主题节点(谓语/动作)
|
||||
topic_node = None
|
||||
memory_type_relation = None
|
||||
for edge in memory.edges:
|
||||
if edge.edge_type == EdgeType.MEMORY_TYPE and edge.source_id == memory.subject_id:
|
||||
topic_node = memory.get_node_by_id(edge.target_id)
|
||||
memory_type_relation = edge.relation
|
||||
break
|
||||
|
||||
|
||||
if not topic_node:
|
||||
logger.warning(f"记忆 {memory.id} 缺少主题节点")
|
||||
return f"{subject_text}(记忆格式错误:缺少主题)"
|
||||
|
||||
|
||||
topic_text = topic_node.content
|
||||
|
||||
|
||||
# 3. 查找客体节点(宾语)和核心关系
|
||||
object_node = None
|
||||
core_relation = None
|
||||
@@ -63,9 +60,9 @@ def format_memory_for_prompt(memory: Memory, include_metadata: bool = False) ->
|
||||
object_node = memory.get_node_by_id(edge.target_id)
|
||||
core_relation = edge.relation if edge.relation else ""
|
||||
break
|
||||
|
||||
|
||||
# 4. 收集属性节点
|
||||
attributes: Dict[str, str] = {}
|
||||
attributes: dict[str, str] = {}
|
||||
for edge in memory.edges:
|
||||
if edge.edge_type == EdgeType.ATTRIBUTE:
|
||||
# 查找属性节点和值节点
|
||||
@@ -73,16 +70,16 @@ def format_memory_for_prompt(memory: Memory, include_metadata: bool = False) ->
|
||||
if attr_node and attr_node.node_type == NodeType.ATTRIBUTE:
|
||||
# 查找这个属性的值
|
||||
for value_edge in memory.edges:
|
||||
if (value_edge.edge_type == EdgeType.ATTRIBUTE
|
||||
if (value_edge.edge_type == EdgeType.ATTRIBUTE
|
||||
and value_edge.source_id == attr_node.id):
|
||||
value_node = memory.get_node_by_id(value_edge.target_id)
|
||||
if value_node and value_node.node_type == NodeType.VALUE:
|
||||
attributes[attr_node.content] = value_node.content
|
||||
break
|
||||
|
||||
|
||||
# 5. 构建自然语言描述
|
||||
parts = []
|
||||
|
||||
|
||||
# 主谓宾结构
|
||||
if object_node is not None:
|
||||
# 有完整的主谓宾
|
||||
@@ -93,7 +90,7 @@ def format_memory_for_prompt(memory: Memory, include_metadata: bool = False) ->
|
||||
else:
|
||||
# 只有主谓
|
||||
parts.append(f"{subject_text}{topic_text}")
|
||||
|
||||
|
||||
# 添加属性信息
|
||||
if attributes:
|
||||
attr_parts = []
|
||||
@@ -106,78 +103,78 @@ def format_memory_for_prompt(memory: Memory, include_metadata: bool = False) ->
|
||||
for key, value in attributes.items():
|
||||
if key not in ["时间", "地点"]:
|
||||
attr_parts.append(f"{key}:{value}")
|
||||
|
||||
|
||||
if attr_parts:
|
||||
parts.append(f"({' '.join(attr_parts)})")
|
||||
|
||||
|
||||
description = "".join(parts)
|
||||
|
||||
|
||||
# 6. 添加元数据(可选)
|
||||
if include_metadata:
|
||||
metadata_parts = []
|
||||
|
||||
|
||||
# 记忆类型
|
||||
if memory.memory_type:
|
||||
metadata_parts.append(f"类型:{memory.memory_type.value}")
|
||||
|
||||
|
||||
# 重要性
|
||||
if memory.importance >= 0.8:
|
||||
metadata_parts.append("重要")
|
||||
elif memory.importance >= 0.6:
|
||||
metadata_parts.append("一般")
|
||||
|
||||
|
||||
# 时间(如果没有在属性中)
|
||||
if "时间" not in attributes:
|
||||
time_str = _format_relative_time(memory.created_at)
|
||||
if time_str:
|
||||
metadata_parts.append(time_str)
|
||||
|
||||
|
||||
if metadata_parts:
|
||||
description += f" [{', '.join(metadata_parts)}]"
|
||||
|
||||
|
||||
return description
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"格式化记忆失败: {e}", exc_info=True)
|
||||
return f"(记忆格式化错误: {str(e)[:50]})"
|
||||
|
||||
|
||||
def format_memories_for_prompt(
|
||||
memories: List[Memory],
|
||||
max_count: Optional[int] = None,
|
||||
memories: list[Memory],
|
||||
max_count: int | None = None,
|
||||
include_metadata: bool = False,
|
||||
group_by_type: bool = False
|
||||
) -> str:
|
||||
"""
|
||||
批量格式化多条记忆为提示词文本
|
||||
|
||||
|
||||
Args:
|
||||
memories: 记忆列表
|
||||
max_count: 最大记忆数量(可选)
|
||||
include_metadata: 是否包含元数据
|
||||
group_by_type: 是否按类型分组
|
||||
|
||||
|
||||
Returns:
|
||||
格式化后的文本,包含标题和列表
|
||||
"""
|
||||
if not memories:
|
||||
return ""
|
||||
|
||||
|
||||
# 限制数量
|
||||
if max_count:
|
||||
memories = memories[:max_count]
|
||||
|
||||
|
||||
# 按类型分组
|
||||
if group_by_type:
|
||||
type_groups: Dict[MemoryType, List[Memory]] = {}
|
||||
type_groups: dict[MemoryType, list[Memory]] = {}
|
||||
for memory in memories:
|
||||
if memory.memory_type not in type_groups:
|
||||
type_groups[memory.memory_type] = []
|
||||
type_groups[memory.memory_type].append(memory)
|
||||
|
||||
|
||||
# 构建分组文本
|
||||
parts = ["### 🧠 相关记忆 (Relevant Memories)", ""]
|
||||
|
||||
|
||||
type_order = [MemoryType.FACT, MemoryType.EVENT, MemoryType.RELATION, MemoryType.OPINION]
|
||||
for mem_type in type_order:
|
||||
if mem_type in type_groups:
|
||||
@@ -186,33 +183,33 @@ def format_memories_for_prompt(
|
||||
desc = format_memory_for_prompt(memory, include_metadata)
|
||||
parts.append(f"- {desc}")
|
||||
parts.append("")
|
||||
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
else:
|
||||
# 不分组,直接列出
|
||||
parts = ["### 🧠 相关记忆 (Relevant Memories)", ""]
|
||||
|
||||
|
||||
for memory in memories:
|
||||
# 获取类型标签
|
||||
type_label = memory.memory_type.value if memory.memory_type else "未知"
|
||||
|
||||
|
||||
# 格式化记忆内容
|
||||
desc = format_memory_for_prompt(memory, include_metadata)
|
||||
|
||||
|
||||
# 添加类型标签
|
||||
parts.append(f"- **[{type_label}]** {desc}")
|
||||
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def get_memory_type_label(memory_type: str) -> str:
|
||||
"""
|
||||
获取记忆类型的中文标签
|
||||
|
||||
|
||||
Args:
|
||||
memory_type: 记忆类型(可能是英文或中文)
|
||||
|
||||
|
||||
Returns:
|
||||
中文标签
|
||||
"""
|
||||
@@ -243,27 +240,27 @@ def get_memory_type_label(memory_type: str) -> str:
|
||||
"经历": "经历",
|
||||
"情境": "情境",
|
||||
}
|
||||
|
||||
|
||||
# 转换为小写进行匹配
|
||||
memory_type_lower = memory_type.lower() if memory_type else ""
|
||||
|
||||
|
||||
return type_mapping.get(memory_type_lower, "未知")
|
||||
|
||||
|
||||
def _format_relative_time(timestamp: datetime) -> Optional[str]:
|
||||
def _format_relative_time(timestamp: datetime) -> str | None:
|
||||
"""
|
||||
格式化相对时间(如"2天前"、"刚才")
|
||||
|
||||
|
||||
Args:
|
||||
timestamp: 时间戳
|
||||
|
||||
|
||||
Returns:
|
||||
相对时间描述,如果太久远则返回None
|
||||
"""
|
||||
try:
|
||||
now = datetime.now()
|
||||
delta = now - timestamp
|
||||
|
||||
|
||||
if delta.total_seconds() < 60:
|
||||
return "刚才"
|
||||
elif delta.total_seconds() < 3600:
|
||||
@@ -290,17 +287,17 @@ def _format_relative_time(timestamp: datetime) -> Optional[str]:
|
||||
def format_memory_summary(memory: Memory) -> str:
|
||||
"""
|
||||
生成记忆的简短摘要(用于日志和调试)
|
||||
|
||||
|
||||
Args:
|
||||
memory: 记忆对象
|
||||
|
||||
|
||||
Returns:
|
||||
简短摘要
|
||||
"""
|
||||
try:
|
||||
subject_node = memory.get_subject_node()
|
||||
subject_text = subject_node.content if subject_node else "?"
|
||||
|
||||
|
||||
topic_text = "?"
|
||||
for edge in memory.edges:
|
||||
if edge.edge_type == EdgeType.MEMORY_TYPE and edge.source_id == memory.subject_id:
|
||||
@@ -308,7 +305,7 @@ def format_memory_summary(memory: Memory) -> str:
|
||||
if topic_node:
|
||||
topic_text = topic_node.content
|
||||
break
|
||||
|
||||
|
||||
return f"{subject_text} - {memory.memory_type.value if memory.memory_type else '?'}: {topic_text}"
|
||||
except Exception:
|
||||
return f"记忆 {memory.id[:8]}"
|
||||
@@ -316,8 +313,8 @@ def format_memory_summary(memory: Memory) -> str:
|
||||
|
||||
# 导出主要函数
|
||||
__all__ = [
|
||||
'format_memory_for_prompt',
|
||||
'format_memories_for_prompt',
|
||||
'get_memory_type_label',
|
||||
'format_memory_summary',
|
||||
"format_memories_for_prompt",
|
||||
"format_memory_for_prompt",
|
||||
"format_memory_summary",
|
||||
"get_memory_type_label",
|
||||
]
|
||||
|
||||
@@ -14,7 +14,6 @@ from __future__ import annotations
|
||||
|
||||
import re
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -24,26 +23,26 @@ logger = get_logger(__name__)
|
||||
class TimeParser:
|
||||
"""
|
||||
时间解析器
|
||||
|
||||
|
||||
负责将自然语言时间表达转换为标准化的绝对时间
|
||||
"""
|
||||
|
||||
def __init__(self, reference_time: Optional[datetime] = None):
|
||||
def __init__(self, reference_time: datetime | None = None):
|
||||
"""
|
||||
初始化时间解析器
|
||||
|
||||
|
||||
Args:
|
||||
reference_time: 参考时间(通常是当前时间)
|
||||
"""
|
||||
self.reference_time = reference_time or datetime.now()
|
||||
|
||||
def parse(self, time_str: str) -> Optional[datetime]:
|
||||
def parse(self, time_str: str) -> datetime | None:
|
||||
"""
|
||||
解析时间字符串
|
||||
|
||||
|
||||
Args:
|
||||
time_str: 时间字符串
|
||||
|
||||
|
||||
Returns:
|
||||
标准化的datetime对象,如果解析失败则返回None
|
||||
"""
|
||||
@@ -81,7 +80,7 @@ class TimeParser:
|
||||
logger.warning(f"无法解析时间: '{time_str}',使用当前时间")
|
||||
return self.reference_time
|
||||
|
||||
def _parse_relative_day(self, time_str: str) -> Optional[datetime]:
|
||||
def _parse_relative_day(self, time_str: str) -> datetime | None:
|
||||
"""
|
||||
解析相对日期:今天、明天、昨天、前天、后天
|
||||
"""
|
||||
@@ -108,7 +107,7 @@ class TimeParser:
|
||||
|
||||
return None
|
||||
|
||||
def _parse_days_ago(self, time_str: str) -> Optional[datetime]:
|
||||
def _parse_days_ago(self, time_str: str) -> datetime | None:
|
||||
"""
|
||||
解析 X天前/X天后、X周前/X周后、X个月前/X个月后
|
||||
"""
|
||||
@@ -172,7 +171,7 @@ class TimeParser:
|
||||
|
||||
return None
|
||||
|
||||
def _parse_hours_ago(self, time_str: str) -> Optional[datetime]:
|
||||
def _parse_hours_ago(self, time_str: str) -> datetime | None:
|
||||
"""
|
||||
解析 X小时前/X小时后、X分钟前/X分钟后
|
||||
"""
|
||||
@@ -204,7 +203,7 @@ class TimeParser:
|
||||
|
||||
return None
|
||||
|
||||
def _parse_week_month_year(self, time_str: str) -> Optional[datetime]:
|
||||
def _parse_week_month_year(self, time_str: str) -> datetime | None:
|
||||
"""
|
||||
解析:上周、上个月、去年、本周、本月、今年
|
||||
"""
|
||||
@@ -232,7 +231,7 @@ class TimeParser:
|
||||
|
||||
return None
|
||||
|
||||
def _parse_specific_date(self, time_str: str) -> Optional[datetime]:
|
||||
def _parse_specific_date(self, time_str: str) -> datetime | None:
|
||||
"""
|
||||
解析具体日期:
|
||||
- 2025-11-05
|
||||
@@ -266,7 +265,7 @@ class TimeParser:
|
||||
|
||||
return None
|
||||
|
||||
def _parse_time_of_day(self, time_str: str) -> Optional[datetime]:
|
||||
def _parse_time_of_day(self, time_str: str) -> datetime | None:
|
||||
"""
|
||||
解析一天中的时间:
|
||||
- 早上、上午、中午、下午、晚上、深夜
|
||||
@@ -290,7 +289,7 @@ class TimeParser:
|
||||
}
|
||||
|
||||
# 先检查是否有具体时间点:早上8点、下午3点
|
||||
for period, default_hour in time_periods.items():
|
||||
for period in time_periods.keys():
|
||||
pattern = rf"{period}(\d{{1,2}})点?"
|
||||
match = re.search(pattern, time_str)
|
||||
if match:
|
||||
@@ -314,13 +313,13 @@ class TimeParser:
|
||||
|
||||
return None
|
||||
|
||||
def _parse_combined_time(self, time_str: str) -> Optional[datetime]:
|
||||
def _parse_combined_time(self, time_str: str) -> datetime | None:
|
||||
"""
|
||||
解析组合时间表达:今天下午、昨天晚上、明天早上
|
||||
"""
|
||||
# 先解析日期部分
|
||||
date_result = None
|
||||
|
||||
|
||||
# 相对日期关键词
|
||||
relative_days = {
|
||||
"今天": 0, "今日": 0,
|
||||
@@ -330,16 +329,16 @@ class TimeParser:
|
||||
"后天": 2, "后日": 2,
|
||||
"大前天": -3, "大后天": 3,
|
||||
}
|
||||
|
||||
|
||||
for keyword, days in relative_days.items():
|
||||
if keyword in time_str:
|
||||
date_result = self.reference_time + timedelta(days=days)
|
||||
date_result = date_result.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
break
|
||||
|
||||
|
||||
if not date_result:
|
||||
return None
|
||||
|
||||
|
||||
# 再解析时间段部分
|
||||
time_periods = {
|
||||
"早上": 8, "早晨": 8,
|
||||
@@ -351,7 +350,7 @@ class TimeParser:
|
||||
"深夜": 23,
|
||||
"凌晨": 2,
|
||||
}
|
||||
|
||||
|
||||
for period, hour in time_periods.items():
|
||||
if period in time_str:
|
||||
# 检查是否有具体时间点
|
||||
@@ -363,17 +362,17 @@ class TimeParser:
|
||||
if period in ["下午", "晚上"] and hour < 12:
|
||||
hour += 12
|
||||
return date_result.replace(hour=hour)
|
||||
|
||||
|
||||
# 如果没有时间段,返回日期(默认0点)
|
||||
return date_result
|
||||
|
||||
def _chinese_num_to_int(self, num_str: str) -> int:
|
||||
"""
|
||||
将中文数字转换为阿拉伯数字
|
||||
|
||||
|
||||
Args:
|
||||
num_str: 中文数字字符串(如:"一"、"十"、"3")
|
||||
|
||||
|
||||
Returns:
|
||||
整数
|
||||
"""
|
||||
@@ -418,11 +417,11 @@ class TimeParser:
|
||||
def format_time(self, dt: datetime, format_type: str = "iso") -> str:
|
||||
"""
|
||||
格式化时间
|
||||
|
||||
|
||||
Args:
|
||||
dt: datetime对象
|
||||
format_type: 格式类型 ("iso", "cn", "relative")
|
||||
|
||||
|
||||
Returns:
|
||||
格式化的时间字符串
|
||||
"""
|
||||
@@ -461,13 +460,13 @@ class TimeParser:
|
||||
|
||||
return str(dt)
|
||||
|
||||
def parse_time_range(self, time_str: str) -> Tuple[Optional[datetime], Optional[datetime]]:
|
||||
def parse_time_range(self, time_str: str) -> tuple[datetime | None, datetime | None]:
|
||||
"""
|
||||
解析时间范围:最近一周、最近3天
|
||||
|
||||
|
||||
Args:
|
||||
time_str: 时间范围字符串
|
||||
|
||||
|
||||
Returns:
|
||||
(start_time, end_time)
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user