fix: 修复代码质量问题 - 更正异常处理和导入语句

Co-authored-by: Windpicker-owo <221029311+Windpicker-owo@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot]
2025-11-07 04:39:35 +00:00
parent 3bdcfa3dd4
commit 5caf630623
20 changed files with 893 additions and 910 deletions

View File

@@ -6,10 +6,12 @@ from typing import ClassVar
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system import BasePlugin, register_plugin from src.plugin_system import BasePlugin, register_plugin
from src.plugin_system.base.component_types import ComponentInfo, ToolInfo
logger = get_logger("memory_graph_plugin") logger = get_logger("memory_graph_plugin")
# 用于存储后台任务引用
_background_tasks = set()
@register_plugin @register_plugin
class MemoryGraphPlugin(BasePlugin): class MemoryGraphPlugin(BasePlugin):
@@ -60,6 +62,7 @@ class MemoryGraphPlugin(BasePlugin):
"""插件卸载时的回调""" """插件卸载时的回调"""
try: try:
import asyncio import asyncio
from src.memory_graph.manager_singleton import shutdown_memory_manager from src.memory_graph.manager_singleton import shutdown_memory_manager
logger.info(f"{self.log_prefix} 正在关闭记忆系统...") logger.info(f"{self.log_prefix} 正在关闭记忆系统...")
@@ -68,7 +71,10 @@ class MemoryGraphPlugin(BasePlugin):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
if loop.is_running(): if loop.is_running():
# 如果循环正在运行,创建任务 # 如果循环正在运行,创建任务
asyncio.create_task(shutdown_memory_manager()) task = asyncio.create_task(shutdown_memory_manager())
# 存储引用以防止任务被垃圾回收
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
else: else:
# 如果循环未运行,直接运行 # 如果循环未运行,直接运行
loop.run_until_complete(shutdown_memory_manager()) loop.run_until_complete(shutdown_memory_manager())

View File

@@ -10,13 +10,13 @@
使用方法: 使用方法:
# 预览模式(不实际删除) # 预览模式(不实际删除)
python scripts/deduplicate_memories.py --dry-run python scripts/deduplicate_memories.py --dry-run
# 执行去重 # 执行去重
python scripts/deduplicate_memories.py python scripts/deduplicate_memories.py
# 指定相似度阈值 # 指定相似度阈值
python scripts/deduplicate_memories.py --threshold 0.9 python scripts/deduplicate_memories.py --threshold 0.9
# 指定数据目录 # 指定数据目录
python scripts/deduplicate_memories.py --data-dir data/memory_graph python scripts/deduplicate_memories.py --data-dir data/memory_graph
""" """
@@ -25,27 +25,26 @@ import asyncio
import sys import sys
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Set, Tuple
import numpy as np import numpy as np
sys.path.insert(0, str(Path(__file__).parent.parent)) sys.path.insert(0, str(Path(__file__).parent.parent))
from src.common.logger import get_logger from src.common.logger import get_logger
from src.memory_graph.manager_singleton import get_memory_manager, initialize_memory_manager, shutdown_memory_manager from src.memory_graph.manager_singleton import initialize_memory_manager, shutdown_memory_manager
logger = get_logger(__name__) logger = get_logger(__name__)
class MemoryDeduplicator: class MemoryDeduplicator:
"""记忆去重器""" """记忆去重器"""
def __init__(self, data_dir: str = "data/memory_graph", dry_run: bool = False, threshold: float = 0.85): def __init__(self, data_dir: str = "data/memory_graph", dry_run: bool = False, threshold: float = 0.85):
self.data_dir = data_dir self.data_dir = data_dir
self.dry_run = dry_run self.dry_run = dry_run
self.threshold = threshold self.threshold = threshold
self.manager = None self.manager = None
# 统计信息 # 统计信息
self.stats = { self.stats = {
"total_memories": 0, "total_memories": 0,
@@ -54,34 +53,34 @@ class MemoryDeduplicator:
"duplicates_removed": 0, "duplicates_removed": 0,
"errors": 0, "errors": 0,
} }
async def initialize(self): async def initialize(self):
"""初始化记忆管理器""" """初始化记忆管理器"""
logger.info(f"正在初始化记忆管理器 (data_dir={self.data_dir})...") logger.info(f"正在初始化记忆管理器 (data_dir={self.data_dir})...")
self.manager = await initialize_memory_manager(data_dir=self.data_dir) self.manager = await initialize_memory_manager(data_dir=self.data_dir)
if not self.manager: if not self.manager:
raise RuntimeError("记忆管理器初始化失败") raise RuntimeError("记忆管理器初始化失败")
self.stats["total_memories"] = len(self.manager.graph_store.get_all_memories()) self.stats["total_memories"] = len(self.manager.graph_store.get_all_memories())
logger.info(f"✅ 记忆管理器初始化成功,共 {self.stats['total_memories']} 条记忆") logger.info(f"✅ 记忆管理器初始化成功,共 {self.stats['total_memories']} 条记忆")
async def find_similar_pairs(self) -> List[Tuple[str, str, float]]: async def find_similar_pairs(self) -> list[tuple[str, str, float]]:
""" """
查找所有相似的记忆对(通过向量相似度计算) 查找所有相似的记忆对(通过向量相似度计算)
Returns: Returns:
[(memory_id_1, memory_id_2, similarity), ...] [(memory_id_1, memory_id_2, similarity), ...]
""" """
logger.info("正在扫描相似记忆对...") logger.info("正在扫描相似记忆对...")
similar_pairs = [] similar_pairs = []
seen_pairs = set() # 避免重复 seen_pairs = set() # 避免重复
# 获取所有记忆 # 获取所有记忆
all_memories = self.manager.graph_store.get_all_memories() all_memories = self.manager.graph_store.get_all_memories()
total_memories = len(all_memories) total_memories = len(all_memories)
logger.info(f"开始计算 {total_memories} 条记忆的相似度...") logger.info(f"开始计算 {total_memories} 条记忆的相似度...")
# 两两比较记忆的相似度 # 两两比较记忆的相似度
for i, memory_i in enumerate(all_memories): for i, memory_i in enumerate(all_memories):
# 每处理10条记忆让出控制权 # 每处理10条记忆让出控制权
@@ -89,115 +88,115 @@ class MemoryDeduplicator:
await asyncio.sleep(0) await asyncio.sleep(0)
if i > 0: if i > 0:
logger.info(f"进度: {i}/{total_memories} ({i*100//total_memories}%)") logger.info(f"进度: {i}/{total_memories} ({i*100//total_memories}%)")
# 获取记忆i的向量从主题节点 # 获取记忆i的向量从主题节点
vector_i = None vector_i = None
for node in memory_i.nodes: for node in memory_i.nodes:
if node.embedding is not None: if node.embedding is not None:
vector_i = node.embedding vector_i = node.embedding
break break
if vector_i is None: if vector_i is None:
continue continue
# 与后续记忆比较 # 与后续记忆比较
for j in range(i + 1, total_memories): for j in range(i + 1, total_memories):
memory_j = all_memories[j] memory_j = all_memories[j]
# 获取记忆j的向量 # 获取记忆j的向量
vector_j = None vector_j = None
for node in memory_j.nodes: for node in memory_j.nodes:
if node.embedding is not None: if node.embedding is not None:
vector_j = node.embedding vector_j = node.embedding
break break
if vector_j is None: if vector_j is None:
continue continue
# 计算余弦相似度 # 计算余弦相似度
similarity = self._cosine_similarity(vector_i, vector_j) similarity = self._cosine_similarity(vector_i, vector_j)
# 只保存满足阈值的相似对 # 只保存满足阈值的相似对
if similarity >= self.threshold: if similarity >= self.threshold:
pair_key = tuple(sorted([memory_i.id, memory_j.id])) pair_key = tuple(sorted([memory_i.id, memory_j.id]))
if pair_key not in seen_pairs: if pair_key not in seen_pairs:
seen_pairs.add(pair_key) seen_pairs.add(pair_key)
similar_pairs.append((memory_i.id, memory_j.id, similarity)) similar_pairs.append((memory_i.id, memory_j.id, similarity))
self.stats["similar_pairs"] = len(similar_pairs) self.stats["similar_pairs"] = len(similar_pairs)
logger.info(f"找到 {len(similar_pairs)} 对相似记忆(阈值>={self.threshold}") logger.info(f"找到 {len(similar_pairs)} 对相似记忆(阈值>={self.threshold}")
return similar_pairs return similar_pairs
def _cosine_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float: def _cosine_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:
"""计算余弦相似度""" """计算余弦相似度"""
try: try:
vec1_norm = np.linalg.norm(vec1) vec1_norm = np.linalg.norm(vec1)
vec2_norm = np.linalg.norm(vec2) vec2_norm = np.linalg.norm(vec2)
if vec1_norm == 0 or vec2_norm == 0: if vec1_norm == 0 or vec2_norm == 0:
return 0.0 return 0.0
similarity = np.dot(vec1, vec2) / (vec1_norm * vec2_norm) similarity = np.dot(vec1, vec2) / (vec1_norm * vec2_norm)
return float(similarity) return float(similarity)
except Exception as e: except Exception as e:
logger.error(f"计算余弦相似度失败: {e}") logger.error(f"计算余弦相似度失败: {e}")
return 0.0 return 0.0
def decide_which_to_keep(self, mem_id_1: str, mem_id_2: str) -> Tuple[Optional[str], Optional[str]]: def decide_which_to_keep(self, mem_id_1: str, mem_id_2: str) -> tuple[str | None, str | None]:
""" """
决定保留哪个记忆,删除哪个 决定保留哪个记忆,删除哪个
优先级: 优先级:
1. 重要性更高的 1. 重要性更高的
2. 激活度更高的 2. 激活度更高的
3. 创建时间更早的 3. 创建时间更早的
Returns: Returns:
(keep_id, remove_id) (keep_id, remove_id)
""" """
mem1 = self.manager.graph_store.get_memory_by_id(mem_id_1) mem1 = self.manager.graph_store.get_memory_by_id(mem_id_1)
mem2 = self.manager.graph_store.get_memory_by_id(mem_id_2) mem2 = self.manager.graph_store.get_memory_by_id(mem_id_2)
if not mem1 or not mem2: if not mem1 or not mem2:
logger.warning(f"记忆不存在: {mem_id_1} or {mem_id_2}") logger.warning(f"记忆不存在: {mem_id_1} or {mem_id_2}")
return None, None return None, None
# 比较重要性 # 比较重要性
if mem1.importance > mem2.importance: if mem1.importance > mem2.importance:
return mem_id_1, mem_id_2 return mem_id_1, mem_id_2
elif mem1.importance < mem2.importance: elif mem1.importance < mem2.importance:
return mem_id_2, mem_id_1 return mem_id_2, mem_id_1
# 重要性相同,比较激活度 # 重要性相同,比较激活度
if mem1.activation > mem2.activation: if mem1.activation > mem2.activation:
return mem_id_1, mem_id_2 return mem_id_1, mem_id_2
elif mem1.activation < mem2.activation: elif mem1.activation < mem2.activation:
return mem_id_2, mem_id_1 return mem_id_2, mem_id_1
# 激活度也相同,保留更早创建的 # 激活度也相同,保留更早创建的
if mem1.created_at < mem2.created_at: if mem1.created_at < mem2.created_at:
return mem_id_1, mem_id_2 return mem_id_1, mem_id_2
else: else:
return mem_id_2, mem_id_1 return mem_id_2, mem_id_1
async def deduplicate_pair(self, mem_id_1: str, mem_id_2: str, similarity: float) -> bool: async def deduplicate_pair(self, mem_id_1: str, mem_id_2: str, similarity: float) -> bool:
""" """
去重一对相似记忆 去重一对相似记忆
Returns: Returns:
是否成功去重 是否成功去重
""" """
keep_id, remove_id = self.decide_which_to_keep(mem_id_1, mem_id_2) keep_id, remove_id = self.decide_which_to_keep(mem_id_1, mem_id_2)
if not keep_id or not remove_id: if not keep_id or not remove_id:
self.stats["errors"] += 1 self.stats["errors"] += 1
return False return False
keep_mem = self.manager.graph_store.get_memory_by_id(keep_id) keep_mem = self.manager.graph_store.get_memory_by_id(keep_id)
remove_mem = self.manager.graph_store.get_memory_by_id(remove_id) remove_mem = self.manager.graph_store.get_memory_by_id(remove_id)
logger.info(f"") logger.info("")
logger.info(f"{'[预览]' if self.dry_run else '[执行]'} 去重相似记忆对 (相似度={similarity:.3f}):") logger.info(f"{'[预览]' if self.dry_run else '[执行]'} 去重相似记忆对 (相似度={similarity:.3f}):")
logger.info(f" 保留: {keep_id}") logger.info(f" 保留: {keep_id}")
logger.info(f" - 主题: {keep_mem.metadata.get('topic', 'N/A')}") logger.info(f" - 主题: {keep_mem.metadata.get('topic', 'N/A')}")
@@ -209,41 +208,41 @@ class MemoryDeduplicator:
logger.info(f" - 重要性: {remove_mem.importance:.2f}") logger.info(f" - 重要性: {remove_mem.importance:.2f}")
logger.info(f" - 激活度: {remove_mem.activation:.2f}") logger.info(f" - 激活度: {remove_mem.activation:.2f}")
logger.info(f" - 创建时间: {remove_mem.created_at}") logger.info(f" - 创建时间: {remove_mem.created_at}")
if self.dry_run: if self.dry_run:
logger.info(" [预览模式] 不执行实际删除") logger.info(" [预览模式] 不执行实际删除")
self.stats["duplicates_found"] += 1 self.stats["duplicates_found"] += 1
return True return True
try: try:
# 增强保留记忆的属性 # 增强保留记忆的属性
keep_mem.importance = min(1.0, keep_mem.importance + 0.05) keep_mem.importance = min(1.0, keep_mem.importance + 0.05)
keep_mem.activation = min(1.0, keep_mem.activation + 0.05) keep_mem.activation = min(1.0, keep_mem.activation + 0.05)
# 累加访问次数 # 累加访问次数
if hasattr(keep_mem, 'access_count') and hasattr(remove_mem, 'access_count'): if hasattr(keep_mem, "access_count") and hasattr(remove_mem, "access_count"):
keep_mem.access_count += remove_mem.access_count keep_mem.access_count += remove_mem.access_count
# 删除相似记忆 # 删除相似记忆
await self.manager.delete_memory(remove_id) await self.manager.delete_memory(remove_id)
self.stats["duplicates_removed"] += 1 self.stats["duplicates_removed"] += 1
logger.info(f" ✅ 删除成功") logger.info(" ✅ 删除成功")
# 让出控制权 # 让出控制权
await asyncio.sleep(0) await asyncio.sleep(0)
return True return True
except Exception as e: except Exception as e:
logger.error(f" ❌ 删除失败: {e}", exc_info=True) logger.error(f" ❌ 删除失败: {e}", exc_info=True)
self.stats["errors"] += 1 self.stats["errors"] += 1
return False return False
async def run(self): async def run(self):
"""执行去重""" """执行去重"""
start_time = datetime.now() start_time = datetime.now()
print("="*70) print("="*70)
print("记忆去重工具") print("记忆去重工具")
print("="*70) print("="*70)
@@ -252,13 +251,13 @@ class MemoryDeduplicator:
print(f"模式: {'预览模式(不实际删除)' if self.dry_run else '执行模式(会实际删除)'}") print(f"模式: {'预览模式(不实际删除)' if self.dry_run else '执行模式(会实际删除)'}")
print("="*70) print("="*70)
print() print()
# 初始化 # 初始化
await self.initialize() await self.initialize()
# 查找相似对 # 查找相似对
similar_pairs = await self.find_similar_pairs() similar_pairs = await self.find_similar_pairs()
if not similar_pairs: if not similar_pairs:
logger.info("未找到需要去重的相似记忆对") logger.info("未找到需要去重的相似记忆对")
print() print()
@@ -266,19 +265,19 @@ class MemoryDeduplicator:
print("未找到需要去重的记忆") print("未找到需要去重的记忆")
print("="*70) print("="*70)
return return
# 去重处理 # 去重处理
logger.info(f"开始{'预览' if self.dry_run else '执行'}去重...") logger.info(f"开始{'预览' if self.dry_run else '执行'}去重...")
print() print()
processed_pairs = set() # 避免重复处理 processed_pairs = set() # 避免重复处理
for mem_id_1, mem_id_2, similarity in similar_pairs: for mem_id_1, mem_id_2, similarity in similar_pairs:
# 检查是否已处理(可能一个记忆已被删除) # 检查是否已处理(可能一个记忆已被删除)
pair_key = tuple(sorted([mem_id_1, mem_id_2])) pair_key = tuple(sorted([mem_id_1, mem_id_2]))
if pair_key in processed_pairs: if pair_key in processed_pairs:
continue continue
# 检查记忆是否仍存在 # 检查记忆是否仍存在
if not self.manager.graph_store.get_memory_by_id(mem_id_1): if not self.manager.graph_store.get_memory_by_id(mem_id_1):
logger.debug(f"记忆 {mem_id_1} 已不存在,跳过") logger.debug(f"记忆 {mem_id_1} 已不存在,跳过")
@@ -286,22 +285,22 @@ class MemoryDeduplicator:
if not self.manager.graph_store.get_memory_by_id(mem_id_2): if not self.manager.graph_store.get_memory_by_id(mem_id_2):
logger.debug(f"记忆 {mem_id_2} 已不存在,跳过") logger.debug(f"记忆 {mem_id_2} 已不存在,跳过")
continue continue
# 执行去重 # 执行去重
success = await self.deduplicate_pair(mem_id_1, mem_id_2, similarity) success = await self.deduplicate_pair(mem_id_1, mem_id_2, similarity)
if success: if success:
processed_pairs.add(pair_key) processed_pairs.add(pair_key)
# 保存数据(如果不是干运行) # 保存数据(如果不是干运行)
if not self.dry_run: if not self.dry_run:
logger.info("正在保存数据...") logger.info("正在保存数据...")
await self.manager.persistence.save_graph_store(self.manager.graph_store) await self.manager.persistence.save_graph_store(self.manager.graph_store)
logger.info("✅ 数据已保存") logger.info("✅ 数据已保存")
# 统计报告 # 统计报告
elapsed = (datetime.now() - start_time).total_seconds() elapsed = (datetime.now() - start_time).total_seconds()
print() print()
print("="*70) print("="*70)
print("去重报告") print("去重报告")
@@ -312,7 +311,7 @@ class MemoryDeduplicator:
print(f"{'预览通过' if self.dry_run else '成功删除'}: {self.stats['duplicates_found'] if self.dry_run else self.stats['duplicates_removed']}") print(f"{'预览通过' if self.dry_run else '成功删除'}: {self.stats['duplicates_found'] if self.dry_run else self.stats['duplicates_removed']}")
print(f"错误数: {self.stats['errors']}") print(f"错误数: {self.stats['errors']}")
print(f"耗时: {elapsed:.2f}") print(f"耗时: {elapsed:.2f}")
if self.dry_run: if self.dry_run:
print() print()
print("⚠️ 这是预览模式,未实际删除任何记忆") print("⚠️ 这是预览模式,未实际删除任何记忆")
@@ -322,9 +321,9 @@ class MemoryDeduplicator:
print("✅ 去重完成!") print("✅ 去重完成!")
final_count = len(self.manager.graph_store.get_all_memories()) final_count = len(self.manager.graph_store.get_all_memories())
print(f"📊 最终记忆数: {final_count} (减少 {self.stats['total_memories'] - final_count} 条)") print(f"📊 最终记忆数: {final_count} (减少 {self.stats['total_memories'] - final_count} 条)")
print("="*70) print("="*70)
async def cleanup(self): async def cleanup(self):
"""清理资源""" """清理资源"""
if self.manager: if self.manager:
@@ -340,50 +339,50 @@ async def main():
示例: 示例:
# 预览模式(推荐先运行) # 预览模式(推荐先运行)
python scripts/deduplicate_memories.py --dry-run python scripts/deduplicate_memories.py --dry-run
# 执行去重 # 执行去重
python scripts/deduplicate_memories.py python scripts/deduplicate_memories.py
# 指定相似度阈值(只处理相似度>=0.9的记忆对) # 指定相似度阈值(只处理相似度>=0.9的记忆对)
python scripts/deduplicate_memories.py --threshold 0.9 python scripts/deduplicate_memories.py --threshold 0.9
# 指定数据目录 # 指定数据目录
python scripts/deduplicate_memories.py --data-dir data/memory_graph python scripts/deduplicate_memories.py --data-dir data/memory_graph
# 组合使用 # 组合使用
python scripts/deduplicate_memories.py --dry-run --threshold 0.95 --data-dir data/test python scripts/deduplicate_memories.py --dry-run --threshold 0.95 --data-dir data/test
""" """
) )
parser.add_argument( parser.add_argument(
"--dry-run", "--dry-run",
action="store_true", action="store_true",
help="预览模式,不实际删除记忆(推荐先运行此模式)" help="预览模式,不实际删除记忆(推荐先运行此模式)"
) )
parser.add_argument( parser.add_argument(
"--threshold", "--threshold",
type=float, type=float,
default=0.85, default=0.85,
help="相似度阈值,只处理相似度>=此值的记忆对(默认: 0.85" help="相似度阈值,只处理相似度>=此值的记忆对(默认: 0.85"
) )
parser.add_argument( parser.add_argument(
"--data-dir", "--data-dir",
type=str, type=str,
default="data/memory_graph", default="data/memory_graph",
help="记忆数据目录(默认: data/memory_graph" help="记忆数据目录(默认: data/memory_graph"
) )
args = parser.parse_args() args = parser.parse_args()
# 创建去重器 # 创建去重器
deduplicator = MemoryDeduplicator( deduplicator = MemoryDeduplicator(
data_dir=args.data_dir, data_dir=args.data_dir,
dry_run=args.dry_run, dry_run=args.dry_run,
threshold=args.threshold threshold=args.threshold
) )
try: try:
# 执行去重 # 执行去重
await deduplicator.run() await deduplicator.run()
@@ -396,7 +395,7 @@ async def main():
finally: finally:
# 清理资源 # 清理资源
await deduplicator.cleanup() await deduplicator.cleanup()
return 0 return 0

View File

@@ -6,24 +6,24 @@
from src.memory_graph.manager import MemoryManager from src.memory_graph.manager import MemoryManager
from src.memory_graph.models import ( from src.memory_graph.models import (
EdgeType,
Memory, Memory,
MemoryEdge, MemoryEdge,
MemoryNode, MemoryNode,
MemoryStatus, MemoryStatus,
MemoryType, MemoryType,
NodeType, NodeType,
EdgeType,
) )
__all__ = [ __all__ = [
"MemoryManager", "EdgeType",
"Memory", "Memory",
"MemoryNode",
"MemoryEdge", "MemoryEdge",
"MemoryManager",
"MemoryNode",
"MemoryStatus",
"MemoryType", "MemoryType",
"NodeType", "NodeType",
"EdgeType",
"MemoryStatus",
] ]
__version__ = "0.1.0" __version__ = "0.1.0"

View File

@@ -6,4 +6,4 @@ from src.memory_graph.core.builder import MemoryBuilder
from src.memory_graph.core.extractor import MemoryExtractor from src.memory_graph.core.extractor import MemoryExtractor
from src.memory_graph.core.node_merger import NodeMerger from src.memory_graph.core.node_merger import NodeMerger
__all__ = ["NodeMerger", "MemoryExtractor", "MemoryBuilder"] __all__ = ["MemoryBuilder", "MemoryExtractor", "NodeMerger"]

View File

@@ -5,7 +5,7 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Optional from typing import Any
import numpy as np import numpy as np
@@ -16,7 +16,6 @@ from src.memory_graph.models import (
MemoryEdge, MemoryEdge,
MemoryNode, MemoryNode,
MemoryStatus, MemoryStatus,
MemoryType,
NodeType, NodeType,
) )
from src.memory_graph.storage.graph_store import GraphStore from src.memory_graph.storage.graph_store import GraphStore
@@ -28,7 +27,7 @@ logger = get_logger(__name__)
class MemoryBuilder: class MemoryBuilder:
""" """
记忆构建器 记忆构建器
负责: 负责:
1. 根据提取的元素自动构造记忆子图 1. 根据提取的元素自动构造记忆子图
2. 创建节点和边的完整结构 2. 创建节点和边的完整结构
@@ -41,11 +40,11 @@ class MemoryBuilder:
self, self,
vector_store: VectorStore, vector_store: VectorStore,
graph_store: GraphStore, graph_store: GraphStore,
embedding_generator: Optional[Any] = None, embedding_generator: Any | None = None,
): ):
""" """
初始化记忆构建器 初始化记忆构建器
Args: Args:
vector_store: 向量存储 vector_store: 向量存储
graph_store: 图存储 graph_store: 图存储
@@ -55,13 +54,13 @@ class MemoryBuilder:
self.graph_store = graph_store self.graph_store = graph_store
self.embedding_generator = embedding_generator self.embedding_generator = embedding_generator
async def build_memory(self, extracted_params: Dict[str, Any]) -> Memory: async def build_memory(self, extracted_params: dict[str, Any]) -> Memory:
""" """
构建完整的记忆对象 构建完整的记忆对象
Args: Args:
extracted_params: 提取器返回的标准化参数 extracted_params: 提取器返回的标准化参数
Returns: Returns:
Memory 对象(状态为 STAGED Memory 对象(状态为 STAGED
""" """
@@ -97,7 +96,7 @@ class MemoryBuilder:
edges.append(memory_type_edge) edges.append(memory_type_edge)
# 4. 如果有客体,创建客体节点并连接 # 4. 如果有客体,创建客体节点并连接
if "object" in extracted_params and extracted_params["object"]: if extracted_params.get("object"):
object_node = await self._create_object_node( object_node = await self._create_object_node(
content=extracted_params["object"], memory_id=memory_id content=extracted_params["object"], memory_id=memory_id
) )
@@ -158,14 +157,14 @@ class MemoryBuilder:
) -> MemoryNode: ) -> MemoryNode:
""" """
创建新节点或复用已存在的相似节点 创建新节点或复用已存在的相似节点
对于主体(SUBJECT)和属性(ATTRIBUTE),检查是否已存在相同内容的节点 对于主体(SUBJECT)和属性(ATTRIBUTE),检查是否已存在相同内容的节点
Args: Args:
content: 节点内容 content: 节点内容
node_type: 节点类型 node_type: 节点类型
memory_id: 所属记忆ID memory_id: 所属记忆ID
Returns: Returns:
MemoryNode 对象 MemoryNode 对象
""" """
@@ -190,11 +189,11 @@ class MemoryBuilder:
async def _create_topic_node(self, content: str, memory_id: str) -> MemoryNode: async def _create_topic_node(self, content: str, memory_id: str) -> MemoryNode:
""" """
创建主题节点(需要生成嵌入向量) 创建主题节点(需要生成嵌入向量)
Args: Args:
content: 节点内容 content: 节点内容
memory_id: 所属记忆ID memory_id: 所属记忆ID
Returns: Returns:
MemoryNode 对象 MemoryNode 对象
""" """
@@ -225,11 +224,11 @@ class MemoryBuilder:
async def _create_object_node(self, content: str, memory_id: str) -> MemoryNode: async def _create_object_node(self, content: str, memory_id: str) -> MemoryNode:
""" """
创建客体节点(需要生成嵌入向量) 创建客体节点(需要生成嵌入向量)
Args: Args:
content: 节点内容 content: 节点内容
memory_id: 所属记忆ID memory_id: 所属记忆ID
Returns: Returns:
MemoryNode 对象 MemoryNode 对象
""" """
@@ -258,22 +257,22 @@ class MemoryBuilder:
async def _process_attributes( async def _process_attributes(
self, self,
attributes: Dict[str, Any], attributes: dict[str, Any],
parent_id: str, parent_id: str,
memory_id: str, memory_id: str,
importance: float, importance: float,
) -> tuple[List[MemoryNode], List[MemoryEdge]]: ) -> tuple[list[MemoryNode], list[MemoryEdge]]:
""" """
处理属性,构建属性子图 处理属性,构建属性子图
结构TOPIC -> ATTRIBUTE -> VALUE 结构TOPIC -> ATTRIBUTE -> VALUE
Args: Args:
attributes: 属性字典 attributes: 属性字典
parent_id: 父节点ID通常是TOPIC parent_id: 父节点ID通常是TOPIC
memory_id: 所属记忆ID memory_id: 所属记忆ID
importance: 重要性 importance: 重要性
Returns: Returns:
(属性节点列表, 属性边列表) (属性节点列表, 属性边列表)
""" """
@@ -322,10 +321,10 @@ class MemoryBuilder:
async def _generate_embedding(self, text: str) -> np.ndarray: async def _generate_embedding(self, text: str) -> np.ndarray:
""" """
生成文本的嵌入向量 生成文本的嵌入向量
Args: Args:
text: 文本内容 text: 文本内容
Returns: Returns:
嵌入向量 嵌入向量
""" """
@@ -341,14 +340,14 @@ class MemoryBuilder:
async def _find_existing_node( async def _find_existing_node(
self, content: str, node_type: NodeType self, content: str, node_type: NodeType
) -> Optional[MemoryNode]: ) -> MemoryNode | None:
""" """
查找已存在的完全匹配节点(用于主体和属性) 查找已存在的完全匹配节点(用于主体和属性)
Args: Args:
content: 节点内容 content: 节点内容
node_type: 节点类型 node_type: 节点类型
Returns: Returns:
已存在的节点,如果没有则返回 None 已存在的节点,如果没有则返回 None
""" """
@@ -369,14 +368,14 @@ class MemoryBuilder:
async def _find_similar_topic( async def _find_similar_topic(
self, content: str, embedding: np.ndarray self, content: str, embedding: np.ndarray
) -> Optional[MemoryNode]: ) -> MemoryNode | None:
""" """
查找相似的主题节点(基于语义相似度) 查找相似的主题节点(基于语义相似度)
Args: Args:
content: 内容 content: 内容
embedding: 嵌入向量 embedding: 嵌入向量
Returns: Returns:
相似节点,如果没有则返回 None 相似节点,如果没有则返回 None
""" """
@@ -414,14 +413,14 @@ class MemoryBuilder:
async def _find_similar_object( async def _find_similar_object(
self, content: str, embedding: np.ndarray self, content: str, embedding: np.ndarray
) -> Optional[MemoryNode]: ) -> MemoryNode | None:
""" """
查找相似的客体节点(基于语义相似度) 查找相似的客体节点(基于语义相似度)
Args: Args:
content: 内容 content: 内容
embedding: 嵌入向量 embedding: 嵌入向量
Returns: Returns:
相似节点,如果没有则返回 None 相似节点,如果没有则返回 None
""" """
@@ -480,13 +479,13 @@ class MemoryBuilder:
) -> MemoryEdge: ) -> MemoryEdge:
""" """
关联两个记忆(创建因果或引用边) 关联两个记忆(创建因果或引用边)
Args: Args:
source_memory: 源记忆 source_memory: 源记忆
target_memory: 目标记忆 target_memory: 目标记忆
relation_type: 关系类型(如 "导致", "引用" relation_type: 关系类型(如 "导致", "引用"
importance: 重要性 importance: 重要性
Returns: Returns:
创建的边 创建的边
""" """
@@ -525,7 +524,7 @@ class MemoryBuilder:
logger.error(f"记忆关联失败: {e}", exc_info=True) logger.error(f"记忆关联失败: {e}", exc_info=True)
raise RuntimeError(f"记忆关联失败: {e}") raise RuntimeError(f"记忆关联失败: {e}")
def _find_topic_node(self, memory: Memory) -> Optional[MemoryNode]: def _find_topic_node(self, memory: Memory) -> MemoryNode | None:
"""查找记忆中的主题节点""" """查找记忆中的主题节点"""
for node in memory.nodes: for node in memory.nodes:
if node.node_type == NodeType.TOPIC: if node.node_type == NodeType.TOPIC:

View File

@@ -5,7 +5,7 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime from datetime import datetime
from typing import Any, Dict, Optional from typing import Any
from src.common.logger import get_logger from src.common.logger import get_logger
from src.memory_graph.models import MemoryType from src.memory_graph.models import MemoryType
@@ -17,7 +17,7 @@ logger = get_logger(__name__)
class MemoryExtractor: class MemoryExtractor:
""" """
记忆提取器 记忆提取器
负责: 负责:
1. 从工具调用参数中提取记忆元素 1. 从工具调用参数中提取记忆元素
2. 验证参数完整性和有效性 2. 验证参数完整性和有效性
@@ -25,19 +25,19 @@ class MemoryExtractor:
4. 清洗和格式化数据 4. 清洗和格式化数据
""" """
def __init__(self, time_parser: Optional[TimeParser] = None): def __init__(self, time_parser: TimeParser | None = None):
""" """
初始化记忆提取器 初始化记忆提取器
Args: Args:
time_parser: 时间解析器(可选) time_parser: 时间解析器(可选)
""" """
self.time_parser = time_parser or TimeParser() self.time_parser = time_parser or TimeParser()
def extract_from_tool_params(self, params: Dict[str, Any]) -> Dict[str, Any]: def extract_from_tool_params(self, params: dict[str, Any]) -> dict[str, Any]:
""" """
从工具参数中提取记忆元素 从工具参数中提取记忆元素
Args: Args:
params: 工具调用参数,例如: params: 工具调用参数,例如:
{ {
@@ -48,7 +48,7 @@ class MemoryExtractor:
"attributes": {"时间": "今天", "地点": "家里"}, "attributes": {"时间": "今天", "地点": "家里"},
"importance": 0.3 "importance": 0.3
} }
Returns: Returns:
提取和标准化后的参数字典 提取和标准化后的参数字典
""" """
@@ -64,11 +64,11 @@ class MemoryExtractor:
} }
# 3. 提取可选的客体 # 3. 提取可选的客体
if "object" in params and params["object"]: if params.get("object"):
extracted["object"] = self._clean_text(params["object"]) extracted["object"] = self._clean_text(params["object"])
# 4. 提取和标准化属性 # 4. 提取和标准化属性
if "attributes" in params and params["attributes"]: if params.get("attributes"):
extracted["attributes"] = self._process_attributes(params["attributes"]) extracted["attributes"] = self._process_attributes(params["attributes"])
else: else:
extracted["attributes"] = {} extracted["attributes"] = {}
@@ -86,13 +86,13 @@ class MemoryExtractor:
logger.error(f"记忆提取失败: {e}", exc_info=True) logger.error(f"记忆提取失败: {e}", exc_info=True)
raise ValueError(f"记忆提取失败: {e}") raise ValueError(f"记忆提取失败: {e}")
def _validate_required_params(self, params: Dict[str, Any]) -> None: def _validate_required_params(self, params: dict[str, Any]) -> None:
""" """
验证必需参数 验证必需参数
Args: Args:
params: 参数字典 params: 参数字典
Raises: Raises:
ValueError: 如果缺少必需参数 ValueError: 如果缺少必需参数
""" """
@@ -105,10 +105,10 @@ class MemoryExtractor:
def _clean_text(self, text: Any) -> str: def _clean_text(self, text: Any) -> str:
""" """
清洗文本 清洗文本
Args: Args:
text: 输入文本 text: 输入文本
Returns: Returns:
清洗后的文本 清洗后的文本
""" """
@@ -128,13 +128,13 @@ class MemoryExtractor:
def _parse_memory_type(self, type_str: str) -> MemoryType: def _parse_memory_type(self, type_str: str) -> MemoryType:
""" """
解析记忆类型 解析记忆类型
Args: Args:
type_str: 类型字符串 type_str: 类型字符串
Returns: Returns:
MemoryType 枚举 MemoryType 枚举
Raises: Raises:
ValueError: 如果类型无效 ValueError: 如果类型无效
""" """
@@ -166,10 +166,10 @@ class MemoryExtractor:
def _parse_importance(self, importance: Any) -> float: def _parse_importance(self, importance: Any) -> float:
""" """
解析重要性值 解析重要性值
Args: Args:
importance: 重要性值(可以是数字、字符串等) importance: 重要性值(可以是数字、字符串等)
Returns: Returns:
0-1之间的浮点数 0-1之间的浮点数
""" """
@@ -181,13 +181,13 @@ class MemoryExtractor:
logger.warning(f"无效的重要性值: {importance},使用默认值 0.5") logger.warning(f"无效的重要性值: {importance},使用默认值 0.5")
return 0.5 return 0.5
def _process_attributes(self, attributes: Dict[str, Any]) -> Dict[str, Any]: def _process_attributes(self, attributes: dict[str, Any]) -> dict[str, Any]:
""" """
处理属性字典 处理属性字典
Args: Args:
attributes: 原始属性字典 attributes: 原始属性字典
Returns: Returns:
处理后的属性字典 处理后的属性字典
""" """
@@ -222,10 +222,10 @@ class MemoryExtractor:
return processed return processed
def extract_link_params(self, params: Dict[str, Any]) -> Dict[str, Any]: def extract_link_params(self, params: dict[str, Any]) -> dict[str, Any]:
""" """
提取记忆关联参数(用于 link_memories 工具) 提取记忆关联参数(用于 link_memories 工具)
Args: Args:
params: 工具参数,例如: params: 工具参数,例如:
{ {
@@ -234,7 +234,7 @@ class MemoryExtractor:
"relation_type": "导致", "relation_type": "导致",
"importance": 0.6 "importance": 0.6
} }
Returns: Returns:
提取后的参数 提取后的参数
""" """
@@ -266,10 +266,10 @@ class MemoryExtractor:
def validate_relation_type(self, relation_type: str) -> str: def validate_relation_type(self, relation_type: str) -> str:
""" """
验证关系类型 验证关系类型
Args: Args:
relation_type: 关系类型字符串 relation_type: 关系类型字符串
Returns: Returns:
标准化的关系类型 标准化的关系类型
""" """

View File

@@ -4,11 +4,6 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass
from typing import List, Optional, Tuple
import numpy as np
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.official_configs import MemoryConfig from src.config.official_configs import MemoryConfig
from src.memory_graph.models import MemoryNode, NodeType from src.memory_graph.models import MemoryNode, NodeType
@@ -21,7 +16,7 @@ logger = get_logger(__name__)
class NodeMerger: class NodeMerger:
""" """
节点合并器 节点合并器
负责: 负责:
1. 基于语义相似度查找重复节点 1. 基于语义相似度查找重复节点
2. 验证上下文匹配 2. 验证上下文匹配
@@ -36,7 +31,7 @@ class NodeMerger:
): ):
""" """
初始化节点合并器 初始化节点合并器
Args: Args:
vector_store: 向量存储 vector_store: 向量存储
graph_store: 图存储 graph_store: 图存储
@@ -54,17 +49,17 @@ class NodeMerger:
async def find_similar_nodes( async def find_similar_nodes(
self, self,
node: MemoryNode, node: MemoryNode,
threshold: Optional[float] = None, threshold: float | None = None,
limit: int = 5, limit: int = 5,
) -> List[Tuple[MemoryNode, float]]: ) -> list[tuple[MemoryNode, float]]:
""" """
查找与指定节点相似的节点 查找与指定节点相似的节点
Args: Args:
node: 查询节点 node: 查询节点
threshold: 相似度阈值(可选,默认使用配置值) threshold: 相似度阈值(可选,默认使用配置值)
limit: 返回结果数量 limit: 返回结果数量
Returns: Returns:
List of (similar_node, similarity) List of (similar_node, similarity)
""" """
@@ -112,12 +107,12 @@ class NodeMerger:
) -> bool: ) -> bool:
""" """
判断两个节点是否应该合并 判断两个节点是否应该合并
Args: Args:
source_node: 源节点 source_node: 源节点
target_node: 目标节点 target_node: 目标节点
similarity: 语义相似度 similarity: 语义相似度
Returns: Returns:
是否应该合并 是否应该合并
""" """
@@ -157,16 +152,16 @@ class NodeMerger:
) -> bool: ) -> bool:
""" """
检查两个节点的上下文是否匹配 检查两个节点的上下文是否匹配
上下文匹配的标准: 上下文匹配的标准:
1. 节点类型相同 1. 节点类型相同
2. 邻居节点有重叠 2. 邻居节点有重叠
3. 邻居节点的内容相似 3. 邻居节点的内容相似
Args: Args:
source_node: 源节点 source_node: 源节点
target_node: 目标节点 target_node: 目标节点
Returns: Returns:
是否匹配 是否匹配
""" """
@@ -207,7 +202,7 @@ class NodeMerger:
# 如果有 30% 以上的邻居重叠,认为上下文匹配 # 如果有 30% 以上的邻居重叠,认为上下文匹配
return overlap_ratio > 0.3 return overlap_ratio > 0.3
def _get_node_content(self, node_id: str) -> Optional[str]: def _get_node_content(self, node_id: str) -> str | None:
"""获取节点的内容""" """获取节点的内容"""
memories = self.graph_store.get_memories_by_node(node_id) memories = self.graph_store.get_memories_by_node(node_id)
if memories: if memories:
@@ -223,13 +218,13 @@ class NodeMerger:
) -> bool: ) -> bool:
""" """
合并两个节点 合并两个节点
将 source 节点的所有边转移到 target 节点,然后删除 source 将 source 节点的所有边转移到 target 节点,然后删除 source
Args: Args:
source: 源节点(将被删除) source: 源节点(将被删除)
target: 目标节点(保留) target: 目标节点(保留)
Returns: Returns:
是否成功 是否成功
""" """
@@ -255,7 +250,7 @@ class NodeMerger:
def _update_memory_references(self, old_node_id: str, new_node_id: str) -> None: def _update_memory_references(self, old_node_id: str, new_node_id: str) -> None:
""" """
更新记忆中的节点引用 更新记忆中的节点引用
Args: Args:
old_node_id: 旧节点ID old_node_id: 旧节点ID
new_node_id: 新节点ID new_node_id: 新节点ID
@@ -280,16 +275,16 @@ class NodeMerger:
async def batch_merge_similar_nodes( async def batch_merge_similar_nodes(
self, self,
nodes: List[MemoryNode], nodes: list[MemoryNode],
progress_callback: Optional[callable] = None, progress_callback: callable | None = None,
) -> dict: ) -> dict:
""" """
批量处理节点合并 批量处理节点合并
Args: Args:
nodes: 要处理的节点列表 nodes: 要处理的节点列表
progress_callback: 进度回调函数 progress_callback: 进度回调函数
Returns: Returns:
统计信息字典 统计信息字典
""" """
@@ -344,14 +339,14 @@ class NodeMerger:
self, self,
min_similarity: float = 0.85, min_similarity: float = 0.85,
limit: int = 100, limit: int = 100,
) -> List[Tuple[str, str, float]]: ) -> list[tuple[str, str, float]]:
""" """
获取待合并的候选节点对 获取待合并的候选节点对
Args: Args:
min_similarity: 最小相似度 min_similarity: 最小相似度
limit: 最大返回数量 limit: 最大返回数量
Returns: Returns:
List of (node_id_1, node_id_2, similarity) List of (node_id_1, node_id_2, similarity)
""" """

File diff suppressed because it is too large Load Diff

View File

@@ -7,7 +7,6 @@
from __future__ import annotations from __future__ import annotations
from pathlib import Path from pathlib import Path
from typing import Optional
from src.common.logger import get_logger from src.common.logger import get_logger
from src.memory_graph.manager import MemoryManager from src.memory_graph.manager import MemoryManager
@@ -15,56 +14,56 @@ from src.memory_graph.manager import MemoryManager
logger = get_logger(__name__) logger = get_logger(__name__)
# 全局 MemoryManager 实例 # 全局 MemoryManager 实例
_memory_manager: Optional[MemoryManager] = None _memory_manager: MemoryManager | None = None
_initialized: bool = False _initialized: bool = False
async def initialize_memory_manager( async def initialize_memory_manager(
data_dir: Optional[Path | str] = None, data_dir: Path | str | None = None,
) -> Optional[MemoryManager]: ) -> MemoryManager | None:
""" """
初始化全局 MemoryManager 初始化全局 MemoryManager
直接从 global_config.memory 读取配置 直接从 global_config.memory 读取配置
Args: Args:
data_dir: 数据目录(可选,默认从配置读取) data_dir: 数据目录(可选,默认从配置读取)
Returns: Returns:
MemoryManager 实例,如果禁用则返回 None MemoryManager 实例,如果禁用则返回 None
""" """
global _memory_manager, _initialized global _memory_manager, _initialized
if _initialized and _memory_manager: if _initialized and _memory_manager:
logger.info("MemoryManager 已经初始化,返回现有实例") logger.info("MemoryManager 已经初始化,返回现有实例")
return _memory_manager return _memory_manager
try: try:
from src.config.config import global_config from src.config.config import global_config
# 检查是否启用 # 检查是否启用
if not global_config.memory or not getattr(global_config.memory, 'enable', False): if not global_config.memory or not getattr(global_config.memory, "enable", False):
logger.info("记忆图系统已在配置中禁用") logger.info("记忆图系统已在配置中禁用")
_initialized = False _initialized = False
_memory_manager = None _memory_manager = None
return None return None
# 处理数据目录 # 处理数据目录
if data_dir is None: if data_dir is None:
data_dir = getattr(global_config.memory, 'data_dir', 'data/memory_graph') data_dir = getattr(global_config.memory, "data_dir", "data/memory_graph")
if isinstance(data_dir, str): if isinstance(data_dir, str):
data_dir = Path(data_dir) data_dir = Path(data_dir)
logger.info(f"正在初始化全局 MemoryManager (data_dir={data_dir})...") logger.info(f"正在初始化全局 MemoryManager (data_dir={data_dir})...")
_memory_manager = MemoryManager(data_dir=data_dir) _memory_manager = MemoryManager(data_dir=data_dir)
await _memory_manager.initialize() await _memory_manager.initialize()
_initialized = True _initialized = True
logger.info("✅ 全局 MemoryManager 初始化成功") logger.info("✅ 全局 MemoryManager 初始化成功")
return _memory_manager return _memory_manager
except Exception as e: except Exception as e:
logger.error(f"初始化 MemoryManager 失败: {e}", exc_info=True) logger.error(f"初始化 MemoryManager 失败: {e}", exc_info=True)
_initialized = False _initialized = False
@@ -72,24 +71,24 @@ async def initialize_memory_manager(
raise raise
def get_memory_manager() -> Optional[MemoryManager]: def get_memory_manager() -> MemoryManager | None:
""" """
获取全局 MemoryManager 实例 获取全局 MemoryManager 实例
Returns: Returns:
MemoryManager 实例,如果未初始化则返回 None MemoryManager 实例,如果未初始化则返回 None
""" """
if not _initialized or _memory_manager is None: if not _initialized or _memory_manager is None:
logger.warning("MemoryManager 尚未初始化,请先调用 initialize_memory_manager()") logger.warning("MemoryManager 尚未初始化,请先调用 initialize_memory_manager()")
return None return None
return _memory_manager return _memory_manager
async def shutdown_memory_manager(): async def shutdown_memory_manager():
"""关闭全局 MemoryManager""" """关闭全局 MemoryManager"""
global _memory_manager, _initialized global _memory_manager, _initialized
if _memory_manager: if _memory_manager:
try: try:
logger.info("正在关闭全局 MemoryManager...") logger.info("正在关闭全局 MemoryManager...")

View File

@@ -10,7 +10,7 @@ import uuid
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional from typing import Any
import numpy as np import numpy as np
@@ -60,8 +60,8 @@ class MemoryNode:
id: str # 节点唯一ID id: str # 节点唯一ID
content: str # 节点内容(如:"我"、"吃饭"、"白米饭" content: str # 节点内容(如:"我"、"吃饭"、"白米饭"
node_type: NodeType # 节点类型 node_type: NodeType # 节点类型
embedding: Optional[np.ndarray] = None # 语义向量(仅主题/客体需要) embedding: np.ndarray | None = None # 语义向量(仅主题/客体需要)
metadata: Dict[str, Any] = field(default_factory=dict) # 扩展元数据 metadata: dict[str, Any] = field(default_factory=dict) # 扩展元数据
created_at: datetime = field(default_factory=datetime.now) created_at: datetime = field(default_factory=datetime.now)
def __post_init__(self): def __post_init__(self):
@@ -69,7 +69,7 @@ class MemoryNode:
if not self.id: if not self.id:
self.id = str(uuid.uuid4()) self.id = str(uuid.uuid4())
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> dict[str, Any]:
"""转换为字典(用于序列化)""" """转换为字典(用于序列化)"""
return { return {
"id": self.id, "id": self.id,
@@ -81,7 +81,7 @@ class MemoryNode:
} }
@classmethod @classmethod
def from_dict(cls, data: Dict[str, Any]) -> MemoryNode: def from_dict(cls, data: dict[str, Any]) -> MemoryNode:
"""从字典创建节点""" """从字典创建节点"""
embedding = None embedding = None
if data.get("embedding") is not None: if data.get("embedding") is not None:
@@ -114,7 +114,7 @@ class MemoryEdge:
relation: str # 关系名称(如:"是"、"做"、"时间"、"因为" relation: str # 关系名称(如:"是"、"做"、"时间"、"因为"
edge_type: EdgeType # 边类型 edge_type: EdgeType # 边类型
importance: float = 0.5 # 重要性 [0-1] importance: float = 0.5 # 重要性 [0-1]
metadata: Dict[str, Any] = field(default_factory=dict) # 扩展元数据 metadata: dict[str, Any] = field(default_factory=dict) # 扩展元数据
created_at: datetime = field(default_factory=datetime.now) created_at: datetime = field(default_factory=datetime.now)
def __post_init__(self): def __post_init__(self):
@@ -124,7 +124,7 @@ class MemoryEdge:
# 确保重要性在有效范围内 # 确保重要性在有效范围内
self.importance = max(0.0, min(1.0, self.importance)) self.importance = max(0.0, min(1.0, self.importance))
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> dict[str, Any]:
"""转换为字典(用于序列化)""" """转换为字典(用于序列化)"""
return { return {
"id": self.id, "id": self.id,
@@ -138,7 +138,7 @@ class MemoryEdge:
} }
@classmethod @classmethod
def from_dict(cls, data: Dict[str, Any]) -> MemoryEdge: def from_dict(cls, data: dict[str, Any]) -> MemoryEdge:
"""从字典创建边""" """从字典创建边"""
return cls( return cls(
id=data["id"], id=data["id"],
@@ -162,8 +162,8 @@ class Memory:
id: str # 记忆唯一ID id: str # 记忆唯一ID
subject_id: str # 主体节点ID subject_id: str # 主体节点ID
memory_type: MemoryType # 记忆类型 memory_type: MemoryType # 记忆类型
nodes: List[MemoryNode] # 该记忆包含的所有节点 nodes: list[MemoryNode] # 该记忆包含的所有节点
edges: List[MemoryEdge] # 该记忆包含的所有边 edges: list[MemoryEdge] # 该记忆包含的所有边
importance: float = 0.5 # 整体重要性 [0-1] importance: float = 0.5 # 整体重要性 [0-1]
activation: float = 0.0 # 激活度 [0-1],用于记忆整合和遗忘 activation: float = 0.0 # 激活度 [0-1],用于记忆整合和遗忘
status: MemoryStatus = MemoryStatus.STAGED # 记忆状态 status: MemoryStatus = MemoryStatus.STAGED # 记忆状态
@@ -171,7 +171,7 @@ class Memory:
last_accessed: datetime = field(default_factory=datetime.now) # 最后访问时间 last_accessed: datetime = field(default_factory=datetime.now) # 最后访问时间
access_count: int = 0 # 访问次数 access_count: int = 0 # 访问次数
decay_factor: float = 1.0 # 衰减因子(随时间变化) decay_factor: float = 1.0 # 衰减因子(随时间变化)
metadata: Dict[str, Any] = field(default_factory=dict) # 扩展元数据 metadata: dict[str, Any] = field(default_factory=dict) # 扩展元数据
def __post_init__(self): def __post_init__(self):
"""后初始化处理""" """后初始化处理"""
@@ -181,7 +181,7 @@ class Memory:
self.importance = max(0.0, min(1.0, self.importance)) self.importance = max(0.0, min(1.0, self.importance))
self.activation = max(0.0, min(1.0, self.activation)) self.activation = max(0.0, min(1.0, self.activation))
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> dict[str, Any]:
"""转换为字典(用于序列化)""" """转换为字典(用于序列化)"""
return { return {
"id": self.id, "id": self.id,
@@ -200,7 +200,7 @@ class Memory:
} }
@classmethod @classmethod
def from_dict(cls, data: Dict[str, Any]) -> Memory: def from_dict(cls, data: dict[str, Any]) -> Memory:
"""从字典创建记忆""" """从字典创建记忆"""
return cls( return cls(
id=data["id"], id=data["id"],
@@ -223,14 +223,14 @@ class Memory:
self.last_accessed = datetime.now() self.last_accessed = datetime.now()
self.access_count += 1 self.access_count += 1
def get_node_by_id(self, node_id: str) -> Optional[MemoryNode]: def get_node_by_id(self, node_id: str) -> MemoryNode | None:
"""根据ID获取节点""" """根据ID获取节点"""
for node in self.nodes: for node in self.nodes:
if node.id == node_id: if node.id == node_id:
return node return node
return None return None
def get_subject_node(self) -> Optional[MemoryNode]: def get_subject_node(self) -> MemoryNode | None:
"""获取主体节点""" """获取主体节点"""
return self.get_node_by_id(self.subject_id) return self.get_node_by_id(self.subject_id)
@@ -274,10 +274,10 @@ class StagedMemory:
memory: Memory # 原始记忆对象 memory: Memory # 原始记忆对象
status: MemoryStatus = MemoryStatus.STAGED # 状态 status: MemoryStatus = MemoryStatus.STAGED # 状态
created_at: datetime = field(default_factory=datetime.now) created_at: datetime = field(default_factory=datetime.now)
consolidated_at: Optional[datetime] = None # 整理时间 consolidated_at: datetime | None = None # 整理时间
merge_history: List[str] = field(default_factory=list) # 被合并的节点ID列表 merge_history: list[str] = field(default_factory=list) # 被合并的节点ID列表
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> dict[str, Any]:
"""转换为字典""" """转换为字典"""
return { return {
"memory": self.memory.to_dict(), "memory": self.memory.to_dict(),
@@ -288,7 +288,7 @@ class StagedMemory:
} }
@classmethod @classmethod
def from_dict(cls, data: Dict[str, Any]) -> StagedMemory: def from_dict(cls, data: dict[str, Any]) -> StagedMemory:
"""从字典创建临时记忆""" """从字典创建临时记忆"""
return cls( return cls(
memory=Memory.from_dict(data["memory"]), memory=Memory.from_dict(data["memory"]),

View File

@@ -52,16 +52,16 @@ class CreateMemoryTool(BaseTool):
示例:"我最近在学Python想找数据分析的工作" 示例:"我最近在学Python想找数据分析的工作"
→ 调用1{{subject:"[从历史提取真实名字]", memory_type:"事实", topic:"学习", object:"Python", attributes:{{时间:"最近", 状态:"进行中"}}, importance:0.7}} → 调用1{{subject:"[从历史提取真实名字]", memory_type:"事实", topic:"学习", object:"Python", attributes:{{时间:"最近", 状态:"进行中"}}, importance:0.7}}
→ 调用2{{subject:"[从历史提取真实名字]", memory_type:"目标", topic:"求职", object:"数据分析岗位", attributes:{{状态:"计划中"}}, importance:0.8}}""" → 调用2{{subject:"[从历史提取真实名字]", memory_type:"目标", topic:"求职", object:"数据分析岗位", attributes:{{状态:"计划中"}}, importance:0.8}}"""
parameters: ClassVar[list[tuple[str, ToolParamType, str, bool, list[str] | None]]] = [ parameters: ClassVar[list[tuple[str, ToolParamType, str, bool, list[str] | None]]] = [
("subject", ToolParamType.STRING, "记忆主体(重要!)。从对话历史中提取真实发送人名字。示例:如果看到'Prou(12345678): 我喜欢...'subject应填'Prou';如果看到'张三: 我在...'subject应填'张三'。❌禁止使用'用户'这种泛指,必须用具体名字!", True, None), ("subject", ToolParamType.STRING, "记忆主体(重要!)。从对话历史中提取真实发送人名字。示例:如果看到'Prou(12345678): 我喜欢...'subject应填'Prou';如果看到'张三: 我在...'subject应填'张三'。❌禁止使用'用户'这种泛指,必须用具体名字!", True, None),
("memory_type", ToolParamType.STRING, "记忆类型。【事件】=有明确时间点的动作(昨天吃饭、明天开会)【事实】=稳定状态(职业是程序员、住在北京)【观点】=主观看法(喜欢/讨厌/认为)【关系】=人际关系(朋友、同事)", True, ["事件", "事实", "关系", "观点"]), ("memory_type", ToolParamType.STRING, "记忆类型。【事件】=有明确时间点的动作(昨天吃饭、明天开会)【事实】=稳定状态(职业是程序员、住在北京)【观点】=主观看法(喜欢/讨厌/认为)【关系】=人际关系(朋友、同事)", True, ["事件", "事实", "关系", "观点"]),
("topic", ToolParamType.STRING, "记忆的核心内容(做什么/是什么状态/什么关系)。必须明确、具体,包含关键动词或状态词", True, None), ("topic", ToolParamType.STRING, "记忆的核心内容(做什么/是什么状态/什么关系)。必须明确、具体,包含关键动词或状态词", True, None),
("object", ToolParamType.STRING, "记忆涉及的对象或目标。如果topic已经很完整可以不填如果有明确对象建议填写", False, None), ("object", ToolParamType.STRING, "记忆涉及的对象或目标。如果topic已经很完整可以不填如果有明确对象建议填写", False, None),
("attributes", ToolParamType.STRING, "详细属性JSON格式字符串。强烈建议包含时间具体到日期和小时分钟、地点、状态、原因等上下文信息。例{\"时间\":\"2025-11-06 12:00\",\"地点\":\"公司\",\"状态\":\"进行中\",\"原因\":\"项目需要\"}", False, None), ("attributes", ToolParamType.STRING, '详细属性JSON格式字符串。强烈建议包含时间具体到日期和小时分钟、地点、状态、原因等上下文信息。例{"时间":"2025-11-06 12:00","地点":"公司","状态":"进行中","原因":"项目需要"}', False, None),
("importance", ToolParamType.FLOAT, "重要性评分 0.0-1.0。参考日常琐事0.3-0.4一般对话0.5-0.6重要信息0.7-0.8核心记忆0.9-1.0。不确定时用0.5", False, None), ("importance", ToolParamType.FLOAT, "重要性评分 0.0-1.0。参考日常琐事0.3-0.4一般对话0.5-0.6重要信息0.7-0.8核心记忆0.9-1.0。不确定时用0.5", False, None),
] ]
available_for_llm = True available_for_llm = True
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
@@ -69,20 +69,20 @@ class CreateMemoryTool(BaseTool):
try: try:
# 获取全局 memory_manager # 获取全局 memory_manager
from src.memory_graph.manager_singleton import get_memory_manager from src.memory_graph.manager_singleton import get_memory_manager
manager = get_memory_manager() manager = get_memory_manager()
if not manager: if not manager:
return { return {
"name": self.name, "name": self.name,
"content": "记忆系统未初始化" "content": "记忆系统未初始化"
} }
# 提取参数 # 提取参数
subject = function_args.get("subject", "") subject = function_args.get("subject", "")
memory_type = function_args.get("memory_type", "") memory_type = function_args.get("memory_type", "")
topic = function_args.get("topic", "") topic = function_args.get("topic", "")
obj = function_args.get("object") obj = function_args.get("object")
# 处理 attributes可能是字符串或字典 # 处理 attributes可能是字符串或字典
attributes_raw = function_args.get("attributes", {}) attributes_raw = function_args.get("attributes", {})
if isinstance(attributes_raw, str): if isinstance(attributes_raw, str):
@@ -93,9 +93,9 @@ class CreateMemoryTool(BaseTool):
attributes = {} attributes = {}
else: else:
attributes = attributes_raw attributes = attributes_raw
importance = function_args.get("importance", 0.5) importance = function_args.get("importance", 0.5)
# 创建记忆 # 创建记忆
memory = await manager.create_memory( memory = await manager.create_memory(
subject=subject, subject=subject,
@@ -105,7 +105,7 @@ class CreateMemoryTool(BaseTool):
attributes=attributes, attributes=attributes,
importance=importance, importance=importance,
) )
if memory: if memory:
logger.info(f"[CreateMemoryTool] 成功创建记忆: {memory.id}") logger.info(f"[CreateMemoryTool] 成功创建记忆: {memory.id}")
return { return {
@@ -119,12 +119,12 @@ class CreateMemoryTool(BaseTool):
"content": "创建记忆失败", "content": "创建记忆失败",
"memory_id": None, "memory_id": None,
} }
except Exception as e: except Exception as e:
logger.error(f"[CreateMemoryTool] 执行失败: {e}", exc_info=True) logger.error(f"[CreateMemoryTool] 执行失败: {e}", exc_info=True)
return { return {
"name": self.name, "name": self.name,
"content": f"创建记忆时出错: {str(e)}" "content": f"创建记忆时出错: {e!s}"
} }
@@ -133,33 +133,33 @@ class LinkMemoriesTool(BaseTool):
name = "link_memories" name = "link_memories"
description = "在两个记忆之间建立关联关系。用于连接相关的记忆,形成知识网络。" description = "在两个记忆之间建立关联关系。用于连接相关的记忆,形成知识网络。"
parameters: ClassVar[list[tuple[str, ToolParamType, str, bool, list[str] | None]]] = [ parameters: ClassVar[list[tuple[str, ToolParamType, str, bool, list[str] | None]]] = [
("source_query", ToolParamType.STRING, "源记忆的搜索查询(如记忆的主题关键词)", True, None), ("source_query", ToolParamType.STRING, "源记忆的搜索查询(如记忆的主题关键词)", True, None),
("target_query", ToolParamType.STRING, "目标记忆的搜索查询", True, None), ("target_query", ToolParamType.STRING, "目标记忆的搜索查询", True, None),
("relation", ToolParamType.STRING, "关系类型", True, ["导致", "引用", "相似", "相反", "部分"]), ("relation", ToolParamType.STRING, "关系类型", True, ["导致", "引用", "相似", "相反", "部分"]),
("strength", ToolParamType.FLOAT, "关系强度0.0-1.0默认0.7", False, None), ("strength", ToolParamType.FLOAT, "关系强度0.0-1.0默认0.7", False, None),
] ]
available_for_llm = False # 暂不对 LLM 开放 available_for_llm = False # 暂不对 LLM 开放
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
"""执行关联记忆""" """执行关联记忆"""
try: try:
from src.memory_graph.manager_singleton import get_memory_manager from src.memory_graph.manager_singleton import get_memory_manager
manager = get_memory_manager() manager = get_memory_manager()
if not manager: if not manager:
return { return {
"name": self.name, "name": self.name,
"content": "记忆系统未初始化" "content": "记忆系统未初始化"
} }
source_query = function_args.get("source_query", "") source_query = function_args.get("source_query", "")
target_query = function_args.get("target_query", "") target_query = function_args.get("target_query", "")
relation = function_args.get("relation", "引用") relation = function_args.get("relation", "引用")
strength = function_args.get("strength", 0.7) strength = function_args.get("strength", 0.7)
# 关联记忆 # 关联记忆
success = await manager.link_memories( success = await manager.link_memories(
source_description=source_query, source_description=source_query,
@@ -167,7 +167,7 @@ class LinkMemoriesTool(BaseTool):
relation_type=relation, relation_type=relation,
importance=strength, importance=strength,
) )
if success: if success:
logger.info(f"[LinkMemoriesTool] 成功关联记忆: {source_query} -> {target_query}") logger.info(f"[LinkMemoriesTool] 成功关联记忆: {source_query} -> {target_query}")
return { return {
@@ -179,12 +179,12 @@ class LinkMemoriesTool(BaseTool):
"name": self.name, "name": self.name,
"content": "关联记忆失败,可能找不到匹配的记忆" "content": "关联记忆失败,可能找不到匹配的记忆"
} }
except Exception as e: except Exception as e:
logger.error(f"[LinkMemoriesTool] 执行失败: {e}", exc_info=True) logger.error(f"[LinkMemoriesTool] 执行失败: {e}", exc_info=True)
return { return {
"name": self.name, "name": self.name,
"content": f"关联记忆时出错: {str(e)}" "content": f"关联记忆时出错: {e!s}"
} }
@@ -193,39 +193,39 @@ class SearchMemoriesTool(BaseTool):
name = "search_memories" name = "search_memories"
description = "搜索相关的记忆。根据查询词搜索记忆库,返回最相关的记忆。" description = "搜索相关的记忆。根据查询词搜索记忆库,返回最相关的记忆。"
parameters: ClassVar[list[tuple[str, ToolParamType, str, bool, list[str] | None]]] = [ parameters: ClassVar[list[tuple[str, ToolParamType, str, bool, list[str] | None]]] = [
("query", ToolParamType.STRING, "搜索查询词,描述想要找什么样的记忆", True, None), ("query", ToolParamType.STRING, "搜索查询词,描述想要找什么样的记忆", True, None),
("top_k", ToolParamType.INTEGER, "返回的记忆数量默认5", False, None), ("top_k", ToolParamType.INTEGER, "返回的记忆数量默认5", False, None),
("min_importance", ToolParamType.FLOAT, "最低重要性阈值0.0-1.0),只返回重要性不低于此值的记忆", False, None), ("min_importance", ToolParamType.FLOAT, "最低重要性阈值0.0-1.0),只返回重要性不低于此值的记忆", False, None),
] ]
available_for_llm = False # 暂不对 LLM 开放,记忆检索在提示词构建时自动执行 available_for_llm = False # 暂不对 LLM 开放,记忆检索在提示词构建时自动执行
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
"""执行搜索记忆""" """执行搜索记忆"""
try: try:
from src.memory_graph.manager_singleton import get_memory_manager from src.memory_graph.manager_singleton import get_memory_manager
manager = get_memory_manager() manager = get_memory_manager()
if not manager: if not manager:
return { return {
"name": self.name, "name": self.name,
"content": "记忆系统未初始化" "content": "记忆系统未初始化"
} }
query = function_args.get("query", "") query = function_args.get("query", "")
top_k = function_args.get("top_k", 5) top_k = function_args.get("top_k", 5)
min_importance_raw = function_args.get("min_importance") min_importance_raw = function_args.get("min_importance")
min_importance = float(min_importance_raw) if min_importance_raw is not None else 0.0 min_importance = float(min_importance_raw) if min_importance_raw is not None else 0.0
# 搜索记忆 # 搜索记忆
memories = await manager.search_memories( memories = await manager.search_memories(
query=query, query=query,
top_k=top_k, top_k=top_k,
min_importance=min_importance, min_importance=min_importance,
) )
if memories: if memories:
# 格式化结果 # 格式化结果
result_lines = [f"找到 {len(memories)} 条相关记忆:\n"] result_lines = [f"找到 {len(memories)} 条相关记忆:\n"]
@@ -236,10 +236,10 @@ class SearchMemoriesTool(BaseTool):
result_lines.append( result_lines.append(
f"{i}. [{mem_type}] {topic} (重要性: {importance:.2f})" f"{i}. [{mem_type}] {topic} (重要性: {importance:.2f})"
) )
result_text = "\n".join(result_lines) result_text = "\n".join(result_lines)
logger.info(f"[SearchMemoriesTool] 搜索成功: 查询='{query}', 结果数={len(memories)}") logger.info(f"[SearchMemoriesTool] 搜索成功: 查询='{query}', 结果数={len(memories)}")
return { return {
"name": self.name, "name": self.name,
"content": result_text "content": result_text
@@ -249,10 +249,10 @@ class SearchMemoriesTool(BaseTool):
"name": self.name, "name": self.name,
"content": f"未找到与 '{query}' 相关的记忆" "content": f"未找到与 '{query}' 相关的记忆"
} }
except Exception as e: except Exception as e:
logger.error(f"[SearchMemoriesTool] 执行失败: {e}", exc_info=True) logger.error(f"[SearchMemoriesTool] 执行失败: {e}", exc_info=True)
return { return {
"name": self.name, "name": self.name,
"content": f"搜索记忆时出错: {str(e)}" "content": f"搜索记忆时出错: {e!s}"
} }

View File

@@ -5,4 +5,4 @@
from src.memory_graph.storage.graph_store import GraphStore from src.memory_graph.storage.graph_store import GraphStore
from src.memory_graph.storage.vector_store import VectorStore from src.memory_graph.storage.vector_store import VectorStore
__all__ = ["VectorStore", "GraphStore"] __all__ = ["GraphStore", "VectorStore"]

View File

@@ -4,12 +4,10 @@
from __future__ import annotations from __future__ import annotations
from typing import Dict, List, Optional, Set, Tuple
import networkx as nx import networkx as nx
from src.common.logger import get_logger from src.common.logger import get_logger
from src.memory_graph.models import Memory, MemoryEdge, MemoryNode from src.memory_graph.models import Memory, MemoryEdge
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -17,7 +15,7 @@ logger = get_logger(__name__)
class GraphStore: class GraphStore:
""" """
图存储封装类 图存储封装类
负责: 负责:
1. 记忆图的构建和维护 1. 记忆图的构建和维护
2. 节点和边的快速查询 2. 节点和边的快速查询
@@ -31,17 +29,17 @@ class GraphStore:
self.graph = nx.DiGraph() self.graph = nx.DiGraph()
# 索引记忆ID -> 记忆对象 # 索引记忆ID -> 记忆对象
self.memory_index: Dict[str, Memory] = {} self.memory_index: dict[str, Memory] = {}
# 索引节点ID -> 所属记忆ID集合 # 索引节点ID -> 所属记忆ID集合
self.node_to_memories: Dict[str, Set[str]] = {} self.node_to_memories: dict[str, set[str]] = {}
logger.info("初始化图存储") logger.info("初始化图存储")
def add_memory(self, memory: Memory) -> None: def add_memory(self, memory: Memory) -> None:
""" """
添加记忆到图 添加记忆到图
Args: Args:
memory: 要添加的记忆 memory: 要添加的记忆
""" """
@@ -84,34 +82,34 @@ class GraphStore:
logger.error(f"添加记忆失败: {e}", exc_info=True) logger.error(f"添加记忆失败: {e}", exc_info=True)
raise raise
def get_memory_by_id(self, memory_id: str) -> Optional[Memory]: def get_memory_by_id(self, memory_id: str) -> Memory | None:
""" """
根据ID获取记忆 根据ID获取记忆
Args: Args:
memory_id: 记忆ID memory_id: 记忆ID
Returns: Returns:
记忆对象或 None 记忆对象或 None
""" """
return self.memory_index.get(memory_id) return self.memory_index.get(memory_id)
def get_all_memories(self) -> List[Memory]: def get_all_memories(self) -> list[Memory]:
""" """
获取所有记忆 获取所有记忆
Returns: Returns:
所有记忆的列表 所有记忆的列表
""" """
return list(self.memory_index.values()) return list(self.memory_index.values())
def get_memories_by_node(self, node_id: str) -> List[Memory]: def get_memories_by_node(self, node_id: str) -> list[Memory]:
""" """
获取包含指定节点的所有记忆 获取包含指定节点的所有记忆
Args: Args:
node_id: 节点ID node_id: 节点ID
Returns: Returns:
记忆列表 记忆列表
""" """
@@ -121,14 +119,14 @@ class GraphStore:
memory_ids = self.node_to_memories[node_id] memory_ids = self.node_to_memories[node_id]
return [self.memory_index[mid] for mid in memory_ids if mid in self.memory_index] return [self.memory_index[mid] for mid in memory_ids if mid in self.memory_index]
def get_edges_from_node(self, node_id: str, relation_types: Optional[List[str]] = None) -> List[Dict]: def get_edges_from_node(self, node_id: str, relation_types: list[str] | None = None) -> list[dict]:
""" """
获取从指定节点出发的所有边 获取从指定节点出发的所有边
Args: Args:
node_id: 源节点ID node_id: 源节点ID
relation_types: 关系类型过滤(可选) relation_types: 关系类型过滤(可选)
Returns: Returns:
边信息列表 边信息列表
""" """
@@ -155,16 +153,16 @@ class GraphStore:
return edges return edges
def get_neighbors( def get_neighbors(
self, node_id: str, direction: str = "out", relation_types: Optional[List[str]] = None self, node_id: str, direction: str = "out", relation_types: list[str] | None = None
) -> List[Tuple[str, Dict]]: ) -> list[tuple[str, dict]]:
""" """
获取节点的邻居节点 获取节点的邻居节点
Args: Args:
node_id: 节点ID node_id: 节点ID
direction: 方向 ("out"=出边, "in"=入边, "both"=双向) direction: 方向 ("out"=出边, "in"=入边, "both"=双向)
relation_types: 关系类型过滤 relation_types: 关系类型过滤
Returns: Returns:
List of (neighbor_id, edge_data) List of (neighbor_id, edge_data)
""" """
@@ -187,15 +185,15 @@ class GraphStore:
return neighbors return neighbors
def find_path(self, source_id: str, target_id: str, max_length: Optional[int] = None) -> Optional[List[str]]: def find_path(self, source_id: str, target_id: str, max_length: int | None = None) -> list[str] | None:
""" """
查找两个节点之间的最短路径 查找两个节点之间的最短路径
Args: Args:
source_id: 源节点ID source_id: 源节点ID
target_id: 目标节点ID target_id: 目标节点ID
max_length: 最大路径长度(可选) max_length: 最大路径长度(可选)
Returns: Returns:
路径节点ID列表或 None如果不存在路径 路径节点ID列表或 None如果不存在路径
""" """
@@ -220,18 +218,18 @@ class GraphStore:
def bfs_expand( def bfs_expand(
self, self,
start_nodes: List[str], start_nodes: list[str],
depth: int = 1, depth: int = 1,
relation_types: Optional[List[str]] = None, relation_types: list[str] | None = None,
) -> Set[str]: ) -> set[str]:
""" """
从起始节点进行广度优先搜索扩展 从起始节点进行广度优先搜索扩展
Args: Args:
start_nodes: 起始节点ID列表 start_nodes: 起始节点ID列表
depth: 扩展深度 depth: 扩展深度
relation_types: 关系类型过滤 relation_types: 关系类型过滤
Returns: Returns:
扩展到的所有节点ID集合 扩展到的所有节点ID集合
""" """
@@ -256,13 +254,13 @@ class GraphStore:
return visited return visited
def get_subgraph(self, node_ids: List[str]) -> nx.DiGraph: def get_subgraph(self, node_ids: list[str]) -> nx.DiGraph:
""" """
获取包含指定节点的子图 获取包含指定节点的子图
Args: Args:
node_ids: 节点ID列表 node_ids: 节点ID列表
Returns: Returns:
NetworkX 子图 NetworkX 子图
""" """
@@ -271,7 +269,7 @@ class GraphStore:
def merge_nodes(self, source_id: str, target_id: str) -> None: def merge_nodes(self, source_id: str, target_id: str) -> None:
""" """
合并两个节点将source的所有边转移到target然后删除source 合并两个节点将source的所有边转移到target然后删除source
Args: Args:
source_id: 源节点ID将被删除 source_id: 源节点ID将被删除
target_id: 目标节点ID保留 target_id: 目标节点ID保留
@@ -308,13 +306,13 @@ class GraphStore:
logger.error(f"合并节点失败: {e}", exc_info=True) logger.error(f"合并节点失败: {e}", exc_info=True)
raise raise
def get_node_degree(self, node_id: str) -> Tuple[int, int]: def get_node_degree(self, node_id: str) -> tuple[int, int]:
""" """
获取节点的度数 获取节点的度数
Args: Args:
node_id: 节点ID node_id: 节点ID
Returns: Returns:
(in_degree, out_degree) (in_degree, out_degree)
""" """
@@ -323,7 +321,7 @@ class GraphStore:
return (self.graph.in_degree(node_id), self.graph.out_degree(node_id)) return (self.graph.in_degree(node_id), self.graph.out_degree(node_id))
def get_statistics(self) -> Dict[str, int]: def get_statistics(self) -> dict[str, int]:
"""获取图的统计信息""" """获取图的统计信息"""
return { return {
"total_nodes": self.graph.number_of_nodes(), "total_nodes": self.graph.number_of_nodes(),
@@ -332,10 +330,10 @@ class GraphStore:
"connected_components": nx.number_weakly_connected_components(self.graph), "connected_components": nx.number_weakly_connected_components(self.graph),
} }
def to_dict(self) -> Dict: def to_dict(self) -> dict:
""" """
将图转换为字典(用于持久化) 将图转换为字典(用于持久化)
Returns: Returns:
图的字典表示 图的字典表示
""" """
@@ -356,13 +354,13 @@ class GraphStore:
} }
@classmethod @classmethod
def from_dict(cls, data: Dict) -> GraphStore: def from_dict(cls, data: dict) -> GraphStore:
""" """
从字典加载图 从字典加载图
Args: Args:
data: 图的字典表示 data: 图的字典表示
Returns: Returns:
GraphStore 实例 GraphStore 实例
""" """
@@ -406,7 +404,6 @@ class GraphStore:
规则:对于图中每条边(u, v, data),会尝试将该边注入到所有包含 u 或 v 的记忆中(避免遗漏跨记忆边)。 规则:对于图中每条边(u, v, data),会尝试将该边注入到所有包含 u 或 v 的记忆中(避免遗漏跨记忆边)。
已存在的边(通过 edge.id 检查)将不会重复添加。 已存在的边(通过 edge.id 检查)将不会重复添加。
""" """
from src.memory_graph.models import MemoryEdge
# 构建快速查重索引memory_id -> set(edge_id) # 构建快速查重索引memory_id -> set(edge_id)
existing_edges = {mid: {e.id for e in mem.edges} for mid, mem in self.memory_index.items()} existing_edges = {mid: {e.id for e in mem.edges} for mid, mem in self.memory_index.items()}
@@ -465,10 +462,10 @@ class GraphStore:
def remove_memory(self, memory_id: str) -> bool: def remove_memory(self, memory_id: str) -> bool:
""" """
从图中删除指定记忆 从图中删除指定记忆
Args: Args:
memory_id: 要删除的记忆ID memory_id: 要删除的记忆ID
Returns: Returns:
是否删除成功 是否删除成功
""" """
@@ -477,9 +474,9 @@ class GraphStore:
if memory_id not in self.memory_index: if memory_id not in self.memory_index:
logger.warning(f"记忆不存在,无法删除: {memory_id}") logger.warning(f"记忆不存在,无法删除: {memory_id}")
return False return False
memory = self.memory_index[memory_id] memory = self.memory_index[memory_id]
# 2. 从节点映射中移除此记忆 # 2. 从节点映射中移除此记忆
for node in memory.nodes: for node in memory.nodes:
if node.id in self.node_to_memories: if node.id in self.node_to_memories:
@@ -489,13 +486,13 @@ class GraphStore:
if self.graph.has_node(node.id): if self.graph.has_node(node.id):
self.graph.remove_node(node.id) self.graph.remove_node(node.id)
del self.node_to_memories[node.id] del self.node_to_memories[node.id]
# 3. 从记忆索引中移除 # 3. 从记忆索引中移除
del self.memory_index[memory_id] del self.memory_index[memory_id]
logger.info(f"成功删除记忆: {memory_id}") logger.info(f"成功删除记忆: {memory_id}")
return True return True
except Exception as e: except Exception as e:
logger.error(f"删除记忆失败 {memory_id}: {e}", exc_info=True) logger.error(f"删除记忆失败 {memory_id}: {e}", exc_info=True)
return False return False

View File

@@ -8,14 +8,12 @@ import asyncio
import json import json
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Optional
import orjson import orjson
from src.common.logger import get_logger from src.common.logger import get_logger
from src.memory_graph.models import Memory, StagedMemory from src.memory_graph.models import StagedMemory
from src.memory_graph.storage.graph_store import GraphStore from src.memory_graph.storage.graph_store import GraphStore
from src.memory_graph.storage.vector_store import VectorStore
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -23,7 +21,7 @@ logger = get_logger(__name__)
class PersistenceManager: class PersistenceManager:
""" """
持久化管理器 持久化管理器
负责: 负责:
1. 图数据的保存和加载 1. 图数据的保存和加载
2. 定期自动保存 2. 定期自动保存
@@ -39,7 +37,7 @@ class PersistenceManager:
): ):
""" """
初始化持久化管理器 初始化持久化管理器
Args: Args:
data_dir: 数据存储目录 data_dir: 数据存储目录
graph_file_name: 图数据文件名 graph_file_name: 图数据文件名
@@ -55,7 +53,7 @@ class PersistenceManager:
self.backup_dir.mkdir(parents=True, exist_ok=True) self.backup_dir.mkdir(parents=True, exist_ok=True)
self.auto_save_interval = auto_save_interval self.auto_save_interval = auto_save_interval
self._auto_save_task: Optional[asyncio.Task] = None self._auto_save_task: asyncio.Task | None = None
self._running = False self._running = False
logger.info(f"初始化持久化管理器: data_dir={data_dir}") logger.info(f"初始化持久化管理器: data_dir={data_dir}")
@@ -63,7 +61,7 @@ class PersistenceManager:
async def save_graph_store(self, graph_store: GraphStore) -> None: async def save_graph_store(self, graph_store: GraphStore) -> None:
""" """
保存图存储到文件 保存图存储到文件
Args: Args:
graph_store: 图存储对象 graph_store: 图存储对象
""" """
@@ -95,10 +93,10 @@ class PersistenceManager:
logger.error(f"保存图数据失败: {e}", exc_info=True) logger.error(f"保存图数据失败: {e}", exc_info=True)
raise raise
async def load_graph_store(self) -> Optional[GraphStore]: async def load_graph_store(self) -> GraphStore | None:
""" """
从文件加载图存储 从文件加载图存储
Returns: Returns:
GraphStore 对象,如果文件不存在则返回 None GraphStore 对象,如果文件不存在则返回 None
""" """
@@ -129,7 +127,7 @@ class PersistenceManager:
async def save_staged_memories(self, staged_memories: list[StagedMemory]) -> None: async def save_staged_memories(self, staged_memories: list[StagedMemory]) -> None:
""" """
保存临时记忆列表 保存临时记忆列表
Args: Args:
staged_memories: 临时记忆列表 staged_memories: 临时记忆列表
""" """
@@ -158,7 +156,7 @@ class PersistenceManager:
async def load_staged_memories(self) -> list[StagedMemory]: async def load_staged_memories(self) -> list[StagedMemory]:
""" """
加载临时记忆列表 加载临时记忆列表
Returns: Returns:
临时记忆列表 临时记忆列表
""" """
@@ -179,10 +177,10 @@ class PersistenceManager:
logger.error(f"加载临时记忆失败: {e}", exc_info=True) logger.error(f"加载临时记忆失败: {e}", exc_info=True)
return [] return []
async def create_backup(self) -> Optional[Path]: async def create_backup(self) -> Path | None:
""" """
创建当前数据的备份 创建当前数据的备份
Returns: Returns:
备份文件路径,如果失败则返回 None 备份文件路径,如果失败则返回 None
""" """
@@ -208,7 +206,7 @@ class PersistenceManager:
logger.error(f"创建备份失败: {e}", exc_info=True) logger.error(f"创建备份失败: {e}", exc_info=True)
return None return None
async def _load_from_backup(self) -> Optional[GraphStore]: async def _load_from_backup(self) -> GraphStore | None:
"""从最新的备份加载数据""" """从最新的备份加载数据"""
try: try:
# 查找最新的备份文件 # 查找最新的备份文件
@@ -236,7 +234,7 @@ class PersistenceManager:
async def _cleanup_old_backups(self, keep: int = 10) -> None: async def _cleanup_old_backups(self, keep: int = 10) -> None:
""" """
清理旧备份,只保留最近的几个 清理旧备份,只保留最近的几个
Args: Args:
keep: 保留的备份数量 keep: 保留的备份数量
""" """
@@ -254,11 +252,11 @@ class PersistenceManager:
async def start_auto_save( async def start_auto_save(
self, self,
graph_store: GraphStore, graph_store: GraphStore,
staged_memories_getter: callable = None, staged_memories_getter: callable | None = None,
) -> None: ) -> None:
""" """
启动自动保存任务 启动自动保存任务
Args: Args:
graph_store: 图存储对象 graph_store: 图存储对象
staged_memories_getter: 获取临时记忆的回调函数 staged_memories_getter: 获取临时记忆的回调函数
@@ -310,7 +308,7 @@ class PersistenceManager:
async def export_to_json(self, output_file: Path, graph_store: GraphStore) -> None: async def export_to_json(self, output_file: Path, graph_store: GraphStore) -> None:
""" """
导出图数据到指定的 JSON 文件(用于数据迁移或分析) 导出图数据到指定的 JSON 文件(用于数据迁移或分析)
Args: Args:
output_file: 输出文件路径 output_file: 输出文件路径
graph_store: 图存储对象 graph_store: 图存储对象
@@ -334,13 +332,13 @@ class PersistenceManager:
logger.error(f"导出图数据失败: {e}", exc_info=True) logger.error(f"导出图数据失败: {e}", exc_info=True)
raise raise
async def import_from_json(self, input_file: Path) -> Optional[GraphStore]: async def import_from_json(self, input_file: Path) -> GraphStore | None:
""" """
从 JSON 文件导入图数据 从 JSON 文件导入图数据
Args: Args:
input_file: 输入文件路径 input_file: 输入文件路径
Returns: Returns:
GraphStore 对象 GraphStore 对象
""" """
@@ -360,7 +358,7 @@ class PersistenceManager:
def get_data_size(self) -> dict[str, int]: def get_data_size(self) -> dict[str, int]:
""" """
获取数据文件的大小信息 获取数据文件的大小信息
Returns: Returns:
文件大小字典(字节) 文件大小字典(字节)
""" """

View File

@@ -4,9 +4,8 @@
from __future__ import annotations from __future__ import annotations
import uuid
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple from typing import Any
import numpy as np import numpy as np
@@ -19,7 +18,7 @@ logger = get_logger(__name__)
class VectorStore: class VectorStore:
""" """
向量存储封装类 向量存储封装类
负责: 负责:
1. 节点的语义向量存储和检索 1. 节点的语义向量存储和检索
2. 基于相似度的向量搜索 2. 基于相似度的向量搜索
@@ -29,12 +28,12 @@ class VectorStore:
def __init__( def __init__(
self, self,
collection_name: str = "memory_nodes", collection_name: str = "memory_nodes",
data_dir: Optional[Path] = None, data_dir: Path | None = None,
embedding_function: Optional[Any] = None, embedding_function: Any | None = None,
): ):
""" """
初始化向量存储 初始化向量存储
Args: Args:
collection_name: ChromaDB 集合名称 collection_name: ChromaDB 集合名称
data_dir: 数据存储目录 data_dir: 数据存储目录
@@ -80,7 +79,7 @@ class VectorStore:
async def add_node(self, node: MemoryNode) -> None: async def add_node(self, node: MemoryNode) -> None:
""" """
添加节点到向量存储 添加节点到向量存储
Args: Args:
node: 要添加的节点 node: 要添加的节点
""" """
@@ -98,17 +97,17 @@ class VectorStore:
"node_type": node.node_type.value, "node_type": node.node_type.value,
"created_at": node.created_at.isoformat(), "created_at": node.created_at.isoformat(),
} }
# 处理额外的元数据,将 list 转换为 JSON 字符串 # 处理额外的元数据,将 list 转换为 JSON 字符串
for key, value in node.metadata.items(): for key, value in node.metadata.items():
if isinstance(value, (list, dict)): if isinstance(value, (list, dict)):
import orjson import orjson
metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode('utf-8') metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode("utf-8")
elif isinstance(value, (str, int, float, bool)) or value is None: elif isinstance(value, (str, int, float, bool)) or value is None:
metadata[key] = value metadata[key] = value
else: else:
metadata[key] = str(value) metadata[key] = str(value)
self.collection.add( self.collection.add(
ids=[node.id], ids=[node.id],
embeddings=[node.embedding.tolist()], embeddings=[node.embedding.tolist()],
@@ -122,10 +121,10 @@ class VectorStore:
logger.error(f"添加节点失败: {e}", exc_info=True) logger.error(f"添加节点失败: {e}", exc_info=True)
raise raise
async def add_nodes_batch(self, nodes: List[MemoryNode]) -> None: async def add_nodes_batch(self, nodes: list[MemoryNode]) -> None:
""" """
批量添加节点 批量添加节点
Args: Args:
nodes: 节点列表 nodes: 节点列表
""" """
@@ -151,13 +150,13 @@ class VectorStore:
} }
for key, value in n.metadata.items(): for key, value in n.metadata.items():
if isinstance(value, (list, dict)): if isinstance(value, (list, dict)):
metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode('utf-8') metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode("utf-8")
elif isinstance(value, (str, int, float, bool)) or value is None: elif isinstance(value, (str, int, float, bool)) or value is None:
metadata[key] = value # type: ignore metadata[key] = value # type: ignore
else: else:
metadata[key] = str(value) metadata[key] = str(value)
metadatas.append(metadata) metadatas.append(metadata)
self.collection.add( self.collection.add(
ids=[n.id for n in valid_nodes], ids=[n.id for n in valid_nodes],
embeddings=[n.embedding.tolist() for n in valid_nodes], # type: ignore embeddings=[n.embedding.tolist() for n in valid_nodes], # type: ignore
@@ -175,18 +174,18 @@ class VectorStore:
self, self,
query_embedding: np.ndarray, query_embedding: np.ndarray,
limit: int = 10, limit: int = 10,
node_types: Optional[List[NodeType]] = None, node_types: list[NodeType] | None = None,
min_similarity: float = 0.0, min_similarity: float = 0.0,
) -> List[Tuple[str, float, Dict[str, Any]]]: ) -> list[tuple[str, float, dict[str, Any]]]:
""" """
搜索相似节点 搜索相似节点
Args: Args:
query_embedding: 查询向量 query_embedding: 查询向量
limit: 返回结果数量 limit: 返回结果数量
node_types: 限制节点类型(可选) node_types: 限制节点类型(可选)
min_similarity: 最小相似度阈值 min_similarity: 最小相似度阈值
Returns: Returns:
List of (node_id, similarity, metadata) List of (node_id, similarity, metadata)
""" """
@@ -214,7 +213,7 @@ class VectorStore:
if ids is not None and len(ids) > 0 and len(ids[0]) > 0: if ids is not None and len(ids) > 0 and len(ids[0]) > 0:
distances = results.get("distances") distances = results.get("distances")
metadatas = results.get("metadatas") metadatas = results.get("metadatas")
for i, node_id in enumerate(ids[0]): for i, node_id in enumerate(ids[0]):
# ChromaDB 返回的是距离,需要转换为相似度 # ChromaDB 返回的是距离,需要转换为相似度
# 余弦距离: distance = 1 - similarity # 余弦距离: distance = 1 - similarity
@@ -223,15 +222,15 @@ class VectorStore:
if similarity >= min_similarity: if similarity >= min_similarity:
metadata = metadatas[0][i] if metadatas is not None and len(metadatas) > 0 else {} # type: ignore metadata = metadatas[0][i] if metadatas is not None and len(metadatas) > 0 else {} # type: ignore
# 解析 JSON 字符串回列表/字典 # 解析 JSON 字符串回列表/字典
for key, value in list(metadata.items()): for key, value in list(metadata.items()):
if isinstance(value, str) and (value.startswith('[') or value.startswith('{')): if isinstance(value, str) and (value.startswith("[") or value.startswith("{")):
try: try:
metadata[key] = orjson.loads(value) metadata[key] = orjson.loads(value)
except: except Exception:
pass # 保持原值 pass # 保持原值
similar_nodes.append((node_id, similarity, metadata)) similar_nodes.append((node_id, similarity, metadata))
logger.debug(f"相似节点搜索: 找到 {len(similar_nodes)} 个结果") logger.debug(f"相似节点搜索: 找到 {len(similar_nodes)} 个结果")
@@ -243,19 +242,19 @@ class VectorStore:
async def search_with_multiple_queries( async def search_with_multiple_queries(
self, self,
query_embeddings: List[np.ndarray], query_embeddings: list[np.ndarray],
query_weights: Optional[List[float]] = None, query_weights: list[float] | None = None,
limit: int = 10, limit: int = 10,
node_types: Optional[List[NodeType]] = None, node_types: list[NodeType] | None = None,
min_similarity: float = 0.0, min_similarity: float = 0.0,
fusion_strategy: str = "weighted_max", fusion_strategy: str = "weighted_max",
) -> List[Tuple[str, float, Dict[str, Any]]]: ) -> list[tuple[str, float, dict[str, Any]]]:
""" """
多查询融合搜索 多查询融合搜索
使用多个查询向量进行搜索,然后融合结果。 使用多个查询向量进行搜索,然后融合结果。
这能解决单一查询向量无法同时关注多个关键概念的问题。 这能解决单一查询向量无法同时关注多个关键概念的问题。
Args: Args:
query_embeddings: 查询向量列表 query_embeddings: 查询向量列表
query_weights: 每个查询的权重(可选,默认均等) query_weights: 每个查询的权重(可选,默认均等)
@@ -266,7 +265,7 @@ class VectorStore:
- "weighted_max": 加权最大值(推荐) - "weighted_max": 加权最大值(推荐)
- "weighted_sum": 加权求和 - "weighted_sum": 加权求和
- "rrf": Reciprocal Rank Fusion - "rrf": Reciprocal Rank Fusion
Returns: Returns:
融合后的节点列表 [(node_id, fused_score, metadata), ...] 融合后的节点列表 [(node_id, fused_score, metadata), ...]
""" """
@@ -279,7 +278,7 @@ class VectorStore:
# 默认权重均等 # 默认权重均等
if query_weights is None: if query_weights is None:
query_weights = [1.0 / len(query_embeddings)] * len(query_embeddings) query_weights = [1.0 / len(query_embeddings)] * len(query_embeddings)
# 归一化权重 # 归一化权重
total_weight = sum(query_weights) total_weight = sum(query_weights)
if total_weight > 0: if total_weight > 0:
@@ -287,7 +286,7 @@ class VectorStore:
try: try:
# 1. 对每个查询执行搜索 # 1. 对每个查询执行搜索
all_results: Dict[str, Dict[str, Any]] = {} # node_id -> {scores, metadata} all_results: dict[str, dict[str, Any]] = {} # node_id -> {scores, metadata}
for i, (query_emb, weight) in enumerate(zip(query_embeddings, query_weights)): for i, (query_emb, weight) in enumerate(zip(query_embeddings, query_weights)):
# 搜索更多结果以提高融合质量 # 搜索更多结果以提高融合质量
@@ -307,13 +306,13 @@ class VectorStore:
"ranks": [], "ranks": [],
"metadata": metadata, "metadata": metadata,
} }
all_results[node_id]["scores"].append((similarity, weight)) all_results[node_id]["scores"].append((similarity, weight))
all_results[node_id]["ranks"].append((rank, weight)) all_results[node_id]["ranks"].append((rank, weight))
# 2. 融合分数 # 2. 融合分数
fused_results = [] fused_results = []
for node_id, data in all_results.items(): for node_id, data in all_results.items():
scores = data["scores"] scores = data["scores"]
ranks = data["ranks"] ranks = data["ranks"]
@@ -356,13 +355,13 @@ class VectorStore:
logger.error(f"多查询融合搜索失败: {e}", exc_info=True) logger.error(f"多查询融合搜索失败: {e}", exc_info=True)
raise raise
async def get_node_by_id(self, node_id: str) -> Optional[Dict[str, Any]]: async def get_node_by_id(self, node_id: str) -> dict[str, Any] | None:
""" """
根据ID获取节点元数据 根据ID获取节点元数据
Args: Args:
node_id: 节点ID node_id: 节点ID
Returns: Returns:
节点元数据或 None 节点元数据或 None
""" """
@@ -378,7 +377,7 @@ class VectorStore:
if ids is not None and len(ids) > 0: if ids is not None and len(ids) > 0:
metadatas = result.get("metadatas") metadatas = result.get("metadatas")
embeddings = result.get("embeddings") embeddings = result.get("embeddings")
return { return {
"id": ids[0], "id": ids[0],
"metadata": metadatas[0] if metadatas is not None and len(metadatas) > 0 else {}, "metadata": metadatas[0] if metadatas is not None and len(metadatas) > 0 else {},
@@ -394,7 +393,7 @@ class VectorStore:
async def delete_node(self, node_id: str) -> None: async def delete_node(self, node_id: str) -> None:
""" """
删除节点 删除节点
Args: Args:
node_id: 节点ID node_id: 节点ID
""" """
@@ -412,7 +411,7 @@ class VectorStore:
async def update_node_embedding(self, node_id: str, embedding: np.ndarray) -> None: async def update_node_embedding(self, node_id: str, embedding: np.ndarray) -> None:
""" """
更新节点的 embedding 更新节点的 embedding
Args: Args:
node_id: 节点ID node_id: 节点ID
embedding: 新的向量 embedding: 新的向量

View File

@@ -4,12 +4,12 @@ LLM 工具接口:定义记忆系统的工具 schema 和执行逻辑
from __future__ import annotations from __future__ import annotations
from typing import Any, Dict, List, Optional, Tuple from typing import Any
from src.common.logger import get_logger from src.common.logger import get_logger
from src.memory_graph.core.builder import MemoryBuilder from src.memory_graph.core.builder import MemoryBuilder
from src.memory_graph.core.extractor import MemoryExtractor from src.memory_graph.core.extractor import MemoryExtractor
from src.memory_graph.models import Memory, MemoryStatus from src.memory_graph.models import Memory
from src.memory_graph.storage.graph_store import GraphStore from src.memory_graph.storage.graph_store import GraphStore
from src.memory_graph.storage.persistence import PersistenceManager from src.memory_graph.storage.persistence import PersistenceManager
from src.memory_graph.storage.vector_store import VectorStore from src.memory_graph.storage.vector_store import VectorStore
@@ -21,7 +21,7 @@ logger = get_logger(__name__)
class MemoryTools: class MemoryTools:
""" """
记忆系统工具集 记忆系统工具集
提供给 LLM 使用的工具接口: 提供给 LLM 使用的工具接口:
1. create_memory: 创建新记忆 1. create_memory: 创建新记忆
2. link_memories: 关联两个记忆 2. link_memories: 关联两个记忆
@@ -33,7 +33,7 @@ class MemoryTools:
vector_store: VectorStore, vector_store: VectorStore,
graph_store: GraphStore, graph_store: GraphStore,
persistence_manager: PersistenceManager, persistence_manager: PersistenceManager,
embedding_generator: Optional[EmbeddingGenerator] = None, embedding_generator: EmbeddingGenerator | None = None,
max_expand_depth: int = 1, max_expand_depth: int = 1,
expand_semantic_threshold: float = 0.3, expand_semantic_threshold: float = 0.3,
): ):
@@ -72,10 +72,10 @@ class MemoryTools:
self._initialized = True self._initialized = True
@staticmethod @staticmethod
def get_create_memory_schema() -> Dict[str, Any]: def get_create_memory_schema() -> dict[str, Any]:
""" """
获取 create_memory 工具的 JSON schema 获取 create_memory 工具的 JSON schema
Returns: Returns:
工具 schema 定义 工具 schema 定义
""" """
@@ -145,15 +145,15 @@ class MemoryTools:
"description": "时间信息(强烈建议填写):\n- 具体日期:'2025-11-05''2025年11月'\n- 相对时间:'今天''昨天''上周''最近''3天前'\n- 时间段:'今天下午''上个月''这学期'", "description": "时间信息(强烈建议填写):\n- 具体日期:'2025-11-05''2025年11月'\n- 相对时间:'今天''昨天''上周''最近''3天前'\n- 时间段:'今天下午''上个月''这学期'",
}, },
"地点": { "地点": {
"type": "string", "type": "string",
"description": "地点信息(如涉及):\n- 具体地址、城市名、国家\n- 场所类型:'在家''公司''学校''咖啡店'" "description": "地点信息(如涉及):\n- 具体地址、城市名、国家\n- 场所类型:'在家''公司''学校''咖啡店'"
}, },
"原因": { "原因": {
"type": "string", "type": "string",
"description": "为什么这样做/这样想(如明确提到)" "description": "为什么这样做/这样想(如明确提到)"
}, },
"方式": { "方式": {
"type": "string", "type": "string",
"description": "怎么做的/通过什么方式(如明确提到)" "description": "怎么做的/通过什么方式(如明确提到)"
}, },
"结果": { "结果": {
@@ -183,10 +183,10 @@ class MemoryTools:
} }
@staticmethod @staticmethod
def get_link_memories_schema() -> Dict[str, Any]: def get_link_memories_schema() -> dict[str, Any]:
""" """
获取 link_memories 工具的 JSON schema 获取 link_memories 工具的 JSON schema
Returns: Returns:
工具 schema 定义 工具 schema 定义
""" """
@@ -239,10 +239,10 @@ class MemoryTools:
} }
@staticmethod @staticmethod
def get_search_memories_schema() -> Dict[str, Any]: def get_search_memories_schema() -> dict[str, Any]:
""" """
获取 search_memories 工具的 JSON schema 获取 search_memories 工具的 JSON schema
Returns: Returns:
工具 schema 定义 工具 schema 定义
""" """
@@ -307,13 +307,13 @@ class MemoryTools:
}, },
} }
async def create_memory(self, **params) -> Dict[str, Any]: async def create_memory(self, **params) -> dict[str, Any]:
""" """
执行 create_memory 工具 执行 create_memory 工具
Args: Args:
**params: 工具参数 **params: 工具参数
Returns: Returns:
执行结果 执行结果
""" """
@@ -353,13 +353,13 @@ class MemoryTools:
"message": "记忆创建失败", "message": "记忆创建失败",
} }
async def link_memories(self, **params) -> Dict[str, Any]: async def link_memories(self, **params) -> dict[str, Any]:
""" """
执行 link_memories 工具 执行 link_memories 工具
Args: Args:
**params: 工具参数 **params: 工具参数
Returns: Returns:
执行结果 执行结果
""" """
@@ -433,15 +433,15 @@ class MemoryTools:
"message": "记忆关联失败", "message": "记忆关联失败",
} }
async def search_memories(self, **params) -> Dict[str, Any]: async def search_memories(self, **params) -> dict[str, Any]:
""" """
执行 search_memories 工具 执行 search_memories 工具
使用多策略检索优化: 使用多策略检索优化:
1. 查询分解(识别主要实体和概念) 1. 查询分解(识别主要实体和概念)
2. 多查询并行检索 2. 多查询并行检索
3. 结果融合和重排 3. 结果融合和重排
Args: Args:
**params: 工具参数 **params: 工具参数
- query: 查询字符串 - query: 查询字符串
@@ -449,7 +449,7 @@ class MemoryTools:
- expand_depth: 扩展深度(暂未使用) - expand_depth: 扩展深度(暂未使用)
- use_multi_query: 是否使用多查询策略默认True - use_multi_query: 是否使用多查询策略默认True
- context: 查询上下文(可选) - context: 查询上下文(可选)
Returns: Returns:
搜索结果 搜索结果
""" """
@@ -477,7 +477,7 @@ class MemoryTools:
# 2. 提取初始记忆ID来自向量搜索 # 2. 提取初始记忆ID来自向量搜索
initial_memory_ids = set() initial_memory_ids = set()
memory_scores = {} # 记录每个记忆的初始分数 memory_scores = {} # 记录每个记忆的初始分数
for node_id, similarity, metadata in similar_nodes: for node_id, similarity, metadata in similar_nodes:
if "memory_ids" in metadata: if "memory_ids" in metadata:
ids = metadata["memory_ids"] ids = metadata["memory_ids"]
@@ -486,7 +486,7 @@ class MemoryTools:
import orjson import orjson
try: try:
ids = orjson.loads(ids) ids = orjson.loads(ids)
except: except Exception:
ids = [ids] ids = [ids]
if isinstance(ids, list): if isinstance(ids, list):
for mem_id in ids: for mem_id in ids:
@@ -499,12 +499,12 @@ class MemoryTools:
expanded_memory_scores = {} expanded_memory_scores = {}
if expand_depth > 0 and initial_memory_ids: if expand_depth > 0 and initial_memory_ids:
logger.info(f"开始图扩展: 初始记忆{len(initial_memory_ids)}个, 深度={expand_depth}") logger.info(f"开始图扩展: 初始记忆{len(initial_memory_ids)}个, 深度={expand_depth}")
# 获取查询的embedding用于语义过滤 # 获取查询的embedding用于语义过滤
if self.builder.embedding_generator: if self.builder.embedding_generator:
try: try:
query_embedding = await self.builder.embedding_generator.generate(query) query_embedding = await self.builder.embedding_generator.generate(query)
# 直接使用图扩展逻辑(避免循环依赖) # 直接使用图扩展逻辑(避免循环依赖)
expanded_results = await self._expand_with_semantic_filter( expanded_results = await self._expand_with_semantic_filter(
initial_memory_ids=list(initial_memory_ids), initial_memory_ids=list(initial_memory_ids),
@@ -513,7 +513,7 @@ class MemoryTools:
semantic_threshold=self.expand_semantic_threshold, # 使用配置的阈值 semantic_threshold=self.expand_semantic_threshold, # 使用配置的阈值
max_expanded=top_k * 2 max_expanded=top_k * 2
) )
# 旧代码如果需要使用Manager # 旧代码如果需要使用Manager
# from src.memory_graph.manager import MemoryManager # from src.memory_graph.manager import MemoryManager
# manager = MemoryManager.get_instance() # manager = MemoryManager.get_instance()
@@ -524,19 +524,18 @@ class MemoryTools:
# semantic_threshold=0.5, # semantic_threshold=0.5,
# max_expanded=top_k * 2 # max_expanded=top_k * 2
# ) # )
# 合并扩展结果 # 合并扩展结果
for mem_id, score in expanded_results: expanded_memory_scores.update(dict(expanded_results))
expanded_memory_scores[mem_id] = score
logger.info(f"图扩展完成: 新增{len(expanded_memory_scores)}个相关记忆") logger.info(f"图扩展完成: 新增{len(expanded_memory_scores)}个相关记忆")
except Exception as e: except Exception as e:
logger.warning(f"图扩展失败: {e}") logger.warning(f"图扩展失败: {e}")
# 4. 合并初始记忆和扩展记忆 # 4. 合并初始记忆和扩展记忆
all_memory_ids = set(initial_memory_ids) | set(expanded_memory_scores.keys()) all_memory_ids = set(initial_memory_ids) | set(expanded_memory_scores.keys())
# 计算最终分数:初始记忆保持原分数,扩展记忆使用扩展分数 # 计算最终分数:初始记忆保持原分数,扩展记忆使用扩展分数
final_scores = {} final_scores = {}
for mem_id in all_memory_ids: for mem_id in all_memory_ids:
@@ -546,7 +545,7 @@ class MemoryTools:
elif mem_id in expanded_memory_scores: elif mem_id in expanded_memory_scores:
# 扩展记忆:使用图扩展分数(稍微降权) # 扩展记忆:使用图扩展分数(稍微降权)
final_scores[mem_id] = expanded_memory_scores[mem_id] * 0.8 final_scores[mem_id] = expanded_memory_scores[mem_id] * 0.8
# 按分数排序 # 按分数排序
sorted_memory_ids = sorted( sorted_memory_ids = sorted(
final_scores.keys(), final_scores.keys(),
@@ -562,7 +561,7 @@ class MemoryTools:
# 综合评分:相似度(60%) + 重要性(30%) + 时效性(10%) # 综合评分:相似度(60%) + 重要性(30%) + 时效性(10%)
similarity_score = final_scores[memory_id] similarity_score = final_scores[memory_id]
importance_score = memory.importance importance_score = memory.importance
# 计算时效性分数(最近的记忆得分更高) # 计算时效性分数(最近的记忆得分更高)
from datetime import datetime, timezone from datetime import datetime, timezone
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
@@ -573,16 +572,16 @@ class MemoryTools:
memory_time = memory.created_at memory_time = memory.created_at
age_days = (now - memory_time).total_seconds() / 86400 age_days = (now - memory_time).total_seconds() / 86400
recency_score = 1.0 / (1.0 + age_days / 30) # 30天半衰期 recency_score = 1.0 / (1.0 + age_days / 30) # 30天半衰期
# 综合分数 # 综合分数
final_score = ( final_score = (
similarity_score * 0.6 + similarity_score * 0.6 +
importance_score * 0.3 + importance_score * 0.3 +
recency_score * 0.1 recency_score * 0.1
) )
memories_with_scores.append((memory, final_score)) memories_with_scores.append((memory, final_score))
# 按综合分数排序 # 按综合分数排序
memories_with_scores.sort(key=lambda x: x[1], reverse=True) memories_with_scores.sort(key=lambda x: x[1], reverse=True)
memories = [mem for mem, _ in memories_with_scores[:top_k]] memories = [mem for mem, _ in memories_with_scores[:top_k]]
@@ -624,16 +623,16 @@ class MemoryTools:
} }
async def _generate_multi_queries_simple( async def _generate_multi_queries_simple(
self, query: str, context: Optional[Dict[str, Any]] = None self, query: str, context: dict[str, Any] | None = None
) -> List[Tuple[str, float]]: ) -> list[tuple[str, float]]:
""" """
简化版多查询生成(直接在 Tools 层实现,避免循环依赖) 简化版多查询生成(直接在 Tools 层实现,避免循环依赖)
让小模型直接生成3-5个不同角度的查询语句。 让小模型直接生成3-5个不同角度的查询语句。
""" """
try: try:
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config from src.config.config import model_config
from src.llm_models.utils_model import LLMRequest
llm = LLMRequest( llm = LLMRequest(
model_set=model_config.model_task_config.utils_small, model_set=model_config.model_task_config.utils_small,
@@ -648,10 +647,10 @@ class MemoryTools:
# 处理聊天历史提取最近5条左右的对话 # 处理聊天历史提取最近5条左右的对话
recent_chat = "" recent_chat = ""
if chat_history: if chat_history:
lines = chat_history.strip().split('\n') lines = chat_history.strip().split("\n")
# 取最近5条消息 # 取最近5条消息
recent_lines = lines[-5:] if len(lines) > 5 else lines recent_lines = lines[-5:] if len(lines) > 5 else lines
recent_chat = '\n'.join(recent_lines) recent_chat = "\n".join(recent_lines)
prompt = f"""基于聊天上下文为查询生成3-5个不同角度的搜索语句JSON格式 prompt = f"""基于聊天上下文为查询生成3-5个不同角度的搜索语句JSON格式
@@ -685,36 +684,38 @@ class MemoryTools:
""" """
response, _ = await llm.generate_response_async(prompt, temperature=0.3, max_tokens=250) response, _ = await llm.generate_response_async(prompt, temperature=0.3, max_tokens=250)
import orjson, re import re
response = re.sub(r'```json\s*', '', response)
response = re.sub(r'```\s*$', '', response).strip() import orjson
response = re.sub(r"```json\s*", "", response)
response = re.sub(r"```\s*$", "", response).strip()
data = orjson.loads(response) data = orjson.loads(response)
queries = data.get("queries", []) queries = data.get("queries", [])
result = [(item.get("text", "").strip(), float(item.get("weight", 0.5))) result = [(item.get("text", "").strip(), float(item.get("weight", 0.5)))
for item in queries if item.get("text", "").strip()] for item in queries if item.get("text", "").strip()]
if result: if result:
logger.info(f"生成查询: {[q for q, _ in result]}") logger.info(f"生成查询: {[q for q, _ in result]}")
return result return result
except Exception as e: except Exception as e:
logger.warning(f"多查询生成失败: {e}") logger.warning(f"多查询生成失败: {e}")
return [(query, 1.0)] return [(query, 1.0)]
async def _single_query_search( async def _single_query_search(
self, query: str, top_k: int self, query: str, top_k: int
) -> List[Tuple[str, float, Dict[str, Any]]]: ) -> list[tuple[str, float, dict[str, Any]]]:
""" """
传统的单查询搜索 传统的单查询搜索
Args: Args:
query: 查询字符串 query: 查询字符串
top_k: 返回结果数 top_k: 返回结果数
Returns: Returns:
相似节点列表 [(node_id, similarity, metadata), ...] 相似节点列表 [(node_id, similarity, metadata), ...]
""" """
@@ -735,30 +736,30 @@ class MemoryTools:
return similar_nodes return similar_nodes
async def _multi_query_search( async def _multi_query_search(
self, query: str, top_k: int, context: Optional[Dict[str, Any]] = None self, query: str, top_k: int, context: dict[str, Any] | None = None
) -> List[Tuple[str, float, Dict[str, Any]]]: ) -> list[tuple[str, float, dict[str, Any]]]:
""" """
多查询策略搜索(简化版) 多查询策略搜索(简化版)
直接使用小模型生成多个查询,无需复杂的分解和组合。 直接使用小模型生成多个查询,无需复杂的分解和组合。
步骤: 步骤:
1. 让小模型生成3-5个不同角度的查询 1. 让小模型生成3-5个不同角度的查询
2. 为每个查询生成嵌入 2. 为每个查询生成嵌入
3. 并行搜索并融合结果 3. 并行搜索并融合结果
Args: Args:
query: 查询字符串 query: 查询字符串
top_k: 返回结果数 top_k: 返回结果数
context: 查询上下文 context: 查询上下文
Returns: Returns:
融合后的相似节点列表 融合后的相似节点列表
""" """
try: try:
# 1. 使用小模型生成多个查询 # 1. 使用小模型生成多个查询
multi_queries = await self._generate_multi_queries_simple(query, context) multi_queries = await self._generate_multi_queries_simple(query, context)
logger.debug(f"生成 {len(multi_queries)} 个查询: {multi_queries}") logger.debug(f"生成 {len(multi_queries)} 个查询: {multi_queries}")
# 2. 生成所有查询的嵌入 # 2. 生成所有查询的嵌入
@@ -800,13 +801,13 @@ class MemoryTools:
if node.embedding is not None: if node.embedding is not None:
await self.vector_store.add_node(node) await self.vector_store.add_node(node)
async def _find_memory_by_description(self, description: str) -> Optional[Memory]: async def _find_memory_by_description(self, description: str) -> Memory | None:
""" """
通过描述查找记忆 通过描述查找记忆
Args: Args:
description: 记忆描述 description: 记忆描述
Returns: Returns:
找到的记忆,如果没有则返回 None 找到的记忆,如果没有则返回 None
""" """
@@ -827,13 +828,13 @@ class MemoryTools:
return None return None
# 获取最相似节点关联的记忆 # 获取最相似节点关联的记忆
node_id, similarity, metadata = similar_nodes[0] _node_id, _similarity, metadata = similar_nodes[0]
if "memory_ids" not in metadata or not metadata["memory_ids"]: if "memory_ids" not in metadata or not metadata["memory_ids"]:
return None return None
ids = metadata["memory_ids"] ids = metadata["memory_ids"]
# 确保是列表 # 确保是列表
if isinstance(ids, str): if isinstance(ids, str):
import orjson import orjson
@@ -842,11 +843,11 @@ class MemoryTools:
except Exception as e: except Exception as e:
logger.warning(f"JSON 解析失败: {e}") logger.warning(f"JSON 解析失败: {e}")
ids = [ids] ids = [ids]
if isinstance(ids, list) and ids: if isinstance(ids, list) and ids:
memory_id = ids[0] memory_id = ids[0]
return self.graph_store.get_memory_by_id(memory_id) return self.graph_store.get_memory_by_id(memory_id)
return None return None
def _summarize_memory(self, memory: Memory) -> str: def _summarize_memory(self, memory: Memory) -> str:
@@ -862,103 +863,102 @@ class MemoryTools:
async def _expand_with_semantic_filter( async def _expand_with_semantic_filter(
self, self,
initial_memory_ids: List[str], initial_memory_ids: list[str],
query_embedding, query_embedding,
max_depth: int = 2, max_depth: int = 2,
semantic_threshold: float = 0.5, semantic_threshold: float = 0.5,
max_expanded: int = 20 max_expanded: int = 20
) -> List[Tuple[str, float]]: ) -> list[tuple[str, float]]:
""" """
从初始记忆集合出发,沿图结构扩展,并用语义相似度过滤 从初始记忆集合出发,沿图结构扩展,并用语义相似度过滤
Args: Args:
initial_memory_ids: 初始记忆ID集合 initial_memory_ids: 初始记忆ID集合
query_embedding: 查询向量 query_embedding: 查询向量
max_depth: 最大扩展深度 max_depth: 最大扩展深度
semantic_threshold: 语义相似度阈值 semantic_threshold: 语义相似度阈值
max_expanded: 最多扩展多少个记忆 max_expanded: 最多扩展多少个记忆
Returns: Returns:
List[(memory_id, relevance_score)] List[(memory_id, relevance_score)]
""" """
if not initial_memory_ids or query_embedding is None: if not initial_memory_ids or query_embedding is None:
return [] return []
try: try:
import numpy as np
visited_memories = set(initial_memory_ids) visited_memories = set(initial_memory_ids)
expanded_memories: Dict[str, float] = {} expanded_memories: dict[str, float] = {}
current_level = initial_memory_ids current_level = initial_memory_ids
for depth in range(max_depth): for depth in range(max_depth):
next_level = [] next_level = []
for memory_id in current_level: for memory_id in current_level:
memory = self.graph_store.get_memory_by_id(memory_id) memory = self.graph_store.get_memory_by_id(memory_id)
if not memory: if not memory:
continue continue
for node in memory.nodes: for node in memory.nodes:
if not node.has_embedding(): if not node.has_embedding():
continue continue
try: try:
neighbors = list(self.graph_store.graph.neighbors(node.id)) neighbors = list(self.graph_store.graph.neighbors(node.id))
except: except Exception:
continue continue
for neighbor_id in neighbors: for neighbor_id in neighbors:
neighbor_node_data = self.graph_store.graph.nodes.get(neighbor_id) neighbor_node_data = self.graph_store.graph.nodes.get(neighbor_id)
if not neighbor_node_data: if not neighbor_node_data:
continue continue
neighbor_vector_data = await self.vector_store.get_node_by_id(neighbor_id) neighbor_vector_data = await self.vector_store.get_node_by_id(neighbor_id)
if neighbor_vector_data is None: if neighbor_vector_data is None:
continue continue
neighbor_embedding = neighbor_vector_data.get("embedding") neighbor_embedding = neighbor_vector_data.get("embedding")
if neighbor_embedding is None: if neighbor_embedding is None:
continue continue
# 计算语义相似度 # 计算语义相似度
semantic_sim = self._cosine_similarity( semantic_sim = self._cosine_similarity(
query_embedding, query_embedding,
neighbor_embedding neighbor_embedding
) )
# 获取边权重 # 获取边权重
try: try:
edge_data = self.graph_store.graph.get_edge_data(node.id, neighbor_id) edge_data = self.graph_store.graph.get_edge_data(node.id, neighbor_id)
edge_importance = edge_data.get("importance", 0.5) if edge_data else 0.5 edge_importance = edge_data.get("importance", 0.5) if edge_data else 0.5
except: except Exception:
edge_importance = 0.5 edge_importance = 0.5
# 综合评分 # 综合评分
depth_decay = 1.0 / (depth + 1) depth_decay = 1.0 / (depth + 1)
relevance_score = ( relevance_score = (
semantic_sim * 0.7 + semantic_sim * 0.7 +
edge_importance * 0.2 + edge_importance * 0.2 +
depth_decay * 0.1 depth_decay * 0.1
) )
if relevance_score < semantic_threshold: if relevance_score < semantic_threshold:
continue continue
# 提取记忆ID # 提取记忆ID
neighbor_memory_ids = neighbor_node_data.get("memory_ids", []) neighbor_memory_ids = neighbor_node_data.get("memory_ids", [])
if isinstance(neighbor_memory_ids, str): if isinstance(neighbor_memory_ids, str):
import orjson import orjson
try: try:
neighbor_memory_ids = orjson.loads(neighbor_memory_ids) neighbor_memory_ids = orjson.loads(neighbor_memory_ids)
except: except Exception:
neighbor_memory_ids = [neighbor_memory_ids] neighbor_memory_ids = [neighbor_memory_ids]
for neighbor_mem_id in neighbor_memory_ids: for neighbor_mem_id in neighbor_memory_ids:
if neighbor_mem_id in visited_memories: if neighbor_mem_id in visited_memories:
continue continue
if neighbor_mem_id not in expanded_memories: if neighbor_mem_id not in expanded_memories:
expanded_memories[neighbor_mem_id] = relevance_score expanded_memories[neighbor_mem_id] = relevance_score
visited_memories.add(neighbor_mem_id) visited_memories.add(neighbor_mem_id)
@@ -968,52 +968,52 @@ class MemoryTools:
expanded_memories[neighbor_mem_id], expanded_memories[neighbor_mem_id],
relevance_score relevance_score
) )
if not next_level or len(expanded_memories) >= max_expanded: if not next_level or len(expanded_memories) >= max_expanded:
break break
current_level = next_level[:max_expanded] current_level = next_level[:max_expanded]
sorted_results = sorted( sorted_results = sorted(
expanded_memories.items(), expanded_memories.items(),
key=lambda x: x[1], key=lambda x: x[1],
reverse=True reverse=True
)[:max_expanded] )[:max_expanded]
return sorted_results return sorted_results
except Exception as e: except Exception as e:
logger.error(f"图扩展失败: {e}", exc_info=True) logger.error(f"图扩展失败: {e}", exc_info=True)
return [] return []
def _cosine_similarity(self, vec1, vec2) -> float: def _cosine_similarity(self, vec1, vec2) -> float:
"""计算余弦相似度""" """计算余弦相似度"""
try: try:
import numpy as np import numpy as np
if not isinstance(vec1, np.ndarray): if not isinstance(vec1, np.ndarray):
vec1 = np.array(vec1) vec1 = np.array(vec1)
if not isinstance(vec2, np.ndarray): if not isinstance(vec2, np.ndarray):
vec2 = np.array(vec2) vec2 = np.array(vec2)
vec1_norm = np.linalg.norm(vec1) vec1_norm = np.linalg.norm(vec1)
vec2_norm = np.linalg.norm(vec2) vec2_norm = np.linalg.norm(vec2)
if vec1_norm == 0 or vec2_norm == 0: if vec1_norm == 0 or vec2_norm == 0:
return 0.0 return 0.0
similarity = np.dot(vec1, vec2) / (vec1_norm * vec2_norm) similarity = np.dot(vec1, vec2) / (vec1_norm * vec2_norm)
return float(similarity) return float(similarity)
except Exception as e: except Exception as e:
logger.warning(f"计算余弦相似度失败: {e}") logger.warning(f"计算余弦相似度失败: {e}")
return 0.0 return 0.0
@staticmethod @staticmethod
def get_all_tool_schemas() -> List[Dict[str, Any]]: def get_all_tool_schemas() -> list[dict[str, Any]]:
""" """
获取所有工具的 schema 获取所有工具的 schema
Returns: Returns:
工具 schema 列表 工具 schema 列表
""" """

View File

@@ -5,4 +5,4 @@
from src.memory_graph.utils.embeddings import EmbeddingGenerator, get_embedding_generator from src.memory_graph.utils.embeddings import EmbeddingGenerator, get_embedding_generator
from src.memory_graph.utils.time_parser import TimeParser from src.memory_graph.utils.time_parser import TimeParser
__all__ = ["TimeParser", "EmbeddingGenerator", "get_embedding_generator"] __all__ = ["EmbeddingGenerator", "TimeParser", "get_embedding_generator"]

View File

@@ -5,8 +5,6 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from functools import lru_cache
from typing import List, Optional
import numpy as np import numpy as np
@@ -18,12 +16,12 @@ logger = get_logger(__name__)
class EmbeddingGenerator: class EmbeddingGenerator:
""" """
嵌入向量生成器 嵌入向量生成器
策略: 策略:
1. 优先使用配置的 embedding API通过 LLMRequest 1. 优先使用配置的 embedding API通过 LLMRequest
2. 如果 API 不可用,回退到本地 sentence-transformers 2. 如果 API 不可用,回退到本地 sentence-transformers
3. 如果 sentence-transformers 未安装,使用随机向量(仅测试) 3. 如果 sentence-transformers 未安装,使用随机向量(仅测试)
优点: 优点:
- 降低本地运算负载 - 降低本地运算负载
- 即使未安装 sentence-transformers 也可正常运行 - 即使未安装 sentence-transformers 也可正常运行
@@ -37,19 +35,19 @@ class EmbeddingGenerator:
): ):
""" """
初始化嵌入生成器 初始化嵌入生成器
Args: Args:
use_api: 是否优先使用 API默认 True use_api: 是否优先使用 API默认 True
fallback_model_name: 回退本地模型名称 fallback_model_name: 回退本地模型名称
""" """
self.use_api = use_api self.use_api = use_api
self.fallback_model_name = fallback_model_name self.fallback_model_name = fallback_model_name
# API 相关 # API 相关
self._llm_request = None self._llm_request = None
self._api_available = False self._api_available = False
self._api_dimension = None self._api_dimension = None
# 本地模型相关 # 本地模型相关
self._local_model = None self._local_model = None
self._local_model_loaded = False self._local_model_loaded = False
@@ -58,24 +56,24 @@ class EmbeddingGenerator:
"""初始化 embedding API""" """初始化 embedding API"""
if self._api_available: if self._api_available:
return return
try: try:
from src.config.config import model_config from src.config.config import model_config
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
embedding_config = model_config.model_task_config.embedding embedding_config = model_config.model_task_config.embedding
self._llm_request = LLMRequest( self._llm_request = LLMRequest(
model_set=embedding_config, model_set=embedding_config,
request_type="memory_graph.embedding" request_type="memory_graph.embedding"
) )
# 获取嵌入维度 # 获取嵌入维度
if hasattr(embedding_config, "embedding_dimension") and embedding_config.embedding_dimension: if hasattr(embedding_config, "embedding_dimension") and embedding_config.embedding_dimension:
self._api_dimension = embedding_config.embedding_dimension self._api_dimension = embedding_config.embedding_dimension
self._api_available = True self._api_available = True
logger.info(f"✅ Embedding API 初始化成功 (维度: {self._api_dimension})") logger.info(f"✅ Embedding API 初始化成功 (维度: {self._api_dimension})")
except Exception as e: except Exception as e:
logger.warning(f"⚠️ Embedding API 初始化失败: {e}") logger.warning(f"⚠️ Embedding API 初始化失败: {e}")
self._api_available = False self._api_available = False
@@ -103,15 +101,15 @@ class EmbeddingGenerator:
async def generate(self, text: str) -> np.ndarray: async def generate(self, text: str) -> np.ndarray:
""" """
生成单个文本的嵌入向量 生成单个文本的嵌入向量
策略: 策略:
1. 优先使用 API 1. 优先使用 API
2. API 失败则使用本地模型 2. API 失败则使用本地模型
3. 本地模型不可用则使用随机向量 3. 本地模型不可用则使用随机向量
Args: Args:
text: 输入文本 text: 输入文本
Returns: Returns:
嵌入向量 嵌入向量
""" """
@@ -126,12 +124,12 @@ class EmbeddingGenerator:
embedding = await self._generate_with_api(text) embedding = await self._generate_with_api(text)
if embedding is not None: if embedding is not None:
return embedding return embedding
# 策略 2: 使用本地模型 # 策略 2: 使用本地模型
embedding = await self._generate_with_local_model(text) embedding = await self._generate_with_local_model(text)
if embedding is not None: if embedding is not None:
return embedding return embedding
# 策略 3: 随机向量(仅测试) # 策略 3: 随机向量(仅测试)
logger.warning(f"⚠️ 所有嵌入策略失败,使用随机向量: {text[:30]}...") logger.warning(f"⚠️ 所有嵌入策略失败,使用随机向量: {text[:30]}...")
dim = self._get_dimension() dim = self._get_dimension()
@@ -142,47 +140,47 @@ class EmbeddingGenerator:
dim = self._get_dimension() dim = self._get_dimension()
return np.random.rand(dim).astype(np.float32) return np.random.rand(dim).astype(np.float32)
async def _generate_with_api(self, text: str) -> Optional[np.ndarray]: async def _generate_with_api(self, text: str) -> np.ndarray | None:
"""使用 API 生成嵌入""" """使用 API 生成嵌入"""
try: try:
# 初始化 API # 初始化 API
if not self._api_available: if not self._api_available:
await self._initialize_api() await self._initialize_api()
if not self._api_available or not self._llm_request: if not self._api_available or not self._llm_request:
return None return None
# 调用 API # 调用 API
embedding_list, model_name = await self._llm_request.get_embedding(text) embedding_list, model_name = await self._llm_request.get_embedding(text)
if embedding_list and len(embedding_list) > 0: if embedding_list and len(embedding_list) > 0:
embedding = np.array(embedding_list, dtype=np.float32) embedding = np.array(embedding_list, dtype=np.float32)
logger.debug(f"🌐 API 生成嵌入: {text[:30]}... -> {len(embedding)}维 (模型: {model_name})") logger.debug(f"🌐 API 生成嵌入: {text[:30]}... -> {len(embedding)}维 (模型: {model_name})")
return embedding return embedding
return None return None
except Exception as e: except Exception as e:
logger.debug(f"API 嵌入生成失败: {e}") logger.debug(f"API 嵌入生成失败: {e}")
return None return None
async def _generate_with_local_model(self, text: str) -> Optional[np.ndarray]: async def _generate_with_local_model(self, text: str) -> np.ndarray | None:
"""使用本地模型生成嵌入""" """使用本地模型生成嵌入"""
try: try:
# 加载本地模型 # 加载本地模型
if not self._local_model_loaded: if not self._local_model_loaded:
self._load_local_model() self._load_local_model()
if not self._local_model_loaded or not self._local_model: if not self._local_model_loaded or not self._local_model:
return None return None
# 在线程池中运行 # 在线程池中运行
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
embedding = await loop.run_in_executor(None, self._encode_single_local, text) embedding = await loop.run_in_executor(None, self._encode_single_local, text)
logger.debug(f"💻 本地生成嵌入: {text[:30]}... -> {len(embedding)}") logger.debug(f"💻 本地生成嵌入: {text[:30]}... -> {len(embedding)}")
return embedding return embedding
except Exception as e: except Exception as e:
logger.debug(f"本地模型嵌入生成失败: {e}") logger.debug(f"本地模型嵌入生成失败: {e}")
return None return None
@@ -199,24 +197,24 @@ class EmbeddingGenerator:
# 优先使用 API 维度 # 优先使用 API 维度
if self._api_dimension: if self._api_dimension:
return self._api_dimension return self._api_dimension
# 其次使用本地模型维度 # 其次使用本地模型维度
if self._local_model_loaded and self._local_model: if self._local_model_loaded and self._local_model:
try: try:
return self._local_model.get_sentence_embedding_dimension() return self._local_model.get_sentence_embedding_dimension()
except: except Exception:
pass pass
# 默认 384sentence-transformers 常用维度) # 默认 384sentence-transformers 常用维度)
return 384 return 384
async def generate_batch(self, texts: List[str]) -> List[np.ndarray]: async def generate_batch(self, texts: list[str]) -> list[np.ndarray]:
""" """
批量生成嵌入向量 批量生成嵌入向量
Args: Args:
texts: 文本列表 texts: 文本列表
Returns: Returns:
嵌入向量列表 嵌入向量列表
""" """
@@ -236,13 +234,13 @@ class EmbeddingGenerator:
results = await self._generate_batch_with_api(valid_texts) results = await self._generate_batch_with_api(valid_texts)
if results: if results:
return results return results
# 回退到逐个生成 # 回退到逐个生成
results = [] results = []
for text in valid_texts: for text in valid_texts:
embedding = await self.generate(text) embedding = await self.generate(text)
results.append(embedding) results.append(embedding)
logger.info(f"✅ 批量生成嵌入: {len(texts)} 个文本") logger.info(f"✅ 批量生成嵌入: {len(texts)} 个文本")
return results return results
@@ -251,7 +249,7 @@ class EmbeddingGenerator:
dim = self._get_dimension() dim = self._get_dimension()
return [np.random.rand(dim).astype(np.float32) for _ in texts] return [np.random.rand(dim).astype(np.float32) for _ in texts]
async def _generate_batch_with_api(self, texts: List[str]) -> Optional[List[np.ndarray]]: async def _generate_batch_with_api(self, texts: list[str]) -> list[np.ndarray] | None:
"""使用 API 批量生成""" """使用 API 批量生成"""
try: try:
# 对于大多数 API批量调用就是多次单独调用 # 对于大多数 API批量调用就是多次单独调用
@@ -273,7 +271,7 @@ class EmbeddingGenerator:
# 全局单例 # 全局单例
_global_generator: Optional[EmbeddingGenerator] = None _global_generator: EmbeddingGenerator | None = None
def get_embedding_generator( def get_embedding_generator(
@@ -282,11 +280,11 @@ def get_embedding_generator(
) -> EmbeddingGenerator: ) -> EmbeddingGenerator:
""" """
获取全局嵌入生成器单例 获取全局嵌入生成器单例
Args: Args:
use_api: 是否优先使用 API use_api: 是否优先使用 API
fallback_model_name: 回退本地模型名称 fallback_model_name: 回退本地模型名称
Returns: Returns:
EmbeddingGenerator 实例 EmbeddingGenerator 实例
""" """

View File

@@ -5,10 +5,9 @@
""" """
import logging import logging
from typing import Optional, List, Dict, Any
from datetime import datetime from datetime import datetime
from src.memory_graph.models import Memory, MemoryNode, NodeType, EdgeType, MemoryType from src.memory_graph.models import EdgeType, Memory, MemoryType, NodeType
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -16,18 +15,18 @@ logger = logging.getLogger(__name__)
def format_memory_for_prompt(memory: Memory, include_metadata: bool = False) -> str: def format_memory_for_prompt(memory: Memory, include_metadata: bool = False) -> str:
""" """
将记忆对象格式化为适合提示词的自然语言描述 将记忆对象格式化为适合提示词的自然语言描述
根据记忆的图结构,构建完整的主谓宾描述,包含: 根据记忆的图结构,构建完整的主谓宾描述,包含:
- 主语subject node - 主语subject node
- 谓语/动作topic node - 谓语/动作topic node
- 宾语/对象object node如果存在 - 宾语/对象object node如果存在
- 属性信息attributes如时间、地点等 - 属性信息attributes如时间、地点等
- 关系信息(记忆之间的关系) - 关系信息(记忆之间的关系)
Args: Args:
memory: 记忆对象 memory: 记忆对象
include_metadata: 是否包含元数据(时间、重要性等) include_metadata: 是否包含元数据(时间、重要性等)
Returns: Returns:
格式化后的自然语言描述 格式化后的自然语言描述
""" """
@@ -37,24 +36,22 @@ def format_memory_for_prompt(memory: Memory, include_metadata: bool = False) ->
if not subject_node: if not subject_node:
logger.warning(f"记忆 {memory.id} 缺少主体节点") logger.warning(f"记忆 {memory.id} 缺少主体节点")
return "(记忆格式错误:缺少主体)" return "(记忆格式错误:缺少主体)"
subject_text = subject_node.content subject_text = subject_node.content
# 2. 查找主题节点(谓语/动作) # 2. 查找主题节点(谓语/动作)
topic_node = None topic_node = None
memory_type_relation = None
for edge in memory.edges: for edge in memory.edges:
if edge.edge_type == EdgeType.MEMORY_TYPE and edge.source_id == memory.subject_id: if edge.edge_type == EdgeType.MEMORY_TYPE and edge.source_id == memory.subject_id:
topic_node = memory.get_node_by_id(edge.target_id) topic_node = memory.get_node_by_id(edge.target_id)
memory_type_relation = edge.relation
break break
if not topic_node: if not topic_node:
logger.warning(f"记忆 {memory.id} 缺少主题节点") logger.warning(f"记忆 {memory.id} 缺少主题节点")
return f"{subject_text}(记忆格式错误:缺少主题)" return f"{subject_text}(记忆格式错误:缺少主题)"
topic_text = topic_node.content topic_text = topic_node.content
# 3. 查找客体节点(宾语)和核心关系 # 3. 查找客体节点(宾语)和核心关系
object_node = None object_node = None
core_relation = None core_relation = None
@@ -63,9 +60,9 @@ def format_memory_for_prompt(memory: Memory, include_metadata: bool = False) ->
object_node = memory.get_node_by_id(edge.target_id) object_node = memory.get_node_by_id(edge.target_id)
core_relation = edge.relation if edge.relation else "" core_relation = edge.relation if edge.relation else ""
break break
# 4. 收集属性节点 # 4. 收集属性节点
attributes: Dict[str, str] = {} attributes: dict[str, str] = {}
for edge in memory.edges: for edge in memory.edges:
if edge.edge_type == EdgeType.ATTRIBUTE: if edge.edge_type == EdgeType.ATTRIBUTE:
# 查找属性节点和值节点 # 查找属性节点和值节点
@@ -73,16 +70,16 @@ def format_memory_for_prompt(memory: Memory, include_metadata: bool = False) ->
if attr_node and attr_node.node_type == NodeType.ATTRIBUTE: if attr_node and attr_node.node_type == NodeType.ATTRIBUTE:
# 查找这个属性的值 # 查找这个属性的值
for value_edge in memory.edges: for value_edge in memory.edges:
if (value_edge.edge_type == EdgeType.ATTRIBUTE if (value_edge.edge_type == EdgeType.ATTRIBUTE
and value_edge.source_id == attr_node.id): and value_edge.source_id == attr_node.id):
value_node = memory.get_node_by_id(value_edge.target_id) value_node = memory.get_node_by_id(value_edge.target_id)
if value_node and value_node.node_type == NodeType.VALUE: if value_node and value_node.node_type == NodeType.VALUE:
attributes[attr_node.content] = value_node.content attributes[attr_node.content] = value_node.content
break break
# 5. 构建自然语言描述 # 5. 构建自然语言描述
parts = [] parts = []
# 主谓宾结构 # 主谓宾结构
if object_node is not None: if object_node is not None:
# 有完整的主谓宾 # 有完整的主谓宾
@@ -93,7 +90,7 @@ def format_memory_for_prompt(memory: Memory, include_metadata: bool = False) ->
else: else:
# 只有主谓 # 只有主谓
parts.append(f"{subject_text}{topic_text}") parts.append(f"{subject_text}{topic_text}")
# 添加属性信息 # 添加属性信息
if attributes: if attributes:
attr_parts = [] attr_parts = []
@@ -106,78 +103,78 @@ def format_memory_for_prompt(memory: Memory, include_metadata: bool = False) ->
for key, value in attributes.items(): for key, value in attributes.items():
if key not in ["时间", "地点"]: if key not in ["时间", "地点"]:
attr_parts.append(f"{key}{value}") attr_parts.append(f"{key}{value}")
if attr_parts: if attr_parts:
parts.append(f"{' '.join(attr_parts)}") parts.append(f"{' '.join(attr_parts)}")
description = "".join(parts) description = "".join(parts)
# 6. 添加元数据(可选) # 6. 添加元数据(可选)
if include_metadata: if include_metadata:
metadata_parts = [] metadata_parts = []
# 记忆类型 # 记忆类型
if memory.memory_type: if memory.memory_type:
metadata_parts.append(f"类型:{memory.memory_type.value}") metadata_parts.append(f"类型:{memory.memory_type.value}")
# 重要性 # 重要性
if memory.importance >= 0.8: if memory.importance >= 0.8:
metadata_parts.append("重要") metadata_parts.append("重要")
elif memory.importance >= 0.6: elif memory.importance >= 0.6:
metadata_parts.append("一般") metadata_parts.append("一般")
# 时间(如果没有在属性中) # 时间(如果没有在属性中)
if "时间" not in attributes: if "时间" not in attributes:
time_str = _format_relative_time(memory.created_at) time_str = _format_relative_time(memory.created_at)
if time_str: if time_str:
metadata_parts.append(time_str) metadata_parts.append(time_str)
if metadata_parts: if metadata_parts:
description += f" [{', '.join(metadata_parts)}]" description += f" [{', '.join(metadata_parts)}]"
return description return description
except Exception as e: except Exception as e:
logger.error(f"格式化记忆失败: {e}", exc_info=True) logger.error(f"格式化记忆失败: {e}", exc_info=True)
return f"(记忆格式化错误: {str(e)[:50]}" return f"(记忆格式化错误: {str(e)[:50]}"
def format_memories_for_prompt( def format_memories_for_prompt(
memories: List[Memory], memories: list[Memory],
max_count: Optional[int] = None, max_count: int | None = None,
include_metadata: bool = False, include_metadata: bool = False,
group_by_type: bool = False group_by_type: bool = False
) -> str: ) -> str:
""" """
批量格式化多条记忆为提示词文本 批量格式化多条记忆为提示词文本
Args: Args:
memories: 记忆列表 memories: 记忆列表
max_count: 最大记忆数量(可选) max_count: 最大记忆数量(可选)
include_metadata: 是否包含元数据 include_metadata: 是否包含元数据
group_by_type: 是否按类型分组 group_by_type: 是否按类型分组
Returns: Returns:
格式化后的文本,包含标题和列表 格式化后的文本,包含标题和列表
""" """
if not memories: if not memories:
return "" return ""
# 限制数量 # 限制数量
if max_count: if max_count:
memories = memories[:max_count] memories = memories[:max_count]
# 按类型分组 # 按类型分组
if group_by_type: if group_by_type:
type_groups: Dict[MemoryType, List[Memory]] = {} type_groups: dict[MemoryType, list[Memory]] = {}
for memory in memories: for memory in memories:
if memory.memory_type not in type_groups: if memory.memory_type not in type_groups:
type_groups[memory.memory_type] = [] type_groups[memory.memory_type] = []
type_groups[memory.memory_type].append(memory) type_groups[memory.memory_type].append(memory)
# 构建分组文本 # 构建分组文本
parts = ["### 🧠 相关记忆 (Relevant Memories)", ""] parts = ["### 🧠 相关记忆 (Relevant Memories)", ""]
type_order = [MemoryType.FACT, MemoryType.EVENT, MemoryType.RELATION, MemoryType.OPINION] type_order = [MemoryType.FACT, MemoryType.EVENT, MemoryType.RELATION, MemoryType.OPINION]
for mem_type in type_order: for mem_type in type_order:
if mem_type in type_groups: if mem_type in type_groups:
@@ -186,33 +183,33 @@ def format_memories_for_prompt(
desc = format_memory_for_prompt(memory, include_metadata) desc = format_memory_for_prompt(memory, include_metadata)
parts.append(f"- {desc}") parts.append(f"- {desc}")
parts.append("") parts.append("")
return "\n".join(parts) return "\n".join(parts)
else: else:
# 不分组,直接列出 # 不分组,直接列出
parts = ["### 🧠 相关记忆 (Relevant Memories)", ""] parts = ["### 🧠 相关记忆 (Relevant Memories)", ""]
for memory in memories: for memory in memories:
# 获取类型标签 # 获取类型标签
type_label = memory.memory_type.value if memory.memory_type else "未知" type_label = memory.memory_type.value if memory.memory_type else "未知"
# 格式化记忆内容 # 格式化记忆内容
desc = format_memory_for_prompt(memory, include_metadata) desc = format_memory_for_prompt(memory, include_metadata)
# 添加类型标签 # 添加类型标签
parts.append(f"- **[{type_label}]** {desc}") parts.append(f"- **[{type_label}]** {desc}")
return "\n".join(parts) return "\n".join(parts)
def get_memory_type_label(memory_type: str) -> str: def get_memory_type_label(memory_type: str) -> str:
""" """
获取记忆类型的中文标签 获取记忆类型的中文标签
Args: Args:
memory_type: 记忆类型(可能是英文或中文) memory_type: 记忆类型(可能是英文或中文)
Returns: Returns:
中文标签 中文标签
""" """
@@ -243,27 +240,27 @@ def get_memory_type_label(memory_type: str) -> str:
"经历": "经历", "经历": "经历",
"情境": "情境", "情境": "情境",
} }
# 转换为小写进行匹配 # 转换为小写进行匹配
memory_type_lower = memory_type.lower() if memory_type else "" memory_type_lower = memory_type.lower() if memory_type else ""
return type_mapping.get(memory_type_lower, "未知") return type_mapping.get(memory_type_lower, "未知")
def _format_relative_time(timestamp: datetime) -> Optional[str]: def _format_relative_time(timestamp: datetime) -> str | None:
""" """
格式化相对时间(如"2天前""刚才" 格式化相对时间(如"2天前""刚才"
Args: Args:
timestamp: 时间戳 timestamp: 时间戳
Returns: Returns:
相对时间描述如果太久远则返回None 相对时间描述如果太久远则返回None
""" """
try: try:
now = datetime.now() now = datetime.now()
delta = now - timestamp delta = now - timestamp
if delta.total_seconds() < 60: if delta.total_seconds() < 60:
return "刚才" return "刚才"
elif delta.total_seconds() < 3600: elif delta.total_seconds() < 3600:
@@ -290,17 +287,17 @@ def _format_relative_time(timestamp: datetime) -> Optional[str]:
def format_memory_summary(memory: Memory) -> str: def format_memory_summary(memory: Memory) -> str:
""" """
生成记忆的简短摘要(用于日志和调试) 生成记忆的简短摘要(用于日志和调试)
Args: Args:
memory: 记忆对象 memory: 记忆对象
Returns: Returns:
简短摘要 简短摘要
""" """
try: try:
subject_node = memory.get_subject_node() subject_node = memory.get_subject_node()
subject_text = subject_node.content if subject_node else "?" subject_text = subject_node.content if subject_node else "?"
topic_text = "?" topic_text = "?"
for edge in memory.edges: for edge in memory.edges:
if edge.edge_type == EdgeType.MEMORY_TYPE and edge.source_id == memory.subject_id: if edge.edge_type == EdgeType.MEMORY_TYPE and edge.source_id == memory.subject_id:
@@ -308,7 +305,7 @@ def format_memory_summary(memory: Memory) -> str:
if topic_node: if topic_node:
topic_text = topic_node.content topic_text = topic_node.content
break break
return f"{subject_text} - {memory.memory_type.value if memory.memory_type else '?'}: {topic_text}" return f"{subject_text} - {memory.memory_type.value if memory.memory_type else '?'}: {topic_text}"
except Exception: except Exception:
return f"记忆 {memory.id[:8]}" return f"记忆 {memory.id[:8]}"
@@ -316,8 +313,8 @@ def format_memory_summary(memory: Memory) -> str:
# 导出主要函数 # 导出主要函数
__all__ = [ __all__ = [
'format_memory_for_prompt', "format_memories_for_prompt",
'format_memories_for_prompt', "format_memory_for_prompt",
'get_memory_type_label', "format_memory_summary",
'format_memory_summary', "get_memory_type_label",
] ]

View File

@@ -14,7 +14,6 @@ from __future__ import annotations
import re import re
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Optional, Tuple
from src.common.logger import get_logger from src.common.logger import get_logger
@@ -24,26 +23,26 @@ logger = get_logger(__name__)
class TimeParser: class TimeParser:
""" """
时间解析器 时间解析器
负责将自然语言时间表达转换为标准化的绝对时间 负责将自然语言时间表达转换为标准化的绝对时间
""" """
def __init__(self, reference_time: Optional[datetime] = None): def __init__(self, reference_time: datetime | None = None):
""" """
初始化时间解析器 初始化时间解析器
Args: Args:
reference_time: 参考时间(通常是当前时间) reference_time: 参考时间(通常是当前时间)
""" """
self.reference_time = reference_time or datetime.now() self.reference_time = reference_time or datetime.now()
def parse(self, time_str: str) -> Optional[datetime]: def parse(self, time_str: str) -> datetime | None:
""" """
解析时间字符串 解析时间字符串
Args: Args:
time_str: 时间字符串 time_str: 时间字符串
Returns: Returns:
标准化的datetime对象如果解析失败则返回None 标准化的datetime对象如果解析失败则返回None
""" """
@@ -81,7 +80,7 @@ class TimeParser:
logger.warning(f"无法解析时间: '{time_str}',使用当前时间") logger.warning(f"无法解析时间: '{time_str}',使用当前时间")
return self.reference_time return self.reference_time
def _parse_relative_day(self, time_str: str) -> Optional[datetime]: def _parse_relative_day(self, time_str: str) -> datetime | None:
""" """
解析相对日期:今天、明天、昨天、前天、后天 解析相对日期:今天、明天、昨天、前天、后天
""" """
@@ -108,7 +107,7 @@ class TimeParser:
return None return None
def _parse_days_ago(self, time_str: str) -> Optional[datetime]: def _parse_days_ago(self, time_str: str) -> datetime | None:
""" """
解析 X天前/X天后、X周前/X周后、X个月前/X个月后 解析 X天前/X天后、X周前/X周后、X个月前/X个月后
""" """
@@ -172,7 +171,7 @@ class TimeParser:
return None return None
def _parse_hours_ago(self, time_str: str) -> Optional[datetime]: def _parse_hours_ago(self, time_str: str) -> datetime | None:
""" """
解析 X小时前/X小时后、X分钟前/X分钟后 解析 X小时前/X小时后、X分钟前/X分钟后
""" """
@@ -204,7 +203,7 @@ class TimeParser:
return None return None
def _parse_week_month_year(self, time_str: str) -> Optional[datetime]: def _parse_week_month_year(self, time_str: str) -> datetime | None:
""" """
解析:上周、上个月、去年、本周、本月、今年 解析:上周、上个月、去年、本周、本月、今年
""" """
@@ -232,7 +231,7 @@ class TimeParser:
return None return None
def _parse_specific_date(self, time_str: str) -> Optional[datetime]: def _parse_specific_date(self, time_str: str) -> datetime | None:
""" """
解析具体日期: 解析具体日期:
- 2025-11-05 - 2025-11-05
@@ -266,7 +265,7 @@ class TimeParser:
return None return None
def _parse_time_of_day(self, time_str: str) -> Optional[datetime]: def _parse_time_of_day(self, time_str: str) -> datetime | None:
""" """
解析一天中的时间: 解析一天中的时间:
- 早上、上午、中午、下午、晚上、深夜 - 早上、上午、中午、下午、晚上、深夜
@@ -290,7 +289,7 @@ class TimeParser:
} }
# 先检查是否有具体时间点早上8点、下午3点 # 先检查是否有具体时间点早上8点、下午3点
for period, default_hour in time_periods.items(): for period in time_periods.keys():
pattern = rf"{period}(\d{{1,2}})点?" pattern = rf"{period}(\d{{1,2}})点?"
match = re.search(pattern, time_str) match = re.search(pattern, time_str)
if match: if match:
@@ -314,13 +313,13 @@ class TimeParser:
return None return None
def _parse_combined_time(self, time_str: str) -> Optional[datetime]: def _parse_combined_time(self, time_str: str) -> datetime | None:
""" """
解析组合时间表达:今天下午、昨天晚上、明天早上 解析组合时间表达:今天下午、昨天晚上、明天早上
""" """
# 先解析日期部分 # 先解析日期部分
date_result = None date_result = None
# 相对日期关键词 # 相对日期关键词
relative_days = { relative_days = {
"今天": 0, "今日": 0, "今天": 0, "今日": 0,
@@ -330,16 +329,16 @@ class TimeParser:
"后天": 2, "后日": 2, "后天": 2, "后日": 2,
"大前天": -3, "大后天": 3, "大前天": -3, "大后天": 3,
} }
for keyword, days in relative_days.items(): for keyword, days in relative_days.items():
if keyword in time_str: if keyword in time_str:
date_result = self.reference_time + timedelta(days=days) date_result = self.reference_time + timedelta(days=days)
date_result = date_result.replace(hour=0, minute=0, second=0, microsecond=0) date_result = date_result.replace(hour=0, minute=0, second=0, microsecond=0)
break break
if not date_result: if not date_result:
return None return None
# 再解析时间段部分 # 再解析时间段部分
time_periods = { time_periods = {
"早上": 8, "早晨": 8, "早上": 8, "早晨": 8,
@@ -351,7 +350,7 @@ class TimeParser:
"深夜": 23, "深夜": 23,
"凌晨": 2, "凌晨": 2,
} }
for period, hour in time_periods.items(): for period, hour in time_periods.items():
if period in time_str: if period in time_str:
# 检查是否有具体时间点 # 检查是否有具体时间点
@@ -363,17 +362,17 @@ class TimeParser:
if period in ["下午", "晚上"] and hour < 12: if period in ["下午", "晚上"] and hour < 12:
hour += 12 hour += 12
return date_result.replace(hour=hour) return date_result.replace(hour=hour)
# 如果没有时间段返回日期默认0点 # 如果没有时间段返回日期默认0点
return date_result return date_result
def _chinese_num_to_int(self, num_str: str) -> int: def _chinese_num_to_int(self, num_str: str) -> int:
""" """
将中文数字转换为阿拉伯数字 将中文数字转换为阿拉伯数字
Args: Args:
num_str: 中文数字字符串(如:"""""3" num_str: 中文数字字符串(如:"""""3"
Returns: Returns:
整数 整数
""" """
@@ -418,11 +417,11 @@ class TimeParser:
def format_time(self, dt: datetime, format_type: str = "iso") -> str: def format_time(self, dt: datetime, format_type: str = "iso") -> str:
""" """
格式化时间 格式化时间
Args: Args:
dt: datetime对象 dt: datetime对象
format_type: 格式类型 ("iso", "cn", "relative") format_type: 格式类型 ("iso", "cn", "relative")
Returns: Returns:
格式化的时间字符串 格式化的时间字符串
""" """
@@ -461,13 +460,13 @@ class TimeParser:
return str(dt) return str(dt)
def parse_time_range(self, time_str: str) -> Tuple[Optional[datetime], Optional[datetime]]: def parse_time_range(self, time_str: str) -> tuple[datetime | None, datetime | None]:
""" """
解析时间范围最近一周、最近3天 解析时间范围最近一周、最近3天
Args: Args:
time_str: 时间范围字符串 time_str: 时间范围字符串
Returns: Returns:
(start_time, end_time) (start_time, end_time)
""" """