fix: 修复代码质量问题 - 更正异常处理和导入语句
Co-authored-by: Windpicker-owo <221029311+Windpicker-owo@users.noreply.github.com>
This commit is contained in:
@@ -6,10 +6,12 @@ from typing import ClassVar
|
|||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.plugin_system import BasePlugin, register_plugin
|
from src.plugin_system import BasePlugin, register_plugin
|
||||||
from src.plugin_system.base.component_types import ComponentInfo, ToolInfo
|
|
||||||
|
|
||||||
logger = get_logger("memory_graph_plugin")
|
logger = get_logger("memory_graph_plugin")
|
||||||
|
|
||||||
|
# 用于存储后台任务引用
|
||||||
|
_background_tasks = set()
|
||||||
|
|
||||||
|
|
||||||
@register_plugin
|
@register_plugin
|
||||||
class MemoryGraphPlugin(BasePlugin):
|
class MemoryGraphPlugin(BasePlugin):
|
||||||
@@ -60,6 +62,7 @@ class MemoryGraphPlugin(BasePlugin):
|
|||||||
"""插件卸载时的回调"""
|
"""插件卸载时的回调"""
|
||||||
try:
|
try:
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from src.memory_graph.manager_singleton import shutdown_memory_manager
|
from src.memory_graph.manager_singleton import shutdown_memory_manager
|
||||||
|
|
||||||
logger.info(f"{self.log_prefix} 正在关闭记忆系统...")
|
logger.info(f"{self.log_prefix} 正在关闭记忆系统...")
|
||||||
@@ -68,7 +71,10 @@ class MemoryGraphPlugin(BasePlugin):
|
|||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
if loop.is_running():
|
if loop.is_running():
|
||||||
# 如果循环正在运行,创建任务
|
# 如果循环正在运行,创建任务
|
||||||
asyncio.create_task(shutdown_memory_manager())
|
task = asyncio.create_task(shutdown_memory_manager())
|
||||||
|
# 存储引用以防止任务被垃圾回收
|
||||||
|
_background_tasks.add(task)
|
||||||
|
task.add_done_callback(_background_tasks.discard)
|
||||||
else:
|
else:
|
||||||
# 如果循环未运行,直接运行
|
# 如果循环未运行,直接运行
|
||||||
loop.run_until_complete(shutdown_memory_manager())
|
loop.run_until_complete(shutdown_memory_manager())
|
||||||
|
|||||||
@@ -10,13 +10,13 @@
|
|||||||
使用方法:
|
使用方法:
|
||||||
# 预览模式(不实际删除)
|
# 预览模式(不实际删除)
|
||||||
python scripts/deduplicate_memories.py --dry-run
|
python scripts/deduplicate_memories.py --dry-run
|
||||||
|
|
||||||
# 执行去重
|
# 执行去重
|
||||||
python scripts/deduplicate_memories.py
|
python scripts/deduplicate_memories.py
|
||||||
|
|
||||||
# 指定相似度阈值
|
# 指定相似度阈值
|
||||||
python scripts/deduplicate_memories.py --threshold 0.9
|
python scripts/deduplicate_memories.py --threshold 0.9
|
||||||
|
|
||||||
# 指定数据目录
|
# 指定数据目录
|
||||||
python scripts/deduplicate_memories.py --data-dir data/memory_graph
|
python scripts/deduplicate_memories.py --data-dir data/memory_graph
|
||||||
"""
|
"""
|
||||||
@@ -25,27 +25,26 @@ import asyncio
|
|||||||
import sys
|
import sys
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Set, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.memory_graph.manager_singleton import get_memory_manager, initialize_memory_manager, shutdown_memory_manager
|
from src.memory_graph.manager_singleton import initialize_memory_manager, shutdown_memory_manager
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class MemoryDeduplicator:
|
class MemoryDeduplicator:
|
||||||
"""记忆去重器"""
|
"""记忆去重器"""
|
||||||
|
|
||||||
def __init__(self, data_dir: str = "data/memory_graph", dry_run: bool = False, threshold: float = 0.85):
|
def __init__(self, data_dir: str = "data/memory_graph", dry_run: bool = False, threshold: float = 0.85):
|
||||||
self.data_dir = data_dir
|
self.data_dir = data_dir
|
||||||
self.dry_run = dry_run
|
self.dry_run = dry_run
|
||||||
self.threshold = threshold
|
self.threshold = threshold
|
||||||
self.manager = None
|
self.manager = None
|
||||||
|
|
||||||
# 统计信息
|
# 统计信息
|
||||||
self.stats = {
|
self.stats = {
|
||||||
"total_memories": 0,
|
"total_memories": 0,
|
||||||
@@ -54,34 +53,34 @@ class MemoryDeduplicator:
|
|||||||
"duplicates_removed": 0,
|
"duplicates_removed": 0,
|
||||||
"errors": 0,
|
"errors": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
"""初始化记忆管理器"""
|
"""初始化记忆管理器"""
|
||||||
logger.info(f"正在初始化记忆管理器 (data_dir={self.data_dir})...")
|
logger.info(f"正在初始化记忆管理器 (data_dir={self.data_dir})...")
|
||||||
self.manager = await initialize_memory_manager(data_dir=self.data_dir)
|
self.manager = await initialize_memory_manager(data_dir=self.data_dir)
|
||||||
if not self.manager:
|
if not self.manager:
|
||||||
raise RuntimeError("记忆管理器初始化失败")
|
raise RuntimeError("记忆管理器初始化失败")
|
||||||
|
|
||||||
self.stats["total_memories"] = len(self.manager.graph_store.get_all_memories())
|
self.stats["total_memories"] = len(self.manager.graph_store.get_all_memories())
|
||||||
logger.info(f"✅ 记忆管理器初始化成功,共 {self.stats['total_memories']} 条记忆")
|
logger.info(f"✅ 记忆管理器初始化成功,共 {self.stats['total_memories']} 条记忆")
|
||||||
|
|
||||||
async def find_similar_pairs(self) -> List[Tuple[str, str, float]]:
|
async def find_similar_pairs(self) -> list[tuple[str, str, float]]:
|
||||||
"""
|
"""
|
||||||
查找所有相似的记忆对(通过向量相似度计算)
|
查找所有相似的记忆对(通过向量相似度计算)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
[(memory_id_1, memory_id_2, similarity), ...]
|
[(memory_id_1, memory_id_2, similarity), ...]
|
||||||
"""
|
"""
|
||||||
logger.info("正在扫描相似记忆对...")
|
logger.info("正在扫描相似记忆对...")
|
||||||
similar_pairs = []
|
similar_pairs = []
|
||||||
seen_pairs = set() # 避免重复
|
seen_pairs = set() # 避免重复
|
||||||
|
|
||||||
# 获取所有记忆
|
# 获取所有记忆
|
||||||
all_memories = self.manager.graph_store.get_all_memories()
|
all_memories = self.manager.graph_store.get_all_memories()
|
||||||
total_memories = len(all_memories)
|
total_memories = len(all_memories)
|
||||||
|
|
||||||
logger.info(f"开始计算 {total_memories} 条记忆的相似度...")
|
logger.info(f"开始计算 {total_memories} 条记忆的相似度...")
|
||||||
|
|
||||||
# 两两比较记忆的相似度
|
# 两两比较记忆的相似度
|
||||||
for i, memory_i in enumerate(all_memories):
|
for i, memory_i in enumerate(all_memories):
|
||||||
# 每处理10条记忆让出控制权
|
# 每处理10条记忆让出控制权
|
||||||
@@ -89,115 +88,115 @@ class MemoryDeduplicator:
|
|||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
if i > 0:
|
if i > 0:
|
||||||
logger.info(f"进度: {i}/{total_memories} ({i*100//total_memories}%)")
|
logger.info(f"进度: {i}/{total_memories} ({i*100//total_memories}%)")
|
||||||
|
|
||||||
# 获取记忆i的向量(从主题节点)
|
# 获取记忆i的向量(从主题节点)
|
||||||
vector_i = None
|
vector_i = None
|
||||||
for node in memory_i.nodes:
|
for node in memory_i.nodes:
|
||||||
if node.embedding is not None:
|
if node.embedding is not None:
|
||||||
vector_i = node.embedding
|
vector_i = node.embedding
|
||||||
break
|
break
|
||||||
|
|
||||||
if vector_i is None:
|
if vector_i is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 与后续记忆比较
|
# 与后续记忆比较
|
||||||
for j in range(i + 1, total_memories):
|
for j in range(i + 1, total_memories):
|
||||||
memory_j = all_memories[j]
|
memory_j = all_memories[j]
|
||||||
|
|
||||||
# 获取记忆j的向量
|
# 获取记忆j的向量
|
||||||
vector_j = None
|
vector_j = None
|
||||||
for node in memory_j.nodes:
|
for node in memory_j.nodes:
|
||||||
if node.embedding is not None:
|
if node.embedding is not None:
|
||||||
vector_j = node.embedding
|
vector_j = node.embedding
|
||||||
break
|
break
|
||||||
|
|
||||||
if vector_j is None:
|
if vector_j is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 计算余弦相似度
|
# 计算余弦相似度
|
||||||
similarity = self._cosine_similarity(vector_i, vector_j)
|
similarity = self._cosine_similarity(vector_i, vector_j)
|
||||||
|
|
||||||
# 只保存满足阈值的相似对
|
# 只保存满足阈值的相似对
|
||||||
if similarity >= self.threshold:
|
if similarity >= self.threshold:
|
||||||
pair_key = tuple(sorted([memory_i.id, memory_j.id]))
|
pair_key = tuple(sorted([memory_i.id, memory_j.id]))
|
||||||
if pair_key not in seen_pairs:
|
if pair_key not in seen_pairs:
|
||||||
seen_pairs.add(pair_key)
|
seen_pairs.add(pair_key)
|
||||||
similar_pairs.append((memory_i.id, memory_j.id, similarity))
|
similar_pairs.append((memory_i.id, memory_j.id, similarity))
|
||||||
|
|
||||||
self.stats["similar_pairs"] = len(similar_pairs)
|
self.stats["similar_pairs"] = len(similar_pairs)
|
||||||
logger.info(f"找到 {len(similar_pairs)} 对相似记忆(阈值>={self.threshold})")
|
logger.info(f"找到 {len(similar_pairs)} 对相似记忆(阈值>={self.threshold})")
|
||||||
|
|
||||||
return similar_pairs
|
return similar_pairs
|
||||||
|
|
||||||
def _cosine_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:
|
def _cosine_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:
|
||||||
"""计算余弦相似度"""
|
"""计算余弦相似度"""
|
||||||
try:
|
try:
|
||||||
vec1_norm = np.linalg.norm(vec1)
|
vec1_norm = np.linalg.norm(vec1)
|
||||||
vec2_norm = np.linalg.norm(vec2)
|
vec2_norm = np.linalg.norm(vec2)
|
||||||
|
|
||||||
if vec1_norm == 0 or vec2_norm == 0:
|
if vec1_norm == 0 or vec2_norm == 0:
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
similarity = np.dot(vec1, vec2) / (vec1_norm * vec2_norm)
|
similarity = np.dot(vec1, vec2) / (vec1_norm * vec2_norm)
|
||||||
return float(similarity)
|
return float(similarity)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"计算余弦相似度失败: {e}")
|
logger.error(f"计算余弦相似度失败: {e}")
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
def decide_which_to_keep(self, mem_id_1: str, mem_id_2: str) -> Tuple[Optional[str], Optional[str]]:
|
def decide_which_to_keep(self, mem_id_1: str, mem_id_2: str) -> tuple[str | None, str | None]:
|
||||||
"""
|
"""
|
||||||
决定保留哪个记忆,删除哪个
|
决定保留哪个记忆,删除哪个
|
||||||
|
|
||||||
优先级:
|
优先级:
|
||||||
1. 重要性更高的
|
1. 重要性更高的
|
||||||
2. 激活度更高的
|
2. 激活度更高的
|
||||||
3. 创建时间更早的
|
3. 创建时间更早的
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(keep_id, remove_id)
|
(keep_id, remove_id)
|
||||||
"""
|
"""
|
||||||
mem1 = self.manager.graph_store.get_memory_by_id(mem_id_1)
|
mem1 = self.manager.graph_store.get_memory_by_id(mem_id_1)
|
||||||
mem2 = self.manager.graph_store.get_memory_by_id(mem_id_2)
|
mem2 = self.manager.graph_store.get_memory_by_id(mem_id_2)
|
||||||
|
|
||||||
if not mem1 or not mem2:
|
if not mem1 or not mem2:
|
||||||
logger.warning(f"记忆不存在: {mem_id_1} or {mem_id_2}")
|
logger.warning(f"记忆不存在: {mem_id_1} or {mem_id_2}")
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
# 比较重要性
|
# 比较重要性
|
||||||
if mem1.importance > mem2.importance:
|
if mem1.importance > mem2.importance:
|
||||||
return mem_id_1, mem_id_2
|
return mem_id_1, mem_id_2
|
||||||
elif mem1.importance < mem2.importance:
|
elif mem1.importance < mem2.importance:
|
||||||
return mem_id_2, mem_id_1
|
return mem_id_2, mem_id_1
|
||||||
|
|
||||||
# 重要性相同,比较激活度
|
# 重要性相同,比较激活度
|
||||||
if mem1.activation > mem2.activation:
|
if mem1.activation > mem2.activation:
|
||||||
return mem_id_1, mem_id_2
|
return mem_id_1, mem_id_2
|
||||||
elif mem1.activation < mem2.activation:
|
elif mem1.activation < mem2.activation:
|
||||||
return mem_id_2, mem_id_1
|
return mem_id_2, mem_id_1
|
||||||
|
|
||||||
# 激活度也相同,保留更早创建的
|
# 激活度也相同,保留更早创建的
|
||||||
if mem1.created_at < mem2.created_at:
|
if mem1.created_at < mem2.created_at:
|
||||||
return mem_id_1, mem_id_2
|
return mem_id_1, mem_id_2
|
||||||
else:
|
else:
|
||||||
return mem_id_2, mem_id_1
|
return mem_id_2, mem_id_1
|
||||||
|
|
||||||
async def deduplicate_pair(self, mem_id_1: str, mem_id_2: str, similarity: float) -> bool:
|
async def deduplicate_pair(self, mem_id_1: str, mem_id_2: str, similarity: float) -> bool:
|
||||||
"""
|
"""
|
||||||
去重一对相似记忆
|
去重一对相似记忆
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
是否成功去重
|
是否成功去重
|
||||||
"""
|
"""
|
||||||
keep_id, remove_id = self.decide_which_to_keep(mem_id_1, mem_id_2)
|
keep_id, remove_id = self.decide_which_to_keep(mem_id_1, mem_id_2)
|
||||||
|
|
||||||
if not keep_id or not remove_id:
|
if not keep_id or not remove_id:
|
||||||
self.stats["errors"] += 1
|
self.stats["errors"] += 1
|
||||||
return False
|
return False
|
||||||
|
|
||||||
keep_mem = self.manager.graph_store.get_memory_by_id(keep_id)
|
keep_mem = self.manager.graph_store.get_memory_by_id(keep_id)
|
||||||
remove_mem = self.manager.graph_store.get_memory_by_id(remove_id)
|
remove_mem = self.manager.graph_store.get_memory_by_id(remove_id)
|
||||||
|
|
||||||
logger.info(f"")
|
logger.info("")
|
||||||
logger.info(f"{'[预览]' if self.dry_run else '[执行]'} 去重相似记忆对 (相似度={similarity:.3f}):")
|
logger.info(f"{'[预览]' if self.dry_run else '[执行]'} 去重相似记忆对 (相似度={similarity:.3f}):")
|
||||||
logger.info(f" 保留: {keep_id}")
|
logger.info(f" 保留: {keep_id}")
|
||||||
logger.info(f" - 主题: {keep_mem.metadata.get('topic', 'N/A')}")
|
logger.info(f" - 主题: {keep_mem.metadata.get('topic', 'N/A')}")
|
||||||
@@ -209,41 +208,41 @@ class MemoryDeduplicator:
|
|||||||
logger.info(f" - 重要性: {remove_mem.importance:.2f}")
|
logger.info(f" - 重要性: {remove_mem.importance:.2f}")
|
||||||
logger.info(f" - 激活度: {remove_mem.activation:.2f}")
|
logger.info(f" - 激活度: {remove_mem.activation:.2f}")
|
||||||
logger.info(f" - 创建时间: {remove_mem.created_at}")
|
logger.info(f" - 创建时间: {remove_mem.created_at}")
|
||||||
|
|
||||||
if self.dry_run:
|
if self.dry_run:
|
||||||
logger.info(" [预览模式] 不执行实际删除")
|
logger.info(" [预览模式] 不执行实际删除")
|
||||||
self.stats["duplicates_found"] += 1
|
self.stats["duplicates_found"] += 1
|
||||||
return True
|
return True
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 增强保留记忆的属性
|
# 增强保留记忆的属性
|
||||||
keep_mem.importance = min(1.0, keep_mem.importance + 0.05)
|
keep_mem.importance = min(1.0, keep_mem.importance + 0.05)
|
||||||
keep_mem.activation = min(1.0, keep_mem.activation + 0.05)
|
keep_mem.activation = min(1.0, keep_mem.activation + 0.05)
|
||||||
|
|
||||||
# 累加访问次数
|
# 累加访问次数
|
||||||
if hasattr(keep_mem, 'access_count') and hasattr(remove_mem, 'access_count'):
|
if hasattr(keep_mem, "access_count") and hasattr(remove_mem, "access_count"):
|
||||||
keep_mem.access_count += remove_mem.access_count
|
keep_mem.access_count += remove_mem.access_count
|
||||||
|
|
||||||
# 删除相似记忆
|
# 删除相似记忆
|
||||||
await self.manager.delete_memory(remove_id)
|
await self.manager.delete_memory(remove_id)
|
||||||
|
|
||||||
self.stats["duplicates_removed"] += 1
|
self.stats["duplicates_removed"] += 1
|
||||||
logger.info(f" ✅ 删除成功")
|
logger.info(" ✅ 删除成功")
|
||||||
|
|
||||||
# 让出控制权
|
# 让出控制权
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f" ❌ 删除失败: {e}", exc_info=True)
|
logger.error(f" ❌ 删除失败: {e}", exc_info=True)
|
||||||
self.stats["errors"] += 1
|
self.stats["errors"] += 1
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
"""执行去重"""
|
"""执行去重"""
|
||||||
start_time = datetime.now()
|
start_time = datetime.now()
|
||||||
|
|
||||||
print("="*70)
|
print("="*70)
|
||||||
print("记忆去重工具")
|
print("记忆去重工具")
|
||||||
print("="*70)
|
print("="*70)
|
||||||
@@ -252,13 +251,13 @@ class MemoryDeduplicator:
|
|||||||
print(f"模式: {'预览模式(不实际删除)' if self.dry_run else '执行模式(会实际删除)'}")
|
print(f"模式: {'预览模式(不实际删除)' if self.dry_run else '执行模式(会实际删除)'}")
|
||||||
print("="*70)
|
print("="*70)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
# 初始化
|
# 初始化
|
||||||
await self.initialize()
|
await self.initialize()
|
||||||
|
|
||||||
# 查找相似对
|
# 查找相似对
|
||||||
similar_pairs = await self.find_similar_pairs()
|
similar_pairs = await self.find_similar_pairs()
|
||||||
|
|
||||||
if not similar_pairs:
|
if not similar_pairs:
|
||||||
logger.info("未找到需要去重的相似记忆对")
|
logger.info("未找到需要去重的相似记忆对")
|
||||||
print()
|
print()
|
||||||
@@ -266,19 +265,19 @@ class MemoryDeduplicator:
|
|||||||
print("未找到需要去重的记忆")
|
print("未找到需要去重的记忆")
|
||||||
print("="*70)
|
print("="*70)
|
||||||
return
|
return
|
||||||
|
|
||||||
# 去重处理
|
# 去重处理
|
||||||
logger.info(f"开始{'预览' if self.dry_run else '执行'}去重...")
|
logger.info(f"开始{'预览' if self.dry_run else '执行'}去重...")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
processed_pairs = set() # 避免重复处理
|
processed_pairs = set() # 避免重复处理
|
||||||
|
|
||||||
for mem_id_1, mem_id_2, similarity in similar_pairs:
|
for mem_id_1, mem_id_2, similarity in similar_pairs:
|
||||||
# 检查是否已处理(可能一个记忆已被删除)
|
# 检查是否已处理(可能一个记忆已被删除)
|
||||||
pair_key = tuple(sorted([mem_id_1, mem_id_2]))
|
pair_key = tuple(sorted([mem_id_1, mem_id_2]))
|
||||||
if pair_key in processed_pairs:
|
if pair_key in processed_pairs:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 检查记忆是否仍存在
|
# 检查记忆是否仍存在
|
||||||
if not self.manager.graph_store.get_memory_by_id(mem_id_1):
|
if not self.manager.graph_store.get_memory_by_id(mem_id_1):
|
||||||
logger.debug(f"记忆 {mem_id_1} 已不存在,跳过")
|
logger.debug(f"记忆 {mem_id_1} 已不存在,跳过")
|
||||||
@@ -286,22 +285,22 @@ class MemoryDeduplicator:
|
|||||||
if not self.manager.graph_store.get_memory_by_id(mem_id_2):
|
if not self.manager.graph_store.get_memory_by_id(mem_id_2):
|
||||||
logger.debug(f"记忆 {mem_id_2} 已不存在,跳过")
|
logger.debug(f"记忆 {mem_id_2} 已不存在,跳过")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 执行去重
|
# 执行去重
|
||||||
success = await self.deduplicate_pair(mem_id_1, mem_id_2, similarity)
|
success = await self.deduplicate_pair(mem_id_1, mem_id_2, similarity)
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
processed_pairs.add(pair_key)
|
processed_pairs.add(pair_key)
|
||||||
|
|
||||||
# 保存数据(如果不是干运行)
|
# 保存数据(如果不是干运行)
|
||||||
if not self.dry_run:
|
if not self.dry_run:
|
||||||
logger.info("正在保存数据...")
|
logger.info("正在保存数据...")
|
||||||
await self.manager.persistence.save_graph_store(self.manager.graph_store)
|
await self.manager.persistence.save_graph_store(self.manager.graph_store)
|
||||||
logger.info("✅ 数据已保存")
|
logger.info("✅ 数据已保存")
|
||||||
|
|
||||||
# 统计报告
|
# 统计报告
|
||||||
elapsed = (datetime.now() - start_time).total_seconds()
|
elapsed = (datetime.now() - start_time).total_seconds()
|
||||||
|
|
||||||
print()
|
print()
|
||||||
print("="*70)
|
print("="*70)
|
||||||
print("去重报告")
|
print("去重报告")
|
||||||
@@ -312,7 +311,7 @@ class MemoryDeduplicator:
|
|||||||
print(f"{'预览通过' if self.dry_run else '成功删除'}: {self.stats['duplicates_found'] if self.dry_run else self.stats['duplicates_removed']}")
|
print(f"{'预览通过' if self.dry_run else '成功删除'}: {self.stats['duplicates_found'] if self.dry_run else self.stats['duplicates_removed']}")
|
||||||
print(f"错误数: {self.stats['errors']}")
|
print(f"错误数: {self.stats['errors']}")
|
||||||
print(f"耗时: {elapsed:.2f}秒")
|
print(f"耗时: {elapsed:.2f}秒")
|
||||||
|
|
||||||
if self.dry_run:
|
if self.dry_run:
|
||||||
print()
|
print()
|
||||||
print("⚠️ 这是预览模式,未实际删除任何记忆")
|
print("⚠️ 这是预览模式,未实际删除任何记忆")
|
||||||
@@ -322,9 +321,9 @@ class MemoryDeduplicator:
|
|||||||
print("✅ 去重完成!")
|
print("✅ 去重完成!")
|
||||||
final_count = len(self.manager.graph_store.get_all_memories())
|
final_count = len(self.manager.graph_store.get_all_memories())
|
||||||
print(f"📊 最终记忆数: {final_count} (减少 {self.stats['total_memories'] - final_count} 条)")
|
print(f"📊 最终记忆数: {final_count} (减少 {self.stats['total_memories'] - final_count} 条)")
|
||||||
|
|
||||||
print("="*70)
|
print("="*70)
|
||||||
|
|
||||||
async def cleanup(self):
|
async def cleanup(self):
|
||||||
"""清理资源"""
|
"""清理资源"""
|
||||||
if self.manager:
|
if self.manager:
|
||||||
@@ -340,50 +339,50 @@ async def main():
|
|||||||
示例:
|
示例:
|
||||||
# 预览模式(推荐先运行)
|
# 预览模式(推荐先运行)
|
||||||
python scripts/deduplicate_memories.py --dry-run
|
python scripts/deduplicate_memories.py --dry-run
|
||||||
|
|
||||||
# 执行去重
|
# 执行去重
|
||||||
python scripts/deduplicate_memories.py
|
python scripts/deduplicate_memories.py
|
||||||
|
|
||||||
# 指定相似度阈值(只处理相似度>=0.9的记忆对)
|
# 指定相似度阈值(只处理相似度>=0.9的记忆对)
|
||||||
python scripts/deduplicate_memories.py --threshold 0.9
|
python scripts/deduplicate_memories.py --threshold 0.9
|
||||||
|
|
||||||
# 指定数据目录
|
# 指定数据目录
|
||||||
python scripts/deduplicate_memories.py --data-dir data/memory_graph
|
python scripts/deduplicate_memories.py --data-dir data/memory_graph
|
||||||
|
|
||||||
# 组合使用
|
# 组合使用
|
||||||
python scripts/deduplicate_memories.py --dry-run --threshold 0.95 --data-dir data/test
|
python scripts/deduplicate_memories.py --dry-run --threshold 0.95 --data-dir data/test
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--dry-run",
|
"--dry-run",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="预览模式,不实际删除记忆(推荐先运行此模式)"
|
help="预览模式,不实际删除记忆(推荐先运行此模式)"
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--threshold",
|
"--threshold",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.85,
|
default=0.85,
|
||||||
help="相似度阈值,只处理相似度>=此值的记忆对(默认: 0.85)"
|
help="相似度阈值,只处理相似度>=此值的记忆对(默认: 0.85)"
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--data-dir",
|
"--data-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="data/memory_graph",
|
default="data/memory_graph",
|
||||||
help="记忆数据目录(默认: data/memory_graph)"
|
help="记忆数据目录(默认: data/memory_graph)"
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# 创建去重器
|
# 创建去重器
|
||||||
deduplicator = MemoryDeduplicator(
|
deduplicator = MemoryDeduplicator(
|
||||||
data_dir=args.data_dir,
|
data_dir=args.data_dir,
|
||||||
dry_run=args.dry_run,
|
dry_run=args.dry_run,
|
||||||
threshold=args.threshold
|
threshold=args.threshold
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 执行去重
|
# 执行去重
|
||||||
await deduplicator.run()
|
await deduplicator.run()
|
||||||
@@ -396,7 +395,7 @@ async def main():
|
|||||||
finally:
|
finally:
|
||||||
# 清理资源
|
# 清理资源
|
||||||
await deduplicator.cleanup()
|
await deduplicator.cleanup()
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -6,24 +6,24 @@
|
|||||||
|
|
||||||
from src.memory_graph.manager import MemoryManager
|
from src.memory_graph.manager import MemoryManager
|
||||||
from src.memory_graph.models import (
|
from src.memory_graph.models import (
|
||||||
|
EdgeType,
|
||||||
Memory,
|
Memory,
|
||||||
MemoryEdge,
|
MemoryEdge,
|
||||||
MemoryNode,
|
MemoryNode,
|
||||||
MemoryStatus,
|
MemoryStatus,
|
||||||
MemoryType,
|
MemoryType,
|
||||||
NodeType,
|
NodeType,
|
||||||
EdgeType,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"MemoryManager",
|
"EdgeType",
|
||||||
"Memory",
|
"Memory",
|
||||||
"MemoryNode",
|
|
||||||
"MemoryEdge",
|
"MemoryEdge",
|
||||||
|
"MemoryManager",
|
||||||
|
"MemoryNode",
|
||||||
|
"MemoryStatus",
|
||||||
"MemoryType",
|
"MemoryType",
|
||||||
"NodeType",
|
"NodeType",
|
||||||
"EdgeType",
|
|
||||||
"MemoryStatus",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
__version__ = "0.1.0"
|
__version__ = "0.1.0"
|
||||||
|
|||||||
@@ -6,4 +6,4 @@ from src.memory_graph.core.builder import MemoryBuilder
|
|||||||
from src.memory_graph.core.extractor import MemoryExtractor
|
from src.memory_graph.core.extractor import MemoryExtractor
|
||||||
from src.memory_graph.core.node_merger import NodeMerger
|
from src.memory_graph.core.node_merger import NodeMerger
|
||||||
|
|
||||||
__all__ = ["NodeMerger", "MemoryExtractor", "MemoryBuilder"]
|
__all__ = ["MemoryBuilder", "MemoryExtractor", "NodeMerger"]
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -16,7 +16,6 @@ from src.memory_graph.models import (
|
|||||||
MemoryEdge,
|
MemoryEdge,
|
||||||
MemoryNode,
|
MemoryNode,
|
||||||
MemoryStatus,
|
MemoryStatus,
|
||||||
MemoryType,
|
|
||||||
NodeType,
|
NodeType,
|
||||||
)
|
)
|
||||||
from src.memory_graph.storage.graph_store import GraphStore
|
from src.memory_graph.storage.graph_store import GraphStore
|
||||||
@@ -28,7 +27,7 @@ logger = get_logger(__name__)
|
|||||||
class MemoryBuilder:
|
class MemoryBuilder:
|
||||||
"""
|
"""
|
||||||
记忆构建器
|
记忆构建器
|
||||||
|
|
||||||
负责:
|
负责:
|
||||||
1. 根据提取的元素自动构造记忆子图
|
1. 根据提取的元素自动构造记忆子图
|
||||||
2. 创建节点和边的完整结构
|
2. 创建节点和边的完整结构
|
||||||
@@ -41,11 +40,11 @@ class MemoryBuilder:
|
|||||||
self,
|
self,
|
||||||
vector_store: VectorStore,
|
vector_store: VectorStore,
|
||||||
graph_store: GraphStore,
|
graph_store: GraphStore,
|
||||||
embedding_generator: Optional[Any] = None,
|
embedding_generator: Any | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
初始化记忆构建器
|
初始化记忆构建器
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
vector_store: 向量存储
|
vector_store: 向量存储
|
||||||
graph_store: 图存储
|
graph_store: 图存储
|
||||||
@@ -55,13 +54,13 @@ class MemoryBuilder:
|
|||||||
self.graph_store = graph_store
|
self.graph_store = graph_store
|
||||||
self.embedding_generator = embedding_generator
|
self.embedding_generator = embedding_generator
|
||||||
|
|
||||||
async def build_memory(self, extracted_params: Dict[str, Any]) -> Memory:
|
async def build_memory(self, extracted_params: dict[str, Any]) -> Memory:
|
||||||
"""
|
"""
|
||||||
构建完整的记忆对象
|
构建完整的记忆对象
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
extracted_params: 提取器返回的标准化参数
|
extracted_params: 提取器返回的标准化参数
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Memory 对象(状态为 STAGED)
|
Memory 对象(状态为 STAGED)
|
||||||
"""
|
"""
|
||||||
@@ -97,7 +96,7 @@ class MemoryBuilder:
|
|||||||
edges.append(memory_type_edge)
|
edges.append(memory_type_edge)
|
||||||
|
|
||||||
# 4. 如果有客体,创建客体节点并连接
|
# 4. 如果有客体,创建客体节点并连接
|
||||||
if "object" in extracted_params and extracted_params["object"]:
|
if extracted_params.get("object"):
|
||||||
object_node = await self._create_object_node(
|
object_node = await self._create_object_node(
|
||||||
content=extracted_params["object"], memory_id=memory_id
|
content=extracted_params["object"], memory_id=memory_id
|
||||||
)
|
)
|
||||||
@@ -158,14 +157,14 @@ class MemoryBuilder:
|
|||||||
) -> MemoryNode:
|
) -> MemoryNode:
|
||||||
"""
|
"""
|
||||||
创建新节点或复用已存在的相似节点
|
创建新节点或复用已存在的相似节点
|
||||||
|
|
||||||
对于主体(SUBJECT)和属性(ATTRIBUTE),检查是否已存在相同内容的节点
|
对于主体(SUBJECT)和属性(ATTRIBUTE),检查是否已存在相同内容的节点
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
content: 节点内容
|
content: 节点内容
|
||||||
node_type: 节点类型
|
node_type: 节点类型
|
||||||
memory_id: 所属记忆ID
|
memory_id: 所属记忆ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
MemoryNode 对象
|
MemoryNode 对象
|
||||||
"""
|
"""
|
||||||
@@ -190,11 +189,11 @@ class MemoryBuilder:
|
|||||||
async def _create_topic_node(self, content: str, memory_id: str) -> MemoryNode:
|
async def _create_topic_node(self, content: str, memory_id: str) -> MemoryNode:
|
||||||
"""
|
"""
|
||||||
创建主题节点(需要生成嵌入向量)
|
创建主题节点(需要生成嵌入向量)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
content: 节点内容
|
content: 节点内容
|
||||||
memory_id: 所属记忆ID
|
memory_id: 所属记忆ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
MemoryNode 对象
|
MemoryNode 对象
|
||||||
"""
|
"""
|
||||||
@@ -225,11 +224,11 @@ class MemoryBuilder:
|
|||||||
async def _create_object_node(self, content: str, memory_id: str) -> MemoryNode:
|
async def _create_object_node(self, content: str, memory_id: str) -> MemoryNode:
|
||||||
"""
|
"""
|
||||||
创建客体节点(需要生成嵌入向量)
|
创建客体节点(需要生成嵌入向量)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
content: 节点内容
|
content: 节点内容
|
||||||
memory_id: 所属记忆ID
|
memory_id: 所属记忆ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
MemoryNode 对象
|
MemoryNode 对象
|
||||||
"""
|
"""
|
||||||
@@ -258,22 +257,22 @@ class MemoryBuilder:
|
|||||||
|
|
||||||
async def _process_attributes(
|
async def _process_attributes(
|
||||||
self,
|
self,
|
||||||
attributes: Dict[str, Any],
|
attributes: dict[str, Any],
|
||||||
parent_id: str,
|
parent_id: str,
|
||||||
memory_id: str,
|
memory_id: str,
|
||||||
importance: float,
|
importance: float,
|
||||||
) -> tuple[List[MemoryNode], List[MemoryEdge]]:
|
) -> tuple[list[MemoryNode], list[MemoryEdge]]:
|
||||||
"""
|
"""
|
||||||
处理属性,构建属性子图
|
处理属性,构建属性子图
|
||||||
|
|
||||||
结构:TOPIC -> ATTRIBUTE -> VALUE
|
结构:TOPIC -> ATTRIBUTE -> VALUE
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
attributes: 属性字典
|
attributes: 属性字典
|
||||||
parent_id: 父节点ID(通常是TOPIC)
|
parent_id: 父节点ID(通常是TOPIC)
|
||||||
memory_id: 所属记忆ID
|
memory_id: 所属记忆ID
|
||||||
importance: 重要性
|
importance: 重要性
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(属性节点列表, 属性边列表)
|
(属性节点列表, 属性边列表)
|
||||||
"""
|
"""
|
||||||
@@ -322,10 +321,10 @@ class MemoryBuilder:
|
|||||||
async def _generate_embedding(self, text: str) -> np.ndarray:
|
async def _generate_embedding(self, text: str) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
生成文本的嵌入向量
|
生成文本的嵌入向量
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: 文本内容
|
text: 文本内容
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
嵌入向量
|
嵌入向量
|
||||||
"""
|
"""
|
||||||
@@ -341,14 +340,14 @@ class MemoryBuilder:
|
|||||||
|
|
||||||
async def _find_existing_node(
|
async def _find_existing_node(
|
||||||
self, content: str, node_type: NodeType
|
self, content: str, node_type: NodeType
|
||||||
) -> Optional[MemoryNode]:
|
) -> MemoryNode | None:
|
||||||
"""
|
"""
|
||||||
查找已存在的完全匹配节点(用于主体和属性)
|
查找已存在的完全匹配节点(用于主体和属性)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
content: 节点内容
|
content: 节点内容
|
||||||
node_type: 节点类型
|
node_type: 节点类型
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
已存在的节点,如果没有则返回 None
|
已存在的节点,如果没有则返回 None
|
||||||
"""
|
"""
|
||||||
@@ -369,14 +368,14 @@ class MemoryBuilder:
|
|||||||
|
|
||||||
async def _find_similar_topic(
|
async def _find_similar_topic(
|
||||||
self, content: str, embedding: np.ndarray
|
self, content: str, embedding: np.ndarray
|
||||||
) -> Optional[MemoryNode]:
|
) -> MemoryNode | None:
|
||||||
"""
|
"""
|
||||||
查找相似的主题节点(基于语义相似度)
|
查找相似的主题节点(基于语义相似度)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
content: 内容
|
content: 内容
|
||||||
embedding: 嵌入向量
|
embedding: 嵌入向量
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
相似节点,如果没有则返回 None
|
相似节点,如果没有则返回 None
|
||||||
"""
|
"""
|
||||||
@@ -414,14 +413,14 @@ class MemoryBuilder:
|
|||||||
|
|
||||||
async def _find_similar_object(
|
async def _find_similar_object(
|
||||||
self, content: str, embedding: np.ndarray
|
self, content: str, embedding: np.ndarray
|
||||||
) -> Optional[MemoryNode]:
|
) -> MemoryNode | None:
|
||||||
"""
|
"""
|
||||||
查找相似的客体节点(基于语义相似度)
|
查找相似的客体节点(基于语义相似度)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
content: 内容
|
content: 内容
|
||||||
embedding: 嵌入向量
|
embedding: 嵌入向量
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
相似节点,如果没有则返回 None
|
相似节点,如果没有则返回 None
|
||||||
"""
|
"""
|
||||||
@@ -480,13 +479,13 @@ class MemoryBuilder:
|
|||||||
) -> MemoryEdge:
|
) -> MemoryEdge:
|
||||||
"""
|
"""
|
||||||
关联两个记忆(创建因果或引用边)
|
关联两个记忆(创建因果或引用边)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
source_memory: 源记忆
|
source_memory: 源记忆
|
||||||
target_memory: 目标记忆
|
target_memory: 目标记忆
|
||||||
relation_type: 关系类型(如 "导致", "引用")
|
relation_type: 关系类型(如 "导致", "引用")
|
||||||
importance: 重要性
|
importance: 重要性
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
创建的边
|
创建的边
|
||||||
"""
|
"""
|
||||||
@@ -525,7 +524,7 @@ class MemoryBuilder:
|
|||||||
logger.error(f"记忆关联失败: {e}", exc_info=True)
|
logger.error(f"记忆关联失败: {e}", exc_info=True)
|
||||||
raise RuntimeError(f"记忆关联失败: {e}")
|
raise RuntimeError(f"记忆关联失败: {e}")
|
||||||
|
|
||||||
def _find_topic_node(self, memory: Memory) -> Optional[MemoryNode]:
|
def _find_topic_node(self, memory: Memory) -> MemoryNode | None:
|
||||||
"""查找记忆中的主题节点"""
|
"""查找记忆中的主题节点"""
|
||||||
for node in memory.nodes:
|
for node in memory.nodes:
|
||||||
if node.node_type == NodeType.TOPIC:
|
if node.node_type == NodeType.TOPIC:
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.memory_graph.models import MemoryType
|
from src.memory_graph.models import MemoryType
|
||||||
@@ -17,7 +17,7 @@ logger = get_logger(__name__)
|
|||||||
class MemoryExtractor:
|
class MemoryExtractor:
|
||||||
"""
|
"""
|
||||||
记忆提取器
|
记忆提取器
|
||||||
|
|
||||||
负责:
|
负责:
|
||||||
1. 从工具调用参数中提取记忆元素
|
1. 从工具调用参数中提取记忆元素
|
||||||
2. 验证参数完整性和有效性
|
2. 验证参数完整性和有效性
|
||||||
@@ -25,19 +25,19 @@ class MemoryExtractor:
|
|||||||
4. 清洗和格式化数据
|
4. 清洗和格式化数据
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, time_parser: Optional[TimeParser] = None):
|
def __init__(self, time_parser: TimeParser | None = None):
|
||||||
"""
|
"""
|
||||||
初始化记忆提取器
|
初始化记忆提取器
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
time_parser: 时间解析器(可选)
|
time_parser: 时间解析器(可选)
|
||||||
"""
|
"""
|
||||||
self.time_parser = time_parser or TimeParser()
|
self.time_parser = time_parser or TimeParser()
|
||||||
|
|
||||||
def extract_from_tool_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
|
def extract_from_tool_params(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
从工具参数中提取记忆元素
|
从工具参数中提取记忆元素
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
params: 工具调用参数,例如:
|
params: 工具调用参数,例如:
|
||||||
{
|
{
|
||||||
@@ -48,7 +48,7 @@ class MemoryExtractor:
|
|||||||
"attributes": {"时间": "今天", "地点": "家里"},
|
"attributes": {"时间": "今天", "地点": "家里"},
|
||||||
"importance": 0.3
|
"importance": 0.3
|
||||||
}
|
}
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
提取和标准化后的参数字典
|
提取和标准化后的参数字典
|
||||||
"""
|
"""
|
||||||
@@ -64,11 +64,11 @@ class MemoryExtractor:
|
|||||||
}
|
}
|
||||||
|
|
||||||
# 3. 提取可选的客体
|
# 3. 提取可选的客体
|
||||||
if "object" in params and params["object"]:
|
if params.get("object"):
|
||||||
extracted["object"] = self._clean_text(params["object"])
|
extracted["object"] = self._clean_text(params["object"])
|
||||||
|
|
||||||
# 4. 提取和标准化属性
|
# 4. 提取和标准化属性
|
||||||
if "attributes" in params and params["attributes"]:
|
if params.get("attributes"):
|
||||||
extracted["attributes"] = self._process_attributes(params["attributes"])
|
extracted["attributes"] = self._process_attributes(params["attributes"])
|
||||||
else:
|
else:
|
||||||
extracted["attributes"] = {}
|
extracted["attributes"] = {}
|
||||||
@@ -86,13 +86,13 @@ class MemoryExtractor:
|
|||||||
logger.error(f"记忆提取失败: {e}", exc_info=True)
|
logger.error(f"记忆提取失败: {e}", exc_info=True)
|
||||||
raise ValueError(f"记忆提取失败: {e}")
|
raise ValueError(f"记忆提取失败: {e}")
|
||||||
|
|
||||||
def _validate_required_params(self, params: Dict[str, Any]) -> None:
|
def _validate_required_params(self, params: dict[str, Any]) -> None:
|
||||||
"""
|
"""
|
||||||
验证必需参数
|
验证必需参数
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
params: 参数字典
|
params: 参数字典
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: 如果缺少必需参数
|
ValueError: 如果缺少必需参数
|
||||||
"""
|
"""
|
||||||
@@ -105,10 +105,10 @@ class MemoryExtractor:
|
|||||||
def _clean_text(self, text: Any) -> str:
|
def _clean_text(self, text: Any) -> str:
|
||||||
"""
|
"""
|
||||||
清洗文本
|
清洗文本
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: 输入文本
|
text: 输入文本
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
清洗后的文本
|
清洗后的文本
|
||||||
"""
|
"""
|
||||||
@@ -128,13 +128,13 @@ class MemoryExtractor:
|
|||||||
def _parse_memory_type(self, type_str: str) -> MemoryType:
|
def _parse_memory_type(self, type_str: str) -> MemoryType:
|
||||||
"""
|
"""
|
||||||
解析记忆类型
|
解析记忆类型
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
type_str: 类型字符串
|
type_str: 类型字符串
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
MemoryType 枚举
|
MemoryType 枚举
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: 如果类型无效
|
ValueError: 如果类型无效
|
||||||
"""
|
"""
|
||||||
@@ -166,10 +166,10 @@ class MemoryExtractor:
|
|||||||
def _parse_importance(self, importance: Any) -> float:
|
def _parse_importance(self, importance: Any) -> float:
|
||||||
"""
|
"""
|
||||||
解析重要性值
|
解析重要性值
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
importance: 重要性值(可以是数字、字符串等)
|
importance: 重要性值(可以是数字、字符串等)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
0-1之间的浮点数
|
0-1之间的浮点数
|
||||||
"""
|
"""
|
||||||
@@ -181,13 +181,13 @@ class MemoryExtractor:
|
|||||||
logger.warning(f"无效的重要性值: {importance},使用默认值 0.5")
|
logger.warning(f"无效的重要性值: {importance},使用默认值 0.5")
|
||||||
return 0.5
|
return 0.5
|
||||||
|
|
||||||
def _process_attributes(self, attributes: Dict[str, Any]) -> Dict[str, Any]:
|
def _process_attributes(self, attributes: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
处理属性字典
|
处理属性字典
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
attributes: 原始属性字典
|
attributes: 原始属性字典
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
处理后的属性字典
|
处理后的属性字典
|
||||||
"""
|
"""
|
||||||
@@ -222,10 +222,10 @@ class MemoryExtractor:
|
|||||||
|
|
||||||
return processed
|
return processed
|
||||||
|
|
||||||
def extract_link_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
|
def extract_link_params(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
提取记忆关联参数(用于 link_memories 工具)
|
提取记忆关联参数(用于 link_memories 工具)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
params: 工具参数,例如:
|
params: 工具参数,例如:
|
||||||
{
|
{
|
||||||
@@ -234,7 +234,7 @@ class MemoryExtractor:
|
|||||||
"relation_type": "导致",
|
"relation_type": "导致",
|
||||||
"importance": 0.6
|
"importance": 0.6
|
||||||
}
|
}
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
提取后的参数
|
提取后的参数
|
||||||
"""
|
"""
|
||||||
@@ -266,10 +266,10 @@ class MemoryExtractor:
|
|||||||
def validate_relation_type(self, relation_type: str) -> str:
|
def validate_relation_type(self, relation_type: str) -> str:
|
||||||
"""
|
"""
|
||||||
验证关系类型
|
验证关系类型
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
relation_type: 关系类型字符串
|
relation_type: 关系类型字符串
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
标准化的关系类型
|
标准化的关系类型
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -4,11 +4,6 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.official_configs import MemoryConfig
|
from src.config.official_configs import MemoryConfig
|
||||||
from src.memory_graph.models import MemoryNode, NodeType
|
from src.memory_graph.models import MemoryNode, NodeType
|
||||||
@@ -21,7 +16,7 @@ logger = get_logger(__name__)
|
|||||||
class NodeMerger:
|
class NodeMerger:
|
||||||
"""
|
"""
|
||||||
节点合并器
|
节点合并器
|
||||||
|
|
||||||
负责:
|
负责:
|
||||||
1. 基于语义相似度查找重复节点
|
1. 基于语义相似度查找重复节点
|
||||||
2. 验证上下文匹配
|
2. 验证上下文匹配
|
||||||
@@ -36,7 +31,7 @@ class NodeMerger:
|
|||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
初始化节点合并器
|
初始化节点合并器
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
vector_store: 向量存储
|
vector_store: 向量存储
|
||||||
graph_store: 图存储
|
graph_store: 图存储
|
||||||
@@ -54,17 +49,17 @@ class NodeMerger:
|
|||||||
async def find_similar_nodes(
|
async def find_similar_nodes(
|
||||||
self,
|
self,
|
||||||
node: MemoryNode,
|
node: MemoryNode,
|
||||||
threshold: Optional[float] = None,
|
threshold: float | None = None,
|
||||||
limit: int = 5,
|
limit: int = 5,
|
||||||
) -> List[Tuple[MemoryNode, float]]:
|
) -> list[tuple[MemoryNode, float]]:
|
||||||
"""
|
"""
|
||||||
查找与指定节点相似的节点
|
查找与指定节点相似的节点
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node: 查询节点
|
node: 查询节点
|
||||||
threshold: 相似度阈值(可选,默认使用配置值)
|
threshold: 相似度阈值(可选,默认使用配置值)
|
||||||
limit: 返回结果数量
|
limit: 返回结果数量
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of (similar_node, similarity)
|
List of (similar_node, similarity)
|
||||||
"""
|
"""
|
||||||
@@ -112,12 +107,12 @@ class NodeMerger:
|
|||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
判断两个节点是否应该合并
|
判断两个节点是否应该合并
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
source_node: 源节点
|
source_node: 源节点
|
||||||
target_node: 目标节点
|
target_node: 目标节点
|
||||||
similarity: 语义相似度
|
similarity: 语义相似度
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
是否应该合并
|
是否应该合并
|
||||||
"""
|
"""
|
||||||
@@ -157,16 +152,16 @@ class NodeMerger:
|
|||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
检查两个节点的上下文是否匹配
|
检查两个节点的上下文是否匹配
|
||||||
|
|
||||||
上下文匹配的标准:
|
上下文匹配的标准:
|
||||||
1. 节点类型相同
|
1. 节点类型相同
|
||||||
2. 邻居节点有重叠
|
2. 邻居节点有重叠
|
||||||
3. 邻居节点的内容相似
|
3. 邻居节点的内容相似
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
source_node: 源节点
|
source_node: 源节点
|
||||||
target_node: 目标节点
|
target_node: 目标节点
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
是否匹配
|
是否匹配
|
||||||
"""
|
"""
|
||||||
@@ -207,7 +202,7 @@ class NodeMerger:
|
|||||||
# 如果有 30% 以上的邻居重叠,认为上下文匹配
|
# 如果有 30% 以上的邻居重叠,认为上下文匹配
|
||||||
return overlap_ratio > 0.3
|
return overlap_ratio > 0.3
|
||||||
|
|
||||||
def _get_node_content(self, node_id: str) -> Optional[str]:
|
def _get_node_content(self, node_id: str) -> str | None:
|
||||||
"""获取节点的内容"""
|
"""获取节点的内容"""
|
||||||
memories = self.graph_store.get_memories_by_node(node_id)
|
memories = self.graph_store.get_memories_by_node(node_id)
|
||||||
if memories:
|
if memories:
|
||||||
@@ -223,13 +218,13 @@ class NodeMerger:
|
|||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
合并两个节点
|
合并两个节点
|
||||||
|
|
||||||
将 source 节点的所有边转移到 target 节点,然后删除 source
|
将 source 节点的所有边转移到 target 节点,然后删除 source
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
source: 源节点(将被删除)
|
source: 源节点(将被删除)
|
||||||
target: 目标节点(保留)
|
target: 目标节点(保留)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
是否成功
|
是否成功
|
||||||
"""
|
"""
|
||||||
@@ -255,7 +250,7 @@ class NodeMerger:
|
|||||||
def _update_memory_references(self, old_node_id: str, new_node_id: str) -> None:
|
def _update_memory_references(self, old_node_id: str, new_node_id: str) -> None:
|
||||||
"""
|
"""
|
||||||
更新记忆中的节点引用
|
更新记忆中的节点引用
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
old_node_id: 旧节点ID
|
old_node_id: 旧节点ID
|
||||||
new_node_id: 新节点ID
|
new_node_id: 新节点ID
|
||||||
@@ -280,16 +275,16 @@ class NodeMerger:
|
|||||||
|
|
||||||
async def batch_merge_similar_nodes(
|
async def batch_merge_similar_nodes(
|
||||||
self,
|
self,
|
||||||
nodes: List[MemoryNode],
|
nodes: list[MemoryNode],
|
||||||
progress_callback: Optional[callable] = None,
|
progress_callback: callable | None = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
批量处理节点合并
|
批量处理节点合并
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
nodes: 要处理的节点列表
|
nodes: 要处理的节点列表
|
||||||
progress_callback: 进度回调函数
|
progress_callback: 进度回调函数
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
统计信息字典
|
统计信息字典
|
||||||
"""
|
"""
|
||||||
@@ -344,14 +339,14 @@ class NodeMerger:
|
|||||||
self,
|
self,
|
||||||
min_similarity: float = 0.85,
|
min_similarity: float = 0.85,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
) -> List[Tuple[str, str, float]]:
|
) -> list[tuple[str, str, float]]:
|
||||||
"""
|
"""
|
||||||
获取待合并的候选节点对
|
获取待合并的候选节点对
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
min_similarity: 最小相似度
|
min_similarity: 最小相似度
|
||||||
limit: 最大返回数量
|
limit: 最大返回数量
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of (node_id_1, node_id_2, similarity)
|
List of (node_id_1, node_id_2, similarity)
|
||||||
"""
|
"""
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -7,7 +7,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.memory_graph.manager import MemoryManager
|
from src.memory_graph.manager import MemoryManager
|
||||||
@@ -15,56 +14,56 @@ from src.memory_graph.manager import MemoryManager
|
|||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
# 全局 MemoryManager 实例
|
# 全局 MemoryManager 实例
|
||||||
_memory_manager: Optional[MemoryManager] = None
|
_memory_manager: MemoryManager | None = None
|
||||||
_initialized: bool = False
|
_initialized: bool = False
|
||||||
|
|
||||||
|
|
||||||
async def initialize_memory_manager(
|
async def initialize_memory_manager(
|
||||||
data_dir: Optional[Path | str] = None,
|
data_dir: Path | str | None = None,
|
||||||
) -> Optional[MemoryManager]:
|
) -> MemoryManager | None:
|
||||||
"""
|
"""
|
||||||
初始化全局 MemoryManager
|
初始化全局 MemoryManager
|
||||||
|
|
||||||
直接从 global_config.memory 读取配置
|
直接从 global_config.memory 读取配置
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data_dir: 数据目录(可选,默认从配置读取)
|
data_dir: 数据目录(可选,默认从配置读取)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
MemoryManager 实例,如果禁用则返回 None
|
MemoryManager 实例,如果禁用则返回 None
|
||||||
"""
|
"""
|
||||||
global _memory_manager, _initialized
|
global _memory_manager, _initialized
|
||||||
|
|
||||||
if _initialized and _memory_manager:
|
if _initialized and _memory_manager:
|
||||||
logger.info("MemoryManager 已经初始化,返回现有实例")
|
logger.info("MemoryManager 已经初始化,返回现有实例")
|
||||||
return _memory_manager
|
return _memory_manager
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
# 检查是否启用
|
# 检查是否启用
|
||||||
if not global_config.memory or not getattr(global_config.memory, 'enable', False):
|
if not global_config.memory or not getattr(global_config.memory, "enable", False):
|
||||||
logger.info("记忆图系统已在配置中禁用")
|
logger.info("记忆图系统已在配置中禁用")
|
||||||
_initialized = False
|
_initialized = False
|
||||||
_memory_manager = None
|
_memory_manager = None
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 处理数据目录
|
# 处理数据目录
|
||||||
if data_dir is None:
|
if data_dir is None:
|
||||||
data_dir = getattr(global_config.memory, 'data_dir', 'data/memory_graph')
|
data_dir = getattr(global_config.memory, "data_dir", "data/memory_graph")
|
||||||
if isinstance(data_dir, str):
|
if isinstance(data_dir, str):
|
||||||
data_dir = Path(data_dir)
|
data_dir = Path(data_dir)
|
||||||
|
|
||||||
logger.info(f"正在初始化全局 MemoryManager (data_dir={data_dir})...")
|
logger.info(f"正在初始化全局 MemoryManager (data_dir={data_dir})...")
|
||||||
|
|
||||||
_memory_manager = MemoryManager(data_dir=data_dir)
|
_memory_manager = MemoryManager(data_dir=data_dir)
|
||||||
await _memory_manager.initialize()
|
await _memory_manager.initialize()
|
||||||
|
|
||||||
_initialized = True
|
_initialized = True
|
||||||
logger.info("✅ 全局 MemoryManager 初始化成功")
|
logger.info("✅ 全局 MemoryManager 初始化成功")
|
||||||
|
|
||||||
return _memory_manager
|
return _memory_manager
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"初始化 MemoryManager 失败: {e}", exc_info=True)
|
logger.error(f"初始化 MemoryManager 失败: {e}", exc_info=True)
|
||||||
_initialized = False
|
_initialized = False
|
||||||
@@ -72,24 +71,24 @@ async def initialize_memory_manager(
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
def get_memory_manager() -> Optional[MemoryManager]:
|
def get_memory_manager() -> MemoryManager | None:
|
||||||
"""
|
"""
|
||||||
获取全局 MemoryManager 实例
|
获取全局 MemoryManager 实例
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
MemoryManager 实例,如果未初始化则返回 None
|
MemoryManager 实例,如果未初始化则返回 None
|
||||||
"""
|
"""
|
||||||
if not _initialized or _memory_manager is None:
|
if not _initialized or _memory_manager is None:
|
||||||
logger.warning("MemoryManager 尚未初始化,请先调用 initialize_memory_manager()")
|
logger.warning("MemoryManager 尚未初始化,请先调用 initialize_memory_manager()")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return _memory_manager
|
return _memory_manager
|
||||||
|
|
||||||
|
|
||||||
async def shutdown_memory_manager():
|
async def shutdown_memory_manager():
|
||||||
"""关闭全局 MemoryManager"""
|
"""关闭全局 MemoryManager"""
|
||||||
global _memory_manager, _initialized
|
global _memory_manager, _initialized
|
||||||
|
|
||||||
if _memory_manager:
|
if _memory_manager:
|
||||||
try:
|
try:
|
||||||
logger.info("正在关闭全局 MemoryManager...")
|
logger.info("正在关闭全局 MemoryManager...")
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import uuid
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -60,8 +60,8 @@ class MemoryNode:
|
|||||||
id: str # 节点唯一ID
|
id: str # 节点唯一ID
|
||||||
content: str # 节点内容(如:"我"、"吃饭"、"白米饭")
|
content: str # 节点内容(如:"我"、"吃饭"、"白米饭")
|
||||||
node_type: NodeType # 节点类型
|
node_type: NodeType # 节点类型
|
||||||
embedding: Optional[np.ndarray] = None # 语义向量(仅主题/客体需要)
|
embedding: np.ndarray | None = None # 语义向量(仅主题/客体需要)
|
||||||
metadata: Dict[str, Any] = field(default_factory=dict) # 扩展元数据
|
metadata: dict[str, Any] = field(default_factory=dict) # 扩展元数据
|
||||||
created_at: datetime = field(default_factory=datetime.now)
|
created_at: datetime = field(default_factory=datetime.now)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@@ -69,7 +69,7 @@ class MemoryNode:
|
|||||||
if not self.id:
|
if not self.id:
|
||||||
self.id = str(uuid.uuid4())
|
self.id = str(uuid.uuid4())
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
"""转换为字典(用于序列化)"""
|
"""转换为字典(用于序列化)"""
|
||||||
return {
|
return {
|
||||||
"id": self.id,
|
"id": self.id,
|
||||||
@@ -81,7 +81,7 @@ class MemoryNode:
|
|||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: Dict[str, Any]) -> MemoryNode:
|
def from_dict(cls, data: dict[str, Any]) -> MemoryNode:
|
||||||
"""从字典创建节点"""
|
"""从字典创建节点"""
|
||||||
embedding = None
|
embedding = None
|
||||||
if data.get("embedding") is not None:
|
if data.get("embedding") is not None:
|
||||||
@@ -114,7 +114,7 @@ class MemoryEdge:
|
|||||||
relation: str # 关系名称(如:"是"、"做"、"时间"、"因为")
|
relation: str # 关系名称(如:"是"、"做"、"时间"、"因为")
|
||||||
edge_type: EdgeType # 边类型
|
edge_type: EdgeType # 边类型
|
||||||
importance: float = 0.5 # 重要性 [0-1]
|
importance: float = 0.5 # 重要性 [0-1]
|
||||||
metadata: Dict[str, Any] = field(default_factory=dict) # 扩展元数据
|
metadata: dict[str, Any] = field(default_factory=dict) # 扩展元数据
|
||||||
created_at: datetime = field(default_factory=datetime.now)
|
created_at: datetime = field(default_factory=datetime.now)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@@ -124,7 +124,7 @@ class MemoryEdge:
|
|||||||
# 确保重要性在有效范围内
|
# 确保重要性在有效范围内
|
||||||
self.importance = max(0.0, min(1.0, self.importance))
|
self.importance = max(0.0, min(1.0, self.importance))
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
"""转换为字典(用于序列化)"""
|
"""转换为字典(用于序列化)"""
|
||||||
return {
|
return {
|
||||||
"id": self.id,
|
"id": self.id,
|
||||||
@@ -138,7 +138,7 @@ class MemoryEdge:
|
|||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: Dict[str, Any]) -> MemoryEdge:
|
def from_dict(cls, data: dict[str, Any]) -> MemoryEdge:
|
||||||
"""从字典创建边"""
|
"""从字典创建边"""
|
||||||
return cls(
|
return cls(
|
||||||
id=data["id"],
|
id=data["id"],
|
||||||
@@ -162,8 +162,8 @@ class Memory:
|
|||||||
id: str # 记忆唯一ID
|
id: str # 记忆唯一ID
|
||||||
subject_id: str # 主体节点ID
|
subject_id: str # 主体节点ID
|
||||||
memory_type: MemoryType # 记忆类型
|
memory_type: MemoryType # 记忆类型
|
||||||
nodes: List[MemoryNode] # 该记忆包含的所有节点
|
nodes: list[MemoryNode] # 该记忆包含的所有节点
|
||||||
edges: List[MemoryEdge] # 该记忆包含的所有边
|
edges: list[MemoryEdge] # 该记忆包含的所有边
|
||||||
importance: float = 0.5 # 整体重要性 [0-1]
|
importance: float = 0.5 # 整体重要性 [0-1]
|
||||||
activation: float = 0.0 # 激活度 [0-1],用于记忆整合和遗忘
|
activation: float = 0.0 # 激活度 [0-1],用于记忆整合和遗忘
|
||||||
status: MemoryStatus = MemoryStatus.STAGED # 记忆状态
|
status: MemoryStatus = MemoryStatus.STAGED # 记忆状态
|
||||||
@@ -171,7 +171,7 @@ class Memory:
|
|||||||
last_accessed: datetime = field(default_factory=datetime.now) # 最后访问时间
|
last_accessed: datetime = field(default_factory=datetime.now) # 最后访问时间
|
||||||
access_count: int = 0 # 访问次数
|
access_count: int = 0 # 访问次数
|
||||||
decay_factor: float = 1.0 # 衰减因子(随时间变化)
|
decay_factor: float = 1.0 # 衰减因子(随时间变化)
|
||||||
metadata: Dict[str, Any] = field(default_factory=dict) # 扩展元数据
|
metadata: dict[str, Any] = field(default_factory=dict) # 扩展元数据
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""后初始化处理"""
|
"""后初始化处理"""
|
||||||
@@ -181,7 +181,7 @@ class Memory:
|
|||||||
self.importance = max(0.0, min(1.0, self.importance))
|
self.importance = max(0.0, min(1.0, self.importance))
|
||||||
self.activation = max(0.0, min(1.0, self.activation))
|
self.activation = max(0.0, min(1.0, self.activation))
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
"""转换为字典(用于序列化)"""
|
"""转换为字典(用于序列化)"""
|
||||||
return {
|
return {
|
||||||
"id": self.id,
|
"id": self.id,
|
||||||
@@ -200,7 +200,7 @@ class Memory:
|
|||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: Dict[str, Any]) -> Memory:
|
def from_dict(cls, data: dict[str, Any]) -> Memory:
|
||||||
"""从字典创建记忆"""
|
"""从字典创建记忆"""
|
||||||
return cls(
|
return cls(
|
||||||
id=data["id"],
|
id=data["id"],
|
||||||
@@ -223,14 +223,14 @@ class Memory:
|
|||||||
self.last_accessed = datetime.now()
|
self.last_accessed = datetime.now()
|
||||||
self.access_count += 1
|
self.access_count += 1
|
||||||
|
|
||||||
def get_node_by_id(self, node_id: str) -> Optional[MemoryNode]:
|
def get_node_by_id(self, node_id: str) -> MemoryNode | None:
|
||||||
"""根据ID获取节点"""
|
"""根据ID获取节点"""
|
||||||
for node in self.nodes:
|
for node in self.nodes:
|
||||||
if node.id == node_id:
|
if node.id == node_id:
|
||||||
return node
|
return node
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_subject_node(self) -> Optional[MemoryNode]:
|
def get_subject_node(self) -> MemoryNode | None:
|
||||||
"""获取主体节点"""
|
"""获取主体节点"""
|
||||||
return self.get_node_by_id(self.subject_id)
|
return self.get_node_by_id(self.subject_id)
|
||||||
|
|
||||||
@@ -274,10 +274,10 @@ class StagedMemory:
|
|||||||
memory: Memory # 原始记忆对象
|
memory: Memory # 原始记忆对象
|
||||||
status: MemoryStatus = MemoryStatus.STAGED # 状态
|
status: MemoryStatus = MemoryStatus.STAGED # 状态
|
||||||
created_at: datetime = field(default_factory=datetime.now)
|
created_at: datetime = field(default_factory=datetime.now)
|
||||||
consolidated_at: Optional[datetime] = None # 整理时间
|
consolidated_at: datetime | None = None # 整理时间
|
||||||
merge_history: List[str] = field(default_factory=list) # 被合并的节点ID列表
|
merge_history: list[str] = field(default_factory=list) # 被合并的节点ID列表
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
"""转换为字典"""
|
"""转换为字典"""
|
||||||
return {
|
return {
|
||||||
"memory": self.memory.to_dict(),
|
"memory": self.memory.to_dict(),
|
||||||
@@ -288,7 +288,7 @@ class StagedMemory:
|
|||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: Dict[str, Any]) -> StagedMemory:
|
def from_dict(cls, data: dict[str, Any]) -> StagedMemory:
|
||||||
"""从字典创建临时记忆"""
|
"""从字典创建临时记忆"""
|
||||||
return cls(
|
return cls(
|
||||||
memory=Memory.from_dict(data["memory"]),
|
memory=Memory.from_dict(data["memory"]),
|
||||||
|
|||||||
@@ -52,16 +52,16 @@ class CreateMemoryTool(BaseTool):
|
|||||||
示例:"我最近在学Python,想找数据分析的工作"
|
示例:"我最近在学Python,想找数据分析的工作"
|
||||||
→ 调用1:{{subject:"[从历史提取真实名字]", memory_type:"事实", topic:"学习", object:"Python", attributes:{{时间:"最近", 状态:"进行中"}}, importance:0.7}}
|
→ 调用1:{{subject:"[从历史提取真实名字]", memory_type:"事实", topic:"学习", object:"Python", attributes:{{时间:"最近", 状态:"进行中"}}, importance:0.7}}
|
||||||
→ 调用2:{{subject:"[从历史提取真实名字]", memory_type:"目标", topic:"求职", object:"数据分析岗位", attributes:{{状态:"计划中"}}, importance:0.8}}"""
|
→ 调用2:{{subject:"[从历史提取真实名字]", memory_type:"目标", topic:"求职", object:"数据分析岗位", attributes:{{状态:"计划中"}}, importance:0.8}}"""
|
||||||
|
|
||||||
parameters: ClassVar[list[tuple[str, ToolParamType, str, bool, list[str] | None]]] = [
|
parameters: ClassVar[list[tuple[str, ToolParamType, str, bool, list[str] | None]]] = [
|
||||||
("subject", ToolParamType.STRING, "记忆主体(重要!)。从对话历史中提取真实发送人名字。示例:如果看到'Prou(12345678): 我喜欢...',subject应填'Prou';如果看到'张三: 我在...',subject应填'张三'。❌禁止使用'用户'这种泛指,必须用具体名字!", True, None),
|
("subject", ToolParamType.STRING, "记忆主体(重要!)。从对话历史中提取真实发送人名字。示例:如果看到'Prou(12345678): 我喜欢...',subject应填'Prou';如果看到'张三: 我在...',subject应填'张三'。❌禁止使用'用户'这种泛指,必须用具体名字!", True, None),
|
||||||
("memory_type", ToolParamType.STRING, "记忆类型。【事件】=有明确时间点的动作(昨天吃饭、明天开会)【事实】=稳定状态(职业是程序员、住在北京)【观点】=主观看法(喜欢/讨厌/认为)【关系】=人际关系(朋友、同事)", True, ["事件", "事实", "关系", "观点"]),
|
("memory_type", ToolParamType.STRING, "记忆类型。【事件】=有明确时间点的动作(昨天吃饭、明天开会)【事实】=稳定状态(职业是程序员、住在北京)【观点】=主观看法(喜欢/讨厌/认为)【关系】=人际关系(朋友、同事)", True, ["事件", "事实", "关系", "观点"]),
|
||||||
("topic", ToolParamType.STRING, "记忆的核心内容(做什么/是什么状态/什么关系)。必须明确、具体,包含关键动词或状态词", True, None),
|
("topic", ToolParamType.STRING, "记忆的核心内容(做什么/是什么状态/什么关系)。必须明确、具体,包含关键动词或状态词", True, None),
|
||||||
("object", ToolParamType.STRING, "记忆涉及的对象或目标。如果topic已经很完整可以不填,如果有明确对象建议填写", False, None),
|
("object", ToolParamType.STRING, "记忆涉及的对象或目标。如果topic已经很完整可以不填,如果有明确对象建议填写", False, None),
|
||||||
("attributes", ToolParamType.STRING, "详细属性,JSON格式字符串。强烈建议包含:时间(具体到日期和小时分钟)、地点、状态、原因等上下文信息。例:{\"时间\":\"2025-11-06 12:00\",\"地点\":\"公司\",\"状态\":\"进行中\",\"原因\":\"项目需要\"}", False, None),
|
("attributes", ToolParamType.STRING, '详细属性,JSON格式字符串。强烈建议包含:时间(具体到日期和小时分钟)、地点、状态、原因等上下文信息。例:{"时间":"2025-11-06 12:00","地点":"公司","状态":"进行中","原因":"项目需要"}', False, None),
|
||||||
("importance", ToolParamType.FLOAT, "重要性评分 0.0-1.0。参考:日常琐事0.3-0.4,一般对话0.5-0.6,重要信息0.7-0.8,核心记忆0.9-1.0。不确定时用0.5", False, None),
|
("importance", ToolParamType.FLOAT, "重要性评分 0.0-1.0。参考:日常琐事0.3-0.4,一般对话0.5-0.6,重要信息0.7-0.8,核心记忆0.9-1.0。不确定时用0.5", False, None),
|
||||||
]
|
]
|
||||||
|
|
||||||
available_for_llm = True
|
available_for_llm = True
|
||||||
|
|
||||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||||
@@ -69,20 +69,20 @@ class CreateMemoryTool(BaseTool):
|
|||||||
try:
|
try:
|
||||||
# 获取全局 memory_manager
|
# 获取全局 memory_manager
|
||||||
from src.memory_graph.manager_singleton import get_memory_manager
|
from src.memory_graph.manager_singleton import get_memory_manager
|
||||||
|
|
||||||
manager = get_memory_manager()
|
manager = get_memory_manager()
|
||||||
if not manager:
|
if not manager:
|
||||||
return {
|
return {
|
||||||
"name": self.name,
|
"name": self.name,
|
||||||
"content": "记忆系统未初始化"
|
"content": "记忆系统未初始化"
|
||||||
}
|
}
|
||||||
|
|
||||||
# 提取参数
|
# 提取参数
|
||||||
subject = function_args.get("subject", "")
|
subject = function_args.get("subject", "")
|
||||||
memory_type = function_args.get("memory_type", "")
|
memory_type = function_args.get("memory_type", "")
|
||||||
topic = function_args.get("topic", "")
|
topic = function_args.get("topic", "")
|
||||||
obj = function_args.get("object")
|
obj = function_args.get("object")
|
||||||
|
|
||||||
# 处理 attributes(可能是字符串或字典)
|
# 处理 attributes(可能是字符串或字典)
|
||||||
attributes_raw = function_args.get("attributes", {})
|
attributes_raw = function_args.get("attributes", {})
|
||||||
if isinstance(attributes_raw, str):
|
if isinstance(attributes_raw, str):
|
||||||
@@ -93,9 +93,9 @@ class CreateMemoryTool(BaseTool):
|
|||||||
attributes = {}
|
attributes = {}
|
||||||
else:
|
else:
|
||||||
attributes = attributes_raw
|
attributes = attributes_raw
|
||||||
|
|
||||||
importance = function_args.get("importance", 0.5)
|
importance = function_args.get("importance", 0.5)
|
||||||
|
|
||||||
# 创建记忆
|
# 创建记忆
|
||||||
memory = await manager.create_memory(
|
memory = await manager.create_memory(
|
||||||
subject=subject,
|
subject=subject,
|
||||||
@@ -105,7 +105,7 @@ class CreateMemoryTool(BaseTool):
|
|||||||
attributes=attributes,
|
attributes=attributes,
|
||||||
importance=importance,
|
importance=importance,
|
||||||
)
|
)
|
||||||
|
|
||||||
if memory:
|
if memory:
|
||||||
logger.info(f"[CreateMemoryTool] 成功创建记忆: {memory.id}")
|
logger.info(f"[CreateMemoryTool] 成功创建记忆: {memory.id}")
|
||||||
return {
|
return {
|
||||||
@@ -119,12 +119,12 @@ class CreateMemoryTool(BaseTool):
|
|||||||
"content": "创建记忆失败",
|
"content": "创建记忆失败",
|
||||||
"memory_id": None,
|
"memory_id": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[CreateMemoryTool] 执行失败: {e}", exc_info=True)
|
logger.error(f"[CreateMemoryTool] 执行失败: {e}", exc_info=True)
|
||||||
return {
|
return {
|
||||||
"name": self.name,
|
"name": self.name,
|
||||||
"content": f"创建记忆时出错: {str(e)}"
|
"content": f"创建记忆时出错: {e!s}"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -133,33 +133,33 @@ class LinkMemoriesTool(BaseTool):
|
|||||||
|
|
||||||
name = "link_memories"
|
name = "link_memories"
|
||||||
description = "在两个记忆之间建立关联关系。用于连接相关的记忆,形成知识网络。"
|
description = "在两个记忆之间建立关联关系。用于连接相关的记忆,形成知识网络。"
|
||||||
|
|
||||||
parameters: ClassVar[list[tuple[str, ToolParamType, str, bool, list[str] | None]]] = [
|
parameters: ClassVar[list[tuple[str, ToolParamType, str, bool, list[str] | None]]] = [
|
||||||
("source_query", ToolParamType.STRING, "源记忆的搜索查询(如记忆的主题关键词)", True, None),
|
("source_query", ToolParamType.STRING, "源记忆的搜索查询(如记忆的主题关键词)", True, None),
|
||||||
("target_query", ToolParamType.STRING, "目标记忆的搜索查询", True, None),
|
("target_query", ToolParamType.STRING, "目标记忆的搜索查询", True, None),
|
||||||
("relation", ToolParamType.STRING, "关系类型", True, ["导致", "引用", "相似", "相反", "部分"]),
|
("relation", ToolParamType.STRING, "关系类型", True, ["导致", "引用", "相似", "相反", "部分"]),
|
||||||
("strength", ToolParamType.FLOAT, "关系强度(0.0-1.0),默认0.7", False, None),
|
("strength", ToolParamType.FLOAT, "关系强度(0.0-1.0),默认0.7", False, None),
|
||||||
]
|
]
|
||||||
|
|
||||||
available_for_llm = False # 暂不对 LLM 开放
|
available_for_llm = False # 暂不对 LLM 开放
|
||||||
|
|
||||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""执行关联记忆"""
|
"""执行关联记忆"""
|
||||||
try:
|
try:
|
||||||
from src.memory_graph.manager_singleton import get_memory_manager
|
from src.memory_graph.manager_singleton import get_memory_manager
|
||||||
|
|
||||||
manager = get_memory_manager()
|
manager = get_memory_manager()
|
||||||
if not manager:
|
if not manager:
|
||||||
return {
|
return {
|
||||||
"name": self.name,
|
"name": self.name,
|
||||||
"content": "记忆系统未初始化"
|
"content": "记忆系统未初始化"
|
||||||
}
|
}
|
||||||
|
|
||||||
source_query = function_args.get("source_query", "")
|
source_query = function_args.get("source_query", "")
|
||||||
target_query = function_args.get("target_query", "")
|
target_query = function_args.get("target_query", "")
|
||||||
relation = function_args.get("relation", "引用")
|
relation = function_args.get("relation", "引用")
|
||||||
strength = function_args.get("strength", 0.7)
|
strength = function_args.get("strength", 0.7)
|
||||||
|
|
||||||
# 关联记忆
|
# 关联记忆
|
||||||
success = await manager.link_memories(
|
success = await manager.link_memories(
|
||||||
source_description=source_query,
|
source_description=source_query,
|
||||||
@@ -167,7 +167,7 @@ class LinkMemoriesTool(BaseTool):
|
|||||||
relation_type=relation,
|
relation_type=relation,
|
||||||
importance=strength,
|
importance=strength,
|
||||||
)
|
)
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
logger.info(f"[LinkMemoriesTool] 成功关联记忆: {source_query} -> {target_query}")
|
logger.info(f"[LinkMemoriesTool] 成功关联记忆: {source_query} -> {target_query}")
|
||||||
return {
|
return {
|
||||||
@@ -179,12 +179,12 @@ class LinkMemoriesTool(BaseTool):
|
|||||||
"name": self.name,
|
"name": self.name,
|
||||||
"content": "关联记忆失败,可能找不到匹配的记忆"
|
"content": "关联记忆失败,可能找不到匹配的记忆"
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[LinkMemoriesTool] 执行失败: {e}", exc_info=True)
|
logger.error(f"[LinkMemoriesTool] 执行失败: {e}", exc_info=True)
|
||||||
return {
|
return {
|
||||||
"name": self.name,
|
"name": self.name,
|
||||||
"content": f"关联记忆时出错: {str(e)}"
|
"content": f"关联记忆时出错: {e!s}"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -193,39 +193,39 @@ class SearchMemoriesTool(BaseTool):
|
|||||||
|
|
||||||
name = "search_memories"
|
name = "search_memories"
|
||||||
description = "搜索相关的记忆。根据查询词搜索记忆库,返回最相关的记忆。"
|
description = "搜索相关的记忆。根据查询词搜索记忆库,返回最相关的记忆。"
|
||||||
|
|
||||||
parameters: ClassVar[list[tuple[str, ToolParamType, str, bool, list[str] | None]]] = [
|
parameters: ClassVar[list[tuple[str, ToolParamType, str, bool, list[str] | None]]] = [
|
||||||
("query", ToolParamType.STRING, "搜索查询词,描述想要找什么样的记忆", True, None),
|
("query", ToolParamType.STRING, "搜索查询词,描述想要找什么样的记忆", True, None),
|
||||||
("top_k", ToolParamType.INTEGER, "返回的记忆数量,默认5", False, None),
|
("top_k", ToolParamType.INTEGER, "返回的记忆数量,默认5", False, None),
|
||||||
("min_importance", ToolParamType.FLOAT, "最低重要性阈值(0.0-1.0),只返回重要性不低于此值的记忆", False, None),
|
("min_importance", ToolParamType.FLOAT, "最低重要性阈值(0.0-1.0),只返回重要性不低于此值的记忆", False, None),
|
||||||
]
|
]
|
||||||
|
|
||||||
available_for_llm = False # 暂不对 LLM 开放,记忆检索在提示词构建时自动执行
|
available_for_llm = False # 暂不对 LLM 开放,记忆检索在提示词构建时自动执行
|
||||||
|
|
||||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""执行搜索记忆"""
|
"""执行搜索记忆"""
|
||||||
try:
|
try:
|
||||||
from src.memory_graph.manager_singleton import get_memory_manager
|
from src.memory_graph.manager_singleton import get_memory_manager
|
||||||
|
|
||||||
manager = get_memory_manager()
|
manager = get_memory_manager()
|
||||||
if not manager:
|
if not manager:
|
||||||
return {
|
return {
|
||||||
"name": self.name,
|
"name": self.name,
|
||||||
"content": "记忆系统未初始化"
|
"content": "记忆系统未初始化"
|
||||||
}
|
}
|
||||||
|
|
||||||
query = function_args.get("query", "")
|
query = function_args.get("query", "")
|
||||||
top_k = function_args.get("top_k", 5)
|
top_k = function_args.get("top_k", 5)
|
||||||
min_importance_raw = function_args.get("min_importance")
|
min_importance_raw = function_args.get("min_importance")
|
||||||
min_importance = float(min_importance_raw) if min_importance_raw is not None else 0.0
|
min_importance = float(min_importance_raw) if min_importance_raw is not None else 0.0
|
||||||
|
|
||||||
# 搜索记忆
|
# 搜索记忆
|
||||||
memories = await manager.search_memories(
|
memories = await manager.search_memories(
|
||||||
query=query,
|
query=query,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
min_importance=min_importance,
|
min_importance=min_importance,
|
||||||
)
|
)
|
||||||
|
|
||||||
if memories:
|
if memories:
|
||||||
# 格式化结果
|
# 格式化结果
|
||||||
result_lines = [f"找到 {len(memories)} 条相关记忆:\n"]
|
result_lines = [f"找到 {len(memories)} 条相关记忆:\n"]
|
||||||
@@ -236,10 +236,10 @@ class SearchMemoriesTool(BaseTool):
|
|||||||
result_lines.append(
|
result_lines.append(
|
||||||
f"{i}. [{mem_type}] {topic} (重要性: {importance:.2f})"
|
f"{i}. [{mem_type}] {topic} (重要性: {importance:.2f})"
|
||||||
)
|
)
|
||||||
|
|
||||||
result_text = "\n".join(result_lines)
|
result_text = "\n".join(result_lines)
|
||||||
logger.info(f"[SearchMemoriesTool] 搜索成功: 查询='{query}', 结果数={len(memories)}")
|
logger.info(f"[SearchMemoriesTool] 搜索成功: 查询='{query}', 结果数={len(memories)}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"name": self.name,
|
"name": self.name,
|
||||||
"content": result_text
|
"content": result_text
|
||||||
@@ -249,10 +249,10 @@ class SearchMemoriesTool(BaseTool):
|
|||||||
"name": self.name,
|
"name": self.name,
|
||||||
"content": f"未找到与 '{query}' 相关的记忆"
|
"content": f"未找到与 '{query}' 相关的记忆"
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[SearchMemoriesTool] 执行失败: {e}", exc_info=True)
|
logger.error(f"[SearchMemoriesTool] 执行失败: {e}", exc_info=True)
|
||||||
return {
|
return {
|
||||||
"name": self.name,
|
"name": self.name,
|
||||||
"content": f"搜索记忆时出错: {str(e)}"
|
"content": f"搜索记忆时出错: {e!s}"
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,4 +5,4 @@
|
|||||||
from src.memory_graph.storage.graph_store import GraphStore
|
from src.memory_graph.storage.graph_store import GraphStore
|
||||||
from src.memory_graph.storage.vector_store import VectorStore
|
from src.memory_graph.storage.vector_store import VectorStore
|
||||||
|
|
||||||
__all__ = ["VectorStore", "GraphStore"]
|
__all__ = ["GraphStore", "VectorStore"]
|
||||||
|
|||||||
@@ -4,12 +4,10 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Dict, List, Optional, Set, Tuple
|
|
||||||
|
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.memory_graph.models import Memory, MemoryEdge, MemoryNode
|
from src.memory_graph.models import Memory, MemoryEdge
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@@ -17,7 +15,7 @@ logger = get_logger(__name__)
|
|||||||
class GraphStore:
|
class GraphStore:
|
||||||
"""
|
"""
|
||||||
图存储封装类
|
图存储封装类
|
||||||
|
|
||||||
负责:
|
负责:
|
||||||
1. 记忆图的构建和维护
|
1. 记忆图的构建和维护
|
||||||
2. 节点和边的快速查询
|
2. 节点和边的快速查询
|
||||||
@@ -31,17 +29,17 @@ class GraphStore:
|
|||||||
self.graph = nx.DiGraph()
|
self.graph = nx.DiGraph()
|
||||||
|
|
||||||
# 索引:记忆ID -> 记忆对象
|
# 索引:记忆ID -> 记忆对象
|
||||||
self.memory_index: Dict[str, Memory] = {}
|
self.memory_index: dict[str, Memory] = {}
|
||||||
|
|
||||||
# 索引:节点ID -> 所属记忆ID集合
|
# 索引:节点ID -> 所属记忆ID集合
|
||||||
self.node_to_memories: Dict[str, Set[str]] = {}
|
self.node_to_memories: dict[str, set[str]] = {}
|
||||||
|
|
||||||
logger.info("初始化图存储")
|
logger.info("初始化图存储")
|
||||||
|
|
||||||
def add_memory(self, memory: Memory) -> None:
|
def add_memory(self, memory: Memory) -> None:
|
||||||
"""
|
"""
|
||||||
添加记忆到图
|
添加记忆到图
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
memory: 要添加的记忆
|
memory: 要添加的记忆
|
||||||
"""
|
"""
|
||||||
@@ -84,34 +82,34 @@ class GraphStore:
|
|||||||
logger.error(f"添加记忆失败: {e}", exc_info=True)
|
logger.error(f"添加记忆失败: {e}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def get_memory_by_id(self, memory_id: str) -> Optional[Memory]:
|
def get_memory_by_id(self, memory_id: str) -> Memory | None:
|
||||||
"""
|
"""
|
||||||
根据ID获取记忆
|
根据ID获取记忆
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
memory_id: 记忆ID
|
memory_id: 记忆ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
记忆对象或 None
|
记忆对象或 None
|
||||||
"""
|
"""
|
||||||
return self.memory_index.get(memory_id)
|
return self.memory_index.get(memory_id)
|
||||||
|
|
||||||
def get_all_memories(self) -> List[Memory]:
|
def get_all_memories(self) -> list[Memory]:
|
||||||
"""
|
"""
|
||||||
获取所有记忆
|
获取所有记忆
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
所有记忆的列表
|
所有记忆的列表
|
||||||
"""
|
"""
|
||||||
return list(self.memory_index.values())
|
return list(self.memory_index.values())
|
||||||
|
|
||||||
def get_memories_by_node(self, node_id: str) -> List[Memory]:
|
def get_memories_by_node(self, node_id: str) -> list[Memory]:
|
||||||
"""
|
"""
|
||||||
获取包含指定节点的所有记忆
|
获取包含指定节点的所有记忆
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node_id: 节点ID
|
node_id: 节点ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
记忆列表
|
记忆列表
|
||||||
"""
|
"""
|
||||||
@@ -121,14 +119,14 @@ class GraphStore:
|
|||||||
memory_ids = self.node_to_memories[node_id]
|
memory_ids = self.node_to_memories[node_id]
|
||||||
return [self.memory_index[mid] for mid in memory_ids if mid in self.memory_index]
|
return [self.memory_index[mid] for mid in memory_ids if mid in self.memory_index]
|
||||||
|
|
||||||
def get_edges_from_node(self, node_id: str, relation_types: Optional[List[str]] = None) -> List[Dict]:
|
def get_edges_from_node(self, node_id: str, relation_types: list[str] | None = None) -> list[dict]:
|
||||||
"""
|
"""
|
||||||
获取从指定节点出发的所有边
|
获取从指定节点出发的所有边
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node_id: 源节点ID
|
node_id: 源节点ID
|
||||||
relation_types: 关系类型过滤(可选)
|
relation_types: 关系类型过滤(可选)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
边信息列表
|
边信息列表
|
||||||
"""
|
"""
|
||||||
@@ -155,16 +153,16 @@ class GraphStore:
|
|||||||
return edges
|
return edges
|
||||||
|
|
||||||
def get_neighbors(
|
def get_neighbors(
|
||||||
self, node_id: str, direction: str = "out", relation_types: Optional[List[str]] = None
|
self, node_id: str, direction: str = "out", relation_types: list[str] | None = None
|
||||||
) -> List[Tuple[str, Dict]]:
|
) -> list[tuple[str, dict]]:
|
||||||
"""
|
"""
|
||||||
获取节点的邻居节点
|
获取节点的邻居节点
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node_id: 节点ID
|
node_id: 节点ID
|
||||||
direction: 方向 ("out"=出边, "in"=入边, "both"=双向)
|
direction: 方向 ("out"=出边, "in"=入边, "both"=双向)
|
||||||
relation_types: 关系类型过滤
|
relation_types: 关系类型过滤
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of (neighbor_id, edge_data)
|
List of (neighbor_id, edge_data)
|
||||||
"""
|
"""
|
||||||
@@ -187,15 +185,15 @@ class GraphStore:
|
|||||||
|
|
||||||
return neighbors
|
return neighbors
|
||||||
|
|
||||||
def find_path(self, source_id: str, target_id: str, max_length: Optional[int] = None) -> Optional[List[str]]:
|
def find_path(self, source_id: str, target_id: str, max_length: int | None = None) -> list[str] | None:
|
||||||
"""
|
"""
|
||||||
查找两个节点之间的最短路径
|
查找两个节点之间的最短路径
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
source_id: 源节点ID
|
source_id: 源节点ID
|
||||||
target_id: 目标节点ID
|
target_id: 目标节点ID
|
||||||
max_length: 最大路径长度(可选)
|
max_length: 最大路径长度(可选)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
路径节点ID列表,或 None(如果不存在路径)
|
路径节点ID列表,或 None(如果不存在路径)
|
||||||
"""
|
"""
|
||||||
@@ -220,18 +218,18 @@ class GraphStore:
|
|||||||
|
|
||||||
def bfs_expand(
|
def bfs_expand(
|
||||||
self,
|
self,
|
||||||
start_nodes: List[str],
|
start_nodes: list[str],
|
||||||
depth: int = 1,
|
depth: int = 1,
|
||||||
relation_types: Optional[List[str]] = None,
|
relation_types: list[str] | None = None,
|
||||||
) -> Set[str]:
|
) -> set[str]:
|
||||||
"""
|
"""
|
||||||
从起始节点进行广度优先搜索扩展
|
从起始节点进行广度优先搜索扩展
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
start_nodes: 起始节点ID列表
|
start_nodes: 起始节点ID列表
|
||||||
depth: 扩展深度
|
depth: 扩展深度
|
||||||
relation_types: 关系类型过滤
|
relation_types: 关系类型过滤
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
扩展到的所有节点ID集合
|
扩展到的所有节点ID集合
|
||||||
"""
|
"""
|
||||||
@@ -256,13 +254,13 @@ class GraphStore:
|
|||||||
|
|
||||||
return visited
|
return visited
|
||||||
|
|
||||||
def get_subgraph(self, node_ids: List[str]) -> nx.DiGraph:
|
def get_subgraph(self, node_ids: list[str]) -> nx.DiGraph:
|
||||||
"""
|
"""
|
||||||
获取包含指定节点的子图
|
获取包含指定节点的子图
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node_ids: 节点ID列表
|
node_ids: 节点ID列表
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
NetworkX 子图
|
NetworkX 子图
|
||||||
"""
|
"""
|
||||||
@@ -271,7 +269,7 @@ class GraphStore:
|
|||||||
def merge_nodes(self, source_id: str, target_id: str) -> None:
|
def merge_nodes(self, source_id: str, target_id: str) -> None:
|
||||||
"""
|
"""
|
||||||
合并两个节点(将source的所有边转移到target,然后删除source)
|
合并两个节点(将source的所有边转移到target,然后删除source)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
source_id: 源节点ID(将被删除)
|
source_id: 源节点ID(将被删除)
|
||||||
target_id: 目标节点ID(保留)
|
target_id: 目标节点ID(保留)
|
||||||
@@ -308,13 +306,13 @@ class GraphStore:
|
|||||||
logger.error(f"合并节点失败: {e}", exc_info=True)
|
logger.error(f"合并节点失败: {e}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def get_node_degree(self, node_id: str) -> Tuple[int, int]:
|
def get_node_degree(self, node_id: str) -> tuple[int, int]:
|
||||||
"""
|
"""
|
||||||
获取节点的度数
|
获取节点的度数
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node_id: 节点ID
|
node_id: 节点ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(in_degree, out_degree)
|
(in_degree, out_degree)
|
||||||
"""
|
"""
|
||||||
@@ -323,7 +321,7 @@ class GraphStore:
|
|||||||
|
|
||||||
return (self.graph.in_degree(node_id), self.graph.out_degree(node_id))
|
return (self.graph.in_degree(node_id), self.graph.out_degree(node_id))
|
||||||
|
|
||||||
def get_statistics(self) -> Dict[str, int]:
|
def get_statistics(self) -> dict[str, int]:
|
||||||
"""获取图的统计信息"""
|
"""获取图的统计信息"""
|
||||||
return {
|
return {
|
||||||
"total_nodes": self.graph.number_of_nodes(),
|
"total_nodes": self.graph.number_of_nodes(),
|
||||||
@@ -332,10 +330,10 @@ class GraphStore:
|
|||||||
"connected_components": nx.number_weakly_connected_components(self.graph),
|
"connected_components": nx.number_weakly_connected_components(self.graph),
|
||||||
}
|
}
|
||||||
|
|
||||||
def to_dict(self) -> Dict:
|
def to_dict(self) -> dict:
|
||||||
"""
|
"""
|
||||||
将图转换为字典(用于持久化)
|
将图转换为字典(用于持久化)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
图的字典表示
|
图的字典表示
|
||||||
"""
|
"""
|
||||||
@@ -356,13 +354,13 @@ class GraphStore:
|
|||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: Dict) -> GraphStore:
|
def from_dict(cls, data: dict) -> GraphStore:
|
||||||
"""
|
"""
|
||||||
从字典加载图
|
从字典加载图
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data: 图的字典表示
|
data: 图的字典表示
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
GraphStore 实例
|
GraphStore 实例
|
||||||
"""
|
"""
|
||||||
@@ -406,7 +404,6 @@ class GraphStore:
|
|||||||
规则:对于图中每条边(u, v, data),会尝试将该边注入到所有包含 u 或 v 的记忆中(避免遗漏跨记忆边)。
|
规则:对于图中每条边(u, v, data),会尝试将该边注入到所有包含 u 或 v 的记忆中(避免遗漏跨记忆边)。
|
||||||
已存在的边(通过 edge.id 检查)将不会重复添加。
|
已存在的边(通过 edge.id 检查)将不会重复添加。
|
||||||
"""
|
"""
|
||||||
from src.memory_graph.models import MemoryEdge
|
|
||||||
|
|
||||||
# 构建快速查重索引:memory_id -> set(edge_id)
|
# 构建快速查重索引:memory_id -> set(edge_id)
|
||||||
existing_edges = {mid: {e.id for e in mem.edges} for mid, mem in self.memory_index.items()}
|
existing_edges = {mid: {e.id for e in mem.edges} for mid, mem in self.memory_index.items()}
|
||||||
@@ -465,10 +462,10 @@ class GraphStore:
|
|||||||
def remove_memory(self, memory_id: str) -> bool:
|
def remove_memory(self, memory_id: str) -> bool:
|
||||||
"""
|
"""
|
||||||
从图中删除指定记忆
|
从图中删除指定记忆
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
memory_id: 要删除的记忆ID
|
memory_id: 要删除的记忆ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
是否删除成功
|
是否删除成功
|
||||||
"""
|
"""
|
||||||
@@ -477,9 +474,9 @@ class GraphStore:
|
|||||||
if memory_id not in self.memory_index:
|
if memory_id not in self.memory_index:
|
||||||
logger.warning(f"记忆不存在,无法删除: {memory_id}")
|
logger.warning(f"记忆不存在,无法删除: {memory_id}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
memory = self.memory_index[memory_id]
|
memory = self.memory_index[memory_id]
|
||||||
|
|
||||||
# 2. 从节点映射中移除此记忆
|
# 2. 从节点映射中移除此记忆
|
||||||
for node in memory.nodes:
|
for node in memory.nodes:
|
||||||
if node.id in self.node_to_memories:
|
if node.id in self.node_to_memories:
|
||||||
@@ -489,13 +486,13 @@ class GraphStore:
|
|||||||
if self.graph.has_node(node.id):
|
if self.graph.has_node(node.id):
|
||||||
self.graph.remove_node(node.id)
|
self.graph.remove_node(node.id)
|
||||||
del self.node_to_memories[node.id]
|
del self.node_to_memories[node.id]
|
||||||
|
|
||||||
# 3. 从记忆索引中移除
|
# 3. 从记忆索引中移除
|
||||||
del self.memory_index[memory_id]
|
del self.memory_index[memory_id]
|
||||||
|
|
||||||
logger.info(f"成功删除记忆: {memory_id}")
|
logger.info(f"成功删除记忆: {memory_id}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"删除记忆失败 {memory_id}: {e}", exc_info=True)
|
logger.error(f"删除记忆失败 {memory_id}: {e}", exc_info=True)
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -8,14 +8,12 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import orjson
|
import orjson
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.memory_graph.models import Memory, StagedMemory
|
from src.memory_graph.models import StagedMemory
|
||||||
from src.memory_graph.storage.graph_store import GraphStore
|
from src.memory_graph.storage.graph_store import GraphStore
|
||||||
from src.memory_graph.storage.vector_store import VectorStore
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@@ -23,7 +21,7 @@ logger = get_logger(__name__)
|
|||||||
class PersistenceManager:
|
class PersistenceManager:
|
||||||
"""
|
"""
|
||||||
持久化管理器
|
持久化管理器
|
||||||
|
|
||||||
负责:
|
负责:
|
||||||
1. 图数据的保存和加载
|
1. 图数据的保存和加载
|
||||||
2. 定期自动保存
|
2. 定期自动保存
|
||||||
@@ -39,7 +37,7 @@ class PersistenceManager:
|
|||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
初始化持久化管理器
|
初始化持久化管理器
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data_dir: 数据存储目录
|
data_dir: 数据存储目录
|
||||||
graph_file_name: 图数据文件名
|
graph_file_name: 图数据文件名
|
||||||
@@ -55,7 +53,7 @@ class PersistenceManager:
|
|||||||
self.backup_dir.mkdir(parents=True, exist_ok=True)
|
self.backup_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
self.auto_save_interval = auto_save_interval
|
self.auto_save_interval = auto_save_interval
|
||||||
self._auto_save_task: Optional[asyncio.Task] = None
|
self._auto_save_task: asyncio.Task | None = None
|
||||||
self._running = False
|
self._running = False
|
||||||
|
|
||||||
logger.info(f"初始化持久化管理器: data_dir={data_dir}")
|
logger.info(f"初始化持久化管理器: data_dir={data_dir}")
|
||||||
@@ -63,7 +61,7 @@ class PersistenceManager:
|
|||||||
async def save_graph_store(self, graph_store: GraphStore) -> None:
|
async def save_graph_store(self, graph_store: GraphStore) -> None:
|
||||||
"""
|
"""
|
||||||
保存图存储到文件
|
保存图存储到文件
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
graph_store: 图存储对象
|
graph_store: 图存储对象
|
||||||
"""
|
"""
|
||||||
@@ -95,10 +93,10 @@ class PersistenceManager:
|
|||||||
logger.error(f"保存图数据失败: {e}", exc_info=True)
|
logger.error(f"保存图数据失败: {e}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def load_graph_store(self) -> Optional[GraphStore]:
|
async def load_graph_store(self) -> GraphStore | None:
|
||||||
"""
|
"""
|
||||||
从文件加载图存储
|
从文件加载图存储
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
GraphStore 对象,如果文件不存在则返回 None
|
GraphStore 对象,如果文件不存在则返回 None
|
||||||
"""
|
"""
|
||||||
@@ -129,7 +127,7 @@ class PersistenceManager:
|
|||||||
async def save_staged_memories(self, staged_memories: list[StagedMemory]) -> None:
|
async def save_staged_memories(self, staged_memories: list[StagedMemory]) -> None:
|
||||||
"""
|
"""
|
||||||
保存临时记忆列表
|
保存临时记忆列表
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
staged_memories: 临时记忆列表
|
staged_memories: 临时记忆列表
|
||||||
"""
|
"""
|
||||||
@@ -158,7 +156,7 @@ class PersistenceManager:
|
|||||||
async def load_staged_memories(self) -> list[StagedMemory]:
|
async def load_staged_memories(self) -> list[StagedMemory]:
|
||||||
"""
|
"""
|
||||||
加载临时记忆列表
|
加载临时记忆列表
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
临时记忆列表
|
临时记忆列表
|
||||||
"""
|
"""
|
||||||
@@ -179,10 +177,10 @@ class PersistenceManager:
|
|||||||
logger.error(f"加载临时记忆失败: {e}", exc_info=True)
|
logger.error(f"加载临时记忆失败: {e}", exc_info=True)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def create_backup(self) -> Optional[Path]:
|
async def create_backup(self) -> Path | None:
|
||||||
"""
|
"""
|
||||||
创建当前数据的备份
|
创建当前数据的备份
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
备份文件路径,如果失败则返回 None
|
备份文件路径,如果失败则返回 None
|
||||||
"""
|
"""
|
||||||
@@ -208,7 +206,7 @@ class PersistenceManager:
|
|||||||
logger.error(f"创建备份失败: {e}", exc_info=True)
|
logger.error(f"创建备份失败: {e}", exc_info=True)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def _load_from_backup(self) -> Optional[GraphStore]:
|
async def _load_from_backup(self) -> GraphStore | None:
|
||||||
"""从最新的备份加载数据"""
|
"""从最新的备份加载数据"""
|
||||||
try:
|
try:
|
||||||
# 查找最新的备份文件
|
# 查找最新的备份文件
|
||||||
@@ -236,7 +234,7 @@ class PersistenceManager:
|
|||||||
async def _cleanup_old_backups(self, keep: int = 10) -> None:
|
async def _cleanup_old_backups(self, keep: int = 10) -> None:
|
||||||
"""
|
"""
|
||||||
清理旧备份,只保留最近的几个
|
清理旧备份,只保留最近的几个
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
keep: 保留的备份数量
|
keep: 保留的备份数量
|
||||||
"""
|
"""
|
||||||
@@ -254,11 +252,11 @@ class PersistenceManager:
|
|||||||
async def start_auto_save(
|
async def start_auto_save(
|
||||||
self,
|
self,
|
||||||
graph_store: GraphStore,
|
graph_store: GraphStore,
|
||||||
staged_memories_getter: callable = None,
|
staged_memories_getter: callable | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
启动自动保存任务
|
启动自动保存任务
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
graph_store: 图存储对象
|
graph_store: 图存储对象
|
||||||
staged_memories_getter: 获取临时记忆的回调函数
|
staged_memories_getter: 获取临时记忆的回调函数
|
||||||
@@ -310,7 +308,7 @@ class PersistenceManager:
|
|||||||
async def export_to_json(self, output_file: Path, graph_store: GraphStore) -> None:
|
async def export_to_json(self, output_file: Path, graph_store: GraphStore) -> None:
|
||||||
"""
|
"""
|
||||||
导出图数据到指定的 JSON 文件(用于数据迁移或分析)
|
导出图数据到指定的 JSON 文件(用于数据迁移或分析)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
output_file: 输出文件路径
|
output_file: 输出文件路径
|
||||||
graph_store: 图存储对象
|
graph_store: 图存储对象
|
||||||
@@ -334,13 +332,13 @@ class PersistenceManager:
|
|||||||
logger.error(f"导出图数据失败: {e}", exc_info=True)
|
logger.error(f"导出图数据失败: {e}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def import_from_json(self, input_file: Path) -> Optional[GraphStore]:
|
async def import_from_json(self, input_file: Path) -> GraphStore | None:
|
||||||
"""
|
"""
|
||||||
从 JSON 文件导入图数据
|
从 JSON 文件导入图数据
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_file: 输入文件路径
|
input_file: 输入文件路径
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
GraphStore 对象
|
GraphStore 对象
|
||||||
"""
|
"""
|
||||||
@@ -360,7 +358,7 @@ class PersistenceManager:
|
|||||||
def get_data_size(self) -> dict[str, int]:
|
def get_data_size(self) -> dict[str, int]:
|
||||||
"""
|
"""
|
||||||
获取数据文件的大小信息
|
获取数据文件的大小信息
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
文件大小字典(字节)
|
文件大小字典(字节)
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -4,9 +4,8 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import uuid
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -19,7 +18,7 @@ logger = get_logger(__name__)
|
|||||||
class VectorStore:
|
class VectorStore:
|
||||||
"""
|
"""
|
||||||
向量存储封装类
|
向量存储封装类
|
||||||
|
|
||||||
负责:
|
负责:
|
||||||
1. 节点的语义向量存储和检索
|
1. 节点的语义向量存储和检索
|
||||||
2. 基于相似度的向量搜索
|
2. 基于相似度的向量搜索
|
||||||
@@ -29,12 +28,12 @@ class VectorStore:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
collection_name: str = "memory_nodes",
|
collection_name: str = "memory_nodes",
|
||||||
data_dir: Optional[Path] = None,
|
data_dir: Path | None = None,
|
||||||
embedding_function: Optional[Any] = None,
|
embedding_function: Any | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
初始化向量存储
|
初始化向量存储
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
collection_name: ChromaDB 集合名称
|
collection_name: ChromaDB 集合名称
|
||||||
data_dir: 数据存储目录
|
data_dir: 数据存储目录
|
||||||
@@ -80,7 +79,7 @@ class VectorStore:
|
|||||||
async def add_node(self, node: MemoryNode) -> None:
|
async def add_node(self, node: MemoryNode) -> None:
|
||||||
"""
|
"""
|
||||||
添加节点到向量存储
|
添加节点到向量存储
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node: 要添加的节点
|
node: 要添加的节点
|
||||||
"""
|
"""
|
||||||
@@ -98,17 +97,17 @@ class VectorStore:
|
|||||||
"node_type": node.node_type.value,
|
"node_type": node.node_type.value,
|
||||||
"created_at": node.created_at.isoformat(),
|
"created_at": node.created_at.isoformat(),
|
||||||
}
|
}
|
||||||
|
|
||||||
# 处理额外的元数据,将 list 转换为 JSON 字符串
|
# 处理额外的元数据,将 list 转换为 JSON 字符串
|
||||||
for key, value in node.metadata.items():
|
for key, value in node.metadata.items():
|
||||||
if isinstance(value, (list, dict)):
|
if isinstance(value, (list, dict)):
|
||||||
import orjson
|
import orjson
|
||||||
metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode('utf-8')
|
metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode("utf-8")
|
||||||
elif isinstance(value, (str, int, float, bool)) or value is None:
|
elif isinstance(value, (str, int, float, bool)) or value is None:
|
||||||
metadata[key] = value
|
metadata[key] = value
|
||||||
else:
|
else:
|
||||||
metadata[key] = str(value)
|
metadata[key] = str(value)
|
||||||
|
|
||||||
self.collection.add(
|
self.collection.add(
|
||||||
ids=[node.id],
|
ids=[node.id],
|
||||||
embeddings=[node.embedding.tolist()],
|
embeddings=[node.embedding.tolist()],
|
||||||
@@ -122,10 +121,10 @@ class VectorStore:
|
|||||||
logger.error(f"添加节点失败: {e}", exc_info=True)
|
logger.error(f"添加节点失败: {e}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def add_nodes_batch(self, nodes: List[MemoryNode]) -> None:
|
async def add_nodes_batch(self, nodes: list[MemoryNode]) -> None:
|
||||||
"""
|
"""
|
||||||
批量添加节点
|
批量添加节点
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
nodes: 节点列表
|
nodes: 节点列表
|
||||||
"""
|
"""
|
||||||
@@ -151,13 +150,13 @@ class VectorStore:
|
|||||||
}
|
}
|
||||||
for key, value in n.metadata.items():
|
for key, value in n.metadata.items():
|
||||||
if isinstance(value, (list, dict)):
|
if isinstance(value, (list, dict)):
|
||||||
metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode('utf-8')
|
metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode("utf-8")
|
||||||
elif isinstance(value, (str, int, float, bool)) or value is None:
|
elif isinstance(value, (str, int, float, bool)) or value is None:
|
||||||
metadata[key] = value # type: ignore
|
metadata[key] = value # type: ignore
|
||||||
else:
|
else:
|
||||||
metadata[key] = str(value)
|
metadata[key] = str(value)
|
||||||
metadatas.append(metadata)
|
metadatas.append(metadata)
|
||||||
|
|
||||||
self.collection.add(
|
self.collection.add(
|
||||||
ids=[n.id for n in valid_nodes],
|
ids=[n.id for n in valid_nodes],
|
||||||
embeddings=[n.embedding.tolist() for n in valid_nodes], # type: ignore
|
embeddings=[n.embedding.tolist() for n in valid_nodes], # type: ignore
|
||||||
@@ -175,18 +174,18 @@ class VectorStore:
|
|||||||
self,
|
self,
|
||||||
query_embedding: np.ndarray,
|
query_embedding: np.ndarray,
|
||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
node_types: Optional[List[NodeType]] = None,
|
node_types: list[NodeType] | None = None,
|
||||||
min_similarity: float = 0.0,
|
min_similarity: float = 0.0,
|
||||||
) -> List[Tuple[str, float, Dict[str, Any]]]:
|
) -> list[tuple[str, float, dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
搜索相似节点
|
搜索相似节点
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query_embedding: 查询向量
|
query_embedding: 查询向量
|
||||||
limit: 返回结果数量
|
limit: 返回结果数量
|
||||||
node_types: 限制节点类型(可选)
|
node_types: 限制节点类型(可选)
|
||||||
min_similarity: 最小相似度阈值
|
min_similarity: 最小相似度阈值
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of (node_id, similarity, metadata)
|
List of (node_id, similarity, metadata)
|
||||||
"""
|
"""
|
||||||
@@ -214,7 +213,7 @@ class VectorStore:
|
|||||||
if ids is not None and len(ids) > 0 and len(ids[0]) > 0:
|
if ids is not None and len(ids) > 0 and len(ids[0]) > 0:
|
||||||
distances = results.get("distances")
|
distances = results.get("distances")
|
||||||
metadatas = results.get("metadatas")
|
metadatas = results.get("metadatas")
|
||||||
|
|
||||||
for i, node_id in enumerate(ids[0]):
|
for i, node_id in enumerate(ids[0]):
|
||||||
# ChromaDB 返回的是距离,需要转换为相似度
|
# ChromaDB 返回的是距离,需要转换为相似度
|
||||||
# 余弦距离: distance = 1 - similarity
|
# 余弦距离: distance = 1 - similarity
|
||||||
@@ -223,15 +222,15 @@ class VectorStore:
|
|||||||
|
|
||||||
if similarity >= min_similarity:
|
if similarity >= min_similarity:
|
||||||
metadata = metadatas[0][i] if metadatas is not None and len(metadatas) > 0 else {} # type: ignore
|
metadata = metadatas[0][i] if metadatas is not None and len(metadatas) > 0 else {} # type: ignore
|
||||||
|
|
||||||
# 解析 JSON 字符串回列表/字典
|
# 解析 JSON 字符串回列表/字典
|
||||||
for key, value in list(metadata.items()):
|
for key, value in list(metadata.items()):
|
||||||
if isinstance(value, str) and (value.startswith('[') or value.startswith('{')):
|
if isinstance(value, str) and (value.startswith("[") or value.startswith("{")):
|
||||||
try:
|
try:
|
||||||
metadata[key] = orjson.loads(value)
|
metadata[key] = orjson.loads(value)
|
||||||
except:
|
except Exception:
|
||||||
pass # 保持原值
|
pass # 保持原值
|
||||||
|
|
||||||
similar_nodes.append((node_id, similarity, metadata))
|
similar_nodes.append((node_id, similarity, metadata))
|
||||||
|
|
||||||
logger.debug(f"相似节点搜索: 找到 {len(similar_nodes)} 个结果")
|
logger.debug(f"相似节点搜索: 找到 {len(similar_nodes)} 个结果")
|
||||||
@@ -243,19 +242,19 @@ class VectorStore:
|
|||||||
|
|
||||||
async def search_with_multiple_queries(
|
async def search_with_multiple_queries(
|
||||||
self,
|
self,
|
||||||
query_embeddings: List[np.ndarray],
|
query_embeddings: list[np.ndarray],
|
||||||
query_weights: Optional[List[float]] = None,
|
query_weights: list[float] | None = None,
|
||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
node_types: Optional[List[NodeType]] = None,
|
node_types: list[NodeType] | None = None,
|
||||||
min_similarity: float = 0.0,
|
min_similarity: float = 0.0,
|
||||||
fusion_strategy: str = "weighted_max",
|
fusion_strategy: str = "weighted_max",
|
||||||
) -> List[Tuple[str, float, Dict[str, Any]]]:
|
) -> list[tuple[str, float, dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
多查询融合搜索
|
多查询融合搜索
|
||||||
|
|
||||||
使用多个查询向量进行搜索,然后融合结果。
|
使用多个查询向量进行搜索,然后融合结果。
|
||||||
这能解决单一查询向量无法同时关注多个关键概念的问题。
|
这能解决单一查询向量无法同时关注多个关键概念的问题。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query_embeddings: 查询向量列表
|
query_embeddings: 查询向量列表
|
||||||
query_weights: 每个查询的权重(可选,默认均等)
|
query_weights: 每个查询的权重(可选,默认均等)
|
||||||
@@ -266,7 +265,7 @@ class VectorStore:
|
|||||||
- "weighted_max": 加权最大值(推荐)
|
- "weighted_max": 加权最大值(推荐)
|
||||||
- "weighted_sum": 加权求和
|
- "weighted_sum": 加权求和
|
||||||
- "rrf": Reciprocal Rank Fusion
|
- "rrf": Reciprocal Rank Fusion
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
融合后的节点列表 [(node_id, fused_score, metadata), ...]
|
融合后的节点列表 [(node_id, fused_score, metadata), ...]
|
||||||
"""
|
"""
|
||||||
@@ -279,7 +278,7 @@ class VectorStore:
|
|||||||
# 默认权重均等
|
# 默认权重均等
|
||||||
if query_weights is None:
|
if query_weights is None:
|
||||||
query_weights = [1.0 / len(query_embeddings)] * len(query_embeddings)
|
query_weights = [1.0 / len(query_embeddings)] * len(query_embeddings)
|
||||||
|
|
||||||
# 归一化权重
|
# 归一化权重
|
||||||
total_weight = sum(query_weights)
|
total_weight = sum(query_weights)
|
||||||
if total_weight > 0:
|
if total_weight > 0:
|
||||||
@@ -287,7 +286,7 @@ class VectorStore:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# 1. 对每个查询执行搜索
|
# 1. 对每个查询执行搜索
|
||||||
all_results: Dict[str, Dict[str, Any]] = {} # node_id -> {scores, metadata}
|
all_results: dict[str, dict[str, Any]] = {} # node_id -> {scores, metadata}
|
||||||
|
|
||||||
for i, (query_emb, weight) in enumerate(zip(query_embeddings, query_weights)):
|
for i, (query_emb, weight) in enumerate(zip(query_embeddings, query_weights)):
|
||||||
# 搜索更多结果以提高融合质量
|
# 搜索更多结果以提高融合质量
|
||||||
@@ -307,13 +306,13 @@ class VectorStore:
|
|||||||
"ranks": [],
|
"ranks": [],
|
||||||
"metadata": metadata,
|
"metadata": metadata,
|
||||||
}
|
}
|
||||||
|
|
||||||
all_results[node_id]["scores"].append((similarity, weight))
|
all_results[node_id]["scores"].append((similarity, weight))
|
||||||
all_results[node_id]["ranks"].append((rank, weight))
|
all_results[node_id]["ranks"].append((rank, weight))
|
||||||
|
|
||||||
# 2. 融合分数
|
# 2. 融合分数
|
||||||
fused_results = []
|
fused_results = []
|
||||||
|
|
||||||
for node_id, data in all_results.items():
|
for node_id, data in all_results.items():
|
||||||
scores = data["scores"]
|
scores = data["scores"]
|
||||||
ranks = data["ranks"]
|
ranks = data["ranks"]
|
||||||
@@ -356,13 +355,13 @@ class VectorStore:
|
|||||||
logger.error(f"多查询融合搜索失败: {e}", exc_info=True)
|
logger.error(f"多查询融合搜索失败: {e}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def get_node_by_id(self, node_id: str) -> Optional[Dict[str, Any]]:
|
async def get_node_by_id(self, node_id: str) -> dict[str, Any] | None:
|
||||||
"""
|
"""
|
||||||
根据ID获取节点元数据
|
根据ID获取节点元数据
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node_id: 节点ID
|
node_id: 节点ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
节点元数据或 None
|
节点元数据或 None
|
||||||
"""
|
"""
|
||||||
@@ -378,7 +377,7 @@ class VectorStore:
|
|||||||
if ids is not None and len(ids) > 0:
|
if ids is not None and len(ids) > 0:
|
||||||
metadatas = result.get("metadatas")
|
metadatas = result.get("metadatas")
|
||||||
embeddings = result.get("embeddings")
|
embeddings = result.get("embeddings")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"id": ids[0],
|
"id": ids[0],
|
||||||
"metadata": metadatas[0] if metadatas is not None and len(metadatas) > 0 else {},
|
"metadata": metadatas[0] if metadatas is not None and len(metadatas) > 0 else {},
|
||||||
@@ -394,7 +393,7 @@ class VectorStore:
|
|||||||
async def delete_node(self, node_id: str) -> None:
|
async def delete_node(self, node_id: str) -> None:
|
||||||
"""
|
"""
|
||||||
删除节点
|
删除节点
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node_id: 节点ID
|
node_id: 节点ID
|
||||||
"""
|
"""
|
||||||
@@ -412,7 +411,7 @@ class VectorStore:
|
|||||||
async def update_node_embedding(self, node_id: str, embedding: np.ndarray) -> None:
|
async def update_node_embedding(self, node_id: str, embedding: np.ndarray) -> None:
|
||||||
"""
|
"""
|
||||||
更新节点的 embedding
|
更新节点的 embedding
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node_id: 节点ID
|
node_id: 节点ID
|
||||||
embedding: 新的向量
|
embedding: 新的向量
|
||||||
|
|||||||
@@ -4,12 +4,12 @@ LLM 工具接口:定义记忆系统的工具 schema 和执行逻辑
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.memory_graph.core.builder import MemoryBuilder
|
from src.memory_graph.core.builder import MemoryBuilder
|
||||||
from src.memory_graph.core.extractor import MemoryExtractor
|
from src.memory_graph.core.extractor import MemoryExtractor
|
||||||
from src.memory_graph.models import Memory, MemoryStatus
|
from src.memory_graph.models import Memory
|
||||||
from src.memory_graph.storage.graph_store import GraphStore
|
from src.memory_graph.storage.graph_store import GraphStore
|
||||||
from src.memory_graph.storage.persistence import PersistenceManager
|
from src.memory_graph.storage.persistence import PersistenceManager
|
||||||
from src.memory_graph.storage.vector_store import VectorStore
|
from src.memory_graph.storage.vector_store import VectorStore
|
||||||
@@ -21,7 +21,7 @@ logger = get_logger(__name__)
|
|||||||
class MemoryTools:
|
class MemoryTools:
|
||||||
"""
|
"""
|
||||||
记忆系统工具集
|
记忆系统工具集
|
||||||
|
|
||||||
提供给 LLM 使用的工具接口:
|
提供给 LLM 使用的工具接口:
|
||||||
1. create_memory: 创建新记忆
|
1. create_memory: 创建新记忆
|
||||||
2. link_memories: 关联两个记忆
|
2. link_memories: 关联两个记忆
|
||||||
@@ -33,7 +33,7 @@ class MemoryTools:
|
|||||||
vector_store: VectorStore,
|
vector_store: VectorStore,
|
||||||
graph_store: GraphStore,
|
graph_store: GraphStore,
|
||||||
persistence_manager: PersistenceManager,
|
persistence_manager: PersistenceManager,
|
||||||
embedding_generator: Optional[EmbeddingGenerator] = None,
|
embedding_generator: EmbeddingGenerator | None = None,
|
||||||
max_expand_depth: int = 1,
|
max_expand_depth: int = 1,
|
||||||
expand_semantic_threshold: float = 0.3,
|
expand_semantic_threshold: float = 0.3,
|
||||||
):
|
):
|
||||||
@@ -72,10 +72,10 @@ class MemoryTools:
|
|||||||
self._initialized = True
|
self._initialized = True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_create_memory_schema() -> Dict[str, Any]:
|
def get_create_memory_schema() -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
获取 create_memory 工具的 JSON schema
|
获取 create_memory 工具的 JSON schema
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
工具 schema 定义
|
工具 schema 定义
|
||||||
"""
|
"""
|
||||||
@@ -145,15 +145,15 @@ class MemoryTools:
|
|||||||
"description": "时间信息(强烈建议填写):\n- 具体日期:'2025-11-05'、'2025年11月'\n- 相对时间:'今天'、'昨天'、'上周'、'最近'、'3天前'\n- 时间段:'今天下午'、'上个月'、'这学期'",
|
"description": "时间信息(强烈建议填写):\n- 具体日期:'2025-11-05'、'2025年11月'\n- 相对时间:'今天'、'昨天'、'上周'、'最近'、'3天前'\n- 时间段:'今天下午'、'上个月'、'这学期'",
|
||||||
},
|
},
|
||||||
"地点": {
|
"地点": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "地点信息(如涉及):\n- 具体地址、城市名、国家\n- 场所类型:'在家'、'公司'、'学校'、'咖啡店'"
|
"description": "地点信息(如涉及):\n- 具体地址、城市名、国家\n- 场所类型:'在家'、'公司'、'学校'、'咖啡店'"
|
||||||
},
|
},
|
||||||
"原因": {
|
"原因": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "为什么这样做/这样想(如明确提到)"
|
"description": "为什么这样做/这样想(如明确提到)"
|
||||||
},
|
},
|
||||||
"方式": {
|
"方式": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "怎么做的/通过什么方式(如明确提到)"
|
"description": "怎么做的/通过什么方式(如明确提到)"
|
||||||
},
|
},
|
||||||
"结果": {
|
"结果": {
|
||||||
@@ -183,10 +183,10 @@ class MemoryTools:
|
|||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_link_memories_schema() -> Dict[str, Any]:
|
def get_link_memories_schema() -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
获取 link_memories 工具的 JSON schema
|
获取 link_memories 工具的 JSON schema
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
工具 schema 定义
|
工具 schema 定义
|
||||||
"""
|
"""
|
||||||
@@ -239,10 +239,10 @@ class MemoryTools:
|
|||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_search_memories_schema() -> Dict[str, Any]:
|
def get_search_memories_schema() -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
获取 search_memories 工具的 JSON schema
|
获取 search_memories 工具的 JSON schema
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
工具 schema 定义
|
工具 schema 定义
|
||||||
"""
|
"""
|
||||||
@@ -307,13 +307,13 @@ class MemoryTools:
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
async def create_memory(self, **params) -> Dict[str, Any]:
|
async def create_memory(self, **params) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
执行 create_memory 工具
|
执行 create_memory 工具
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
**params: 工具参数
|
**params: 工具参数
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
执行结果
|
执行结果
|
||||||
"""
|
"""
|
||||||
@@ -353,13 +353,13 @@ class MemoryTools:
|
|||||||
"message": "记忆创建失败",
|
"message": "记忆创建失败",
|
||||||
}
|
}
|
||||||
|
|
||||||
async def link_memories(self, **params) -> Dict[str, Any]:
|
async def link_memories(self, **params) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
执行 link_memories 工具
|
执行 link_memories 工具
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
**params: 工具参数
|
**params: 工具参数
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
执行结果
|
执行结果
|
||||||
"""
|
"""
|
||||||
@@ -433,15 +433,15 @@ class MemoryTools:
|
|||||||
"message": "记忆关联失败",
|
"message": "记忆关联失败",
|
||||||
}
|
}
|
||||||
|
|
||||||
async def search_memories(self, **params) -> Dict[str, Any]:
|
async def search_memories(self, **params) -> dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
执行 search_memories 工具
|
执行 search_memories 工具
|
||||||
|
|
||||||
使用多策略检索优化:
|
使用多策略检索优化:
|
||||||
1. 查询分解(识别主要实体和概念)
|
1. 查询分解(识别主要实体和概念)
|
||||||
2. 多查询并行检索
|
2. 多查询并行检索
|
||||||
3. 结果融合和重排
|
3. 结果融合和重排
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
**params: 工具参数
|
**params: 工具参数
|
||||||
- query: 查询字符串
|
- query: 查询字符串
|
||||||
@@ -449,7 +449,7 @@ class MemoryTools:
|
|||||||
- expand_depth: 扩展深度(暂未使用)
|
- expand_depth: 扩展深度(暂未使用)
|
||||||
- use_multi_query: 是否使用多查询策略(默认True)
|
- use_multi_query: 是否使用多查询策略(默认True)
|
||||||
- context: 查询上下文(可选)
|
- context: 查询上下文(可选)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
搜索结果
|
搜索结果
|
||||||
"""
|
"""
|
||||||
@@ -477,7 +477,7 @@ class MemoryTools:
|
|||||||
# 2. 提取初始记忆ID(来自向量搜索)
|
# 2. 提取初始记忆ID(来自向量搜索)
|
||||||
initial_memory_ids = set()
|
initial_memory_ids = set()
|
||||||
memory_scores = {} # 记录每个记忆的初始分数
|
memory_scores = {} # 记录每个记忆的初始分数
|
||||||
|
|
||||||
for node_id, similarity, metadata in similar_nodes:
|
for node_id, similarity, metadata in similar_nodes:
|
||||||
if "memory_ids" in metadata:
|
if "memory_ids" in metadata:
|
||||||
ids = metadata["memory_ids"]
|
ids = metadata["memory_ids"]
|
||||||
@@ -486,7 +486,7 @@ class MemoryTools:
|
|||||||
import orjson
|
import orjson
|
||||||
try:
|
try:
|
||||||
ids = orjson.loads(ids)
|
ids = orjson.loads(ids)
|
||||||
except:
|
except Exception:
|
||||||
ids = [ids]
|
ids = [ids]
|
||||||
if isinstance(ids, list):
|
if isinstance(ids, list):
|
||||||
for mem_id in ids:
|
for mem_id in ids:
|
||||||
@@ -499,12 +499,12 @@ class MemoryTools:
|
|||||||
expanded_memory_scores = {}
|
expanded_memory_scores = {}
|
||||||
if expand_depth > 0 and initial_memory_ids:
|
if expand_depth > 0 and initial_memory_ids:
|
||||||
logger.info(f"开始图扩展: 初始记忆{len(initial_memory_ids)}个, 深度={expand_depth}")
|
logger.info(f"开始图扩展: 初始记忆{len(initial_memory_ids)}个, 深度={expand_depth}")
|
||||||
|
|
||||||
# 获取查询的embedding用于语义过滤
|
# 获取查询的embedding用于语义过滤
|
||||||
if self.builder.embedding_generator:
|
if self.builder.embedding_generator:
|
||||||
try:
|
try:
|
||||||
query_embedding = await self.builder.embedding_generator.generate(query)
|
query_embedding = await self.builder.embedding_generator.generate(query)
|
||||||
|
|
||||||
# 直接使用图扩展逻辑(避免循环依赖)
|
# 直接使用图扩展逻辑(避免循环依赖)
|
||||||
expanded_results = await self._expand_with_semantic_filter(
|
expanded_results = await self._expand_with_semantic_filter(
|
||||||
initial_memory_ids=list(initial_memory_ids),
|
initial_memory_ids=list(initial_memory_ids),
|
||||||
@@ -513,7 +513,7 @@ class MemoryTools:
|
|||||||
semantic_threshold=self.expand_semantic_threshold, # 使用配置的阈值
|
semantic_threshold=self.expand_semantic_threshold, # 使用配置的阈值
|
||||||
max_expanded=top_k * 2
|
max_expanded=top_k * 2
|
||||||
)
|
)
|
||||||
|
|
||||||
# 旧代码(如果需要使用Manager):
|
# 旧代码(如果需要使用Manager):
|
||||||
# from src.memory_graph.manager import MemoryManager
|
# from src.memory_graph.manager import MemoryManager
|
||||||
# manager = MemoryManager.get_instance()
|
# manager = MemoryManager.get_instance()
|
||||||
@@ -524,19 +524,18 @@ class MemoryTools:
|
|||||||
# semantic_threshold=0.5,
|
# semantic_threshold=0.5,
|
||||||
# max_expanded=top_k * 2
|
# max_expanded=top_k * 2
|
||||||
# )
|
# )
|
||||||
|
|
||||||
# 合并扩展结果
|
# 合并扩展结果
|
||||||
for mem_id, score in expanded_results:
|
expanded_memory_scores.update(dict(expanded_results))
|
||||||
expanded_memory_scores[mem_id] = score
|
|
||||||
|
|
||||||
logger.info(f"图扩展完成: 新增{len(expanded_memory_scores)}个相关记忆")
|
logger.info(f"图扩展完成: 新增{len(expanded_memory_scores)}个相关记忆")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"图扩展失败: {e}")
|
logger.warning(f"图扩展失败: {e}")
|
||||||
|
|
||||||
# 4. 合并初始记忆和扩展记忆
|
# 4. 合并初始记忆和扩展记忆
|
||||||
all_memory_ids = set(initial_memory_ids) | set(expanded_memory_scores.keys())
|
all_memory_ids = set(initial_memory_ids) | set(expanded_memory_scores.keys())
|
||||||
|
|
||||||
# 计算最终分数:初始记忆保持原分数,扩展记忆使用扩展分数
|
# 计算最终分数:初始记忆保持原分数,扩展记忆使用扩展分数
|
||||||
final_scores = {}
|
final_scores = {}
|
||||||
for mem_id in all_memory_ids:
|
for mem_id in all_memory_ids:
|
||||||
@@ -546,7 +545,7 @@ class MemoryTools:
|
|||||||
elif mem_id in expanded_memory_scores:
|
elif mem_id in expanded_memory_scores:
|
||||||
# 扩展记忆:使用图扩展分数(稍微降权)
|
# 扩展记忆:使用图扩展分数(稍微降权)
|
||||||
final_scores[mem_id] = expanded_memory_scores[mem_id] * 0.8
|
final_scores[mem_id] = expanded_memory_scores[mem_id] * 0.8
|
||||||
|
|
||||||
# 按分数排序
|
# 按分数排序
|
||||||
sorted_memory_ids = sorted(
|
sorted_memory_ids = sorted(
|
||||||
final_scores.keys(),
|
final_scores.keys(),
|
||||||
@@ -562,7 +561,7 @@ class MemoryTools:
|
|||||||
# 综合评分:相似度(60%) + 重要性(30%) + 时效性(10%)
|
# 综合评分:相似度(60%) + 重要性(30%) + 时效性(10%)
|
||||||
similarity_score = final_scores[memory_id]
|
similarity_score = final_scores[memory_id]
|
||||||
importance_score = memory.importance
|
importance_score = memory.importance
|
||||||
|
|
||||||
# 计算时效性分数(最近的记忆得分更高)
|
# 计算时效性分数(最近的记忆得分更高)
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
@@ -573,16 +572,16 @@ class MemoryTools:
|
|||||||
memory_time = memory.created_at
|
memory_time = memory.created_at
|
||||||
age_days = (now - memory_time).total_seconds() / 86400
|
age_days = (now - memory_time).total_seconds() / 86400
|
||||||
recency_score = 1.0 / (1.0 + age_days / 30) # 30天半衰期
|
recency_score = 1.0 / (1.0 + age_days / 30) # 30天半衰期
|
||||||
|
|
||||||
# 综合分数
|
# 综合分数
|
||||||
final_score = (
|
final_score = (
|
||||||
similarity_score * 0.6 +
|
similarity_score * 0.6 +
|
||||||
importance_score * 0.3 +
|
importance_score * 0.3 +
|
||||||
recency_score * 0.1
|
recency_score * 0.1
|
||||||
)
|
)
|
||||||
|
|
||||||
memories_with_scores.append((memory, final_score))
|
memories_with_scores.append((memory, final_score))
|
||||||
|
|
||||||
# 按综合分数排序
|
# 按综合分数排序
|
||||||
memories_with_scores.sort(key=lambda x: x[1], reverse=True)
|
memories_with_scores.sort(key=lambda x: x[1], reverse=True)
|
||||||
memories = [mem for mem, _ in memories_with_scores[:top_k]]
|
memories = [mem for mem, _ in memories_with_scores[:top_k]]
|
||||||
@@ -624,16 +623,16 @@ class MemoryTools:
|
|||||||
}
|
}
|
||||||
|
|
||||||
async def _generate_multi_queries_simple(
|
async def _generate_multi_queries_simple(
|
||||||
self, query: str, context: Optional[Dict[str, Any]] = None
|
self, query: str, context: dict[str, Any] | None = None
|
||||||
) -> List[Tuple[str, float]]:
|
) -> list[tuple[str, float]]:
|
||||||
"""
|
"""
|
||||||
简化版多查询生成(直接在 Tools 层实现,避免循环依赖)
|
简化版多查询生成(直接在 Tools 层实现,避免循环依赖)
|
||||||
|
|
||||||
让小模型直接生成3-5个不同角度的查询语句。
|
让小模型直接生成3-5个不同角度的查询语句。
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from src.llm_models.utils_model import LLMRequest
|
|
||||||
from src.config.config import model_config
|
from src.config.config import model_config
|
||||||
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
|
||||||
llm = LLMRequest(
|
llm = LLMRequest(
|
||||||
model_set=model_config.model_task_config.utils_small,
|
model_set=model_config.model_task_config.utils_small,
|
||||||
@@ -648,10 +647,10 @@ class MemoryTools:
|
|||||||
# 处理聊天历史,提取最近5条左右的对话
|
# 处理聊天历史,提取最近5条左右的对话
|
||||||
recent_chat = ""
|
recent_chat = ""
|
||||||
if chat_history:
|
if chat_history:
|
||||||
lines = chat_history.strip().split('\n')
|
lines = chat_history.strip().split("\n")
|
||||||
# 取最近5条消息
|
# 取最近5条消息
|
||||||
recent_lines = lines[-5:] if len(lines) > 5 else lines
|
recent_lines = lines[-5:] if len(lines) > 5 else lines
|
||||||
recent_chat = '\n'.join(recent_lines)
|
recent_chat = "\n".join(recent_lines)
|
||||||
|
|
||||||
prompt = f"""基于聊天上下文为查询生成3-5个不同角度的搜索语句(JSON格式)。
|
prompt = f"""基于聊天上下文为查询生成3-5个不同角度的搜索语句(JSON格式)。
|
||||||
|
|
||||||
@@ -685,36 +684,38 @@ class MemoryTools:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
response, _ = await llm.generate_response_async(prompt, temperature=0.3, max_tokens=250)
|
response, _ = await llm.generate_response_async(prompt, temperature=0.3, max_tokens=250)
|
||||||
|
|
||||||
import orjson, re
|
import re
|
||||||
response = re.sub(r'```json\s*', '', response)
|
|
||||||
response = re.sub(r'```\s*$', '', response).strip()
|
import orjson
|
||||||
|
response = re.sub(r"```json\s*", "", response)
|
||||||
|
response = re.sub(r"```\s*$", "", response).strip()
|
||||||
|
|
||||||
data = orjson.loads(response)
|
data = orjson.loads(response)
|
||||||
queries = data.get("queries", [])
|
queries = data.get("queries", [])
|
||||||
|
|
||||||
result = [(item.get("text", "").strip(), float(item.get("weight", 0.5)))
|
result = [(item.get("text", "").strip(), float(item.get("weight", 0.5)))
|
||||||
for item in queries if item.get("text", "").strip()]
|
for item in queries if item.get("text", "").strip()]
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
logger.info(f"生成查询: {[q for q, _ in result]}")
|
logger.info(f"生成查询: {[q for q, _ in result]}")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"多查询生成失败: {e}")
|
logger.warning(f"多查询生成失败: {e}")
|
||||||
|
|
||||||
return [(query, 1.0)]
|
return [(query, 1.0)]
|
||||||
|
|
||||||
async def _single_query_search(
|
async def _single_query_search(
|
||||||
self, query: str, top_k: int
|
self, query: str, top_k: int
|
||||||
) -> List[Tuple[str, float, Dict[str, Any]]]:
|
) -> list[tuple[str, float, dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
传统的单查询搜索
|
传统的单查询搜索
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: 查询字符串
|
query: 查询字符串
|
||||||
top_k: 返回结果数
|
top_k: 返回结果数
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
相似节点列表 [(node_id, similarity, metadata), ...]
|
相似节点列表 [(node_id, similarity, metadata), ...]
|
||||||
"""
|
"""
|
||||||
@@ -735,30 +736,30 @@ class MemoryTools:
|
|||||||
return similar_nodes
|
return similar_nodes
|
||||||
|
|
||||||
async def _multi_query_search(
|
async def _multi_query_search(
|
||||||
self, query: str, top_k: int, context: Optional[Dict[str, Any]] = None
|
self, query: str, top_k: int, context: dict[str, Any] | None = None
|
||||||
) -> List[Tuple[str, float, Dict[str, Any]]]:
|
) -> list[tuple[str, float, dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
多查询策略搜索(简化版)
|
多查询策略搜索(简化版)
|
||||||
|
|
||||||
直接使用小模型生成多个查询,无需复杂的分解和组合。
|
直接使用小模型生成多个查询,无需复杂的分解和组合。
|
||||||
|
|
||||||
步骤:
|
步骤:
|
||||||
1. 让小模型生成3-5个不同角度的查询
|
1. 让小模型生成3-5个不同角度的查询
|
||||||
2. 为每个查询生成嵌入
|
2. 为每个查询生成嵌入
|
||||||
3. 并行搜索并融合结果
|
3. 并行搜索并融合结果
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: 查询字符串
|
query: 查询字符串
|
||||||
top_k: 返回结果数
|
top_k: 返回结果数
|
||||||
context: 查询上下文
|
context: 查询上下文
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
融合后的相似节点列表
|
融合后的相似节点列表
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 1. 使用小模型生成多个查询
|
# 1. 使用小模型生成多个查询
|
||||||
multi_queries = await self._generate_multi_queries_simple(query, context)
|
multi_queries = await self._generate_multi_queries_simple(query, context)
|
||||||
|
|
||||||
logger.debug(f"生成 {len(multi_queries)} 个查询: {multi_queries}")
|
logger.debug(f"生成 {len(multi_queries)} 个查询: {multi_queries}")
|
||||||
|
|
||||||
# 2. 生成所有查询的嵌入
|
# 2. 生成所有查询的嵌入
|
||||||
@@ -800,13 +801,13 @@ class MemoryTools:
|
|||||||
if node.embedding is not None:
|
if node.embedding is not None:
|
||||||
await self.vector_store.add_node(node)
|
await self.vector_store.add_node(node)
|
||||||
|
|
||||||
async def _find_memory_by_description(self, description: str) -> Optional[Memory]:
|
async def _find_memory_by_description(self, description: str) -> Memory | None:
|
||||||
"""
|
"""
|
||||||
通过描述查找记忆
|
通过描述查找记忆
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
description: 记忆描述
|
description: 记忆描述
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
找到的记忆,如果没有则返回 None
|
找到的记忆,如果没有则返回 None
|
||||||
"""
|
"""
|
||||||
@@ -827,13 +828,13 @@ class MemoryTools:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# 获取最相似节点关联的记忆
|
# 获取最相似节点关联的记忆
|
||||||
node_id, similarity, metadata = similar_nodes[0]
|
_node_id, _similarity, metadata = similar_nodes[0]
|
||||||
|
|
||||||
if "memory_ids" not in metadata or not metadata["memory_ids"]:
|
if "memory_ids" not in metadata or not metadata["memory_ids"]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
ids = metadata["memory_ids"]
|
ids = metadata["memory_ids"]
|
||||||
|
|
||||||
# 确保是列表
|
# 确保是列表
|
||||||
if isinstance(ids, str):
|
if isinstance(ids, str):
|
||||||
import orjson
|
import orjson
|
||||||
@@ -842,11 +843,11 @@ class MemoryTools:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"JSON 解析失败: {e}")
|
logger.warning(f"JSON 解析失败: {e}")
|
||||||
ids = [ids]
|
ids = [ids]
|
||||||
|
|
||||||
if isinstance(ids, list) and ids:
|
if isinstance(ids, list) and ids:
|
||||||
memory_id = ids[0]
|
memory_id = ids[0]
|
||||||
return self.graph_store.get_memory_by_id(memory_id)
|
return self.graph_store.get_memory_by_id(memory_id)
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _summarize_memory(self, memory: Memory) -> str:
|
def _summarize_memory(self, memory: Memory) -> str:
|
||||||
@@ -862,103 +863,102 @@ class MemoryTools:
|
|||||||
|
|
||||||
async def _expand_with_semantic_filter(
|
async def _expand_with_semantic_filter(
|
||||||
self,
|
self,
|
||||||
initial_memory_ids: List[str],
|
initial_memory_ids: list[str],
|
||||||
query_embedding,
|
query_embedding,
|
||||||
max_depth: int = 2,
|
max_depth: int = 2,
|
||||||
semantic_threshold: float = 0.5,
|
semantic_threshold: float = 0.5,
|
||||||
max_expanded: int = 20
|
max_expanded: int = 20
|
||||||
) -> List[Tuple[str, float]]:
|
) -> list[tuple[str, float]]:
|
||||||
"""
|
"""
|
||||||
从初始记忆集合出发,沿图结构扩展,并用语义相似度过滤
|
从初始记忆集合出发,沿图结构扩展,并用语义相似度过滤
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
initial_memory_ids: 初始记忆ID集合
|
initial_memory_ids: 初始记忆ID集合
|
||||||
query_embedding: 查询向量
|
query_embedding: 查询向量
|
||||||
max_depth: 最大扩展深度
|
max_depth: 最大扩展深度
|
||||||
semantic_threshold: 语义相似度阈值
|
semantic_threshold: 语义相似度阈值
|
||||||
max_expanded: 最多扩展多少个记忆
|
max_expanded: 最多扩展多少个记忆
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[(memory_id, relevance_score)]
|
List[(memory_id, relevance_score)]
|
||||||
"""
|
"""
|
||||||
if not initial_memory_ids or query_embedding is None:
|
if not initial_memory_ids or query_embedding is None:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
visited_memories = set(initial_memory_ids)
|
visited_memories = set(initial_memory_ids)
|
||||||
expanded_memories: Dict[str, float] = {}
|
expanded_memories: dict[str, float] = {}
|
||||||
|
|
||||||
current_level = initial_memory_ids
|
current_level = initial_memory_ids
|
||||||
|
|
||||||
for depth in range(max_depth):
|
for depth in range(max_depth):
|
||||||
next_level = []
|
next_level = []
|
||||||
|
|
||||||
for memory_id in current_level:
|
for memory_id in current_level:
|
||||||
memory = self.graph_store.get_memory_by_id(memory_id)
|
memory = self.graph_store.get_memory_by_id(memory_id)
|
||||||
if not memory:
|
if not memory:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for node in memory.nodes:
|
for node in memory.nodes:
|
||||||
if not node.has_embedding():
|
if not node.has_embedding():
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
neighbors = list(self.graph_store.graph.neighbors(node.id))
|
neighbors = list(self.graph_store.graph.neighbors(node.id))
|
||||||
except:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for neighbor_id in neighbors:
|
for neighbor_id in neighbors:
|
||||||
neighbor_node_data = self.graph_store.graph.nodes.get(neighbor_id)
|
neighbor_node_data = self.graph_store.graph.nodes.get(neighbor_id)
|
||||||
if not neighbor_node_data:
|
if not neighbor_node_data:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
neighbor_vector_data = await self.vector_store.get_node_by_id(neighbor_id)
|
neighbor_vector_data = await self.vector_store.get_node_by_id(neighbor_id)
|
||||||
if neighbor_vector_data is None:
|
if neighbor_vector_data is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
neighbor_embedding = neighbor_vector_data.get("embedding")
|
neighbor_embedding = neighbor_vector_data.get("embedding")
|
||||||
if neighbor_embedding is None:
|
if neighbor_embedding is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 计算语义相似度
|
# 计算语义相似度
|
||||||
semantic_sim = self._cosine_similarity(
|
semantic_sim = self._cosine_similarity(
|
||||||
query_embedding,
|
query_embedding,
|
||||||
neighbor_embedding
|
neighbor_embedding
|
||||||
)
|
)
|
||||||
|
|
||||||
# 获取边权重
|
# 获取边权重
|
||||||
try:
|
try:
|
||||||
edge_data = self.graph_store.graph.get_edge_data(node.id, neighbor_id)
|
edge_data = self.graph_store.graph.get_edge_data(node.id, neighbor_id)
|
||||||
edge_importance = edge_data.get("importance", 0.5) if edge_data else 0.5
|
edge_importance = edge_data.get("importance", 0.5) if edge_data else 0.5
|
||||||
except:
|
except Exception:
|
||||||
edge_importance = 0.5
|
edge_importance = 0.5
|
||||||
|
|
||||||
# 综合评分
|
# 综合评分
|
||||||
depth_decay = 1.0 / (depth + 1)
|
depth_decay = 1.0 / (depth + 1)
|
||||||
relevance_score = (
|
relevance_score = (
|
||||||
semantic_sim * 0.7 +
|
semantic_sim * 0.7 +
|
||||||
edge_importance * 0.2 +
|
edge_importance * 0.2 +
|
||||||
depth_decay * 0.1
|
depth_decay * 0.1
|
||||||
)
|
)
|
||||||
|
|
||||||
if relevance_score < semantic_threshold:
|
if relevance_score < semantic_threshold:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 提取记忆ID
|
# 提取记忆ID
|
||||||
neighbor_memory_ids = neighbor_node_data.get("memory_ids", [])
|
neighbor_memory_ids = neighbor_node_data.get("memory_ids", [])
|
||||||
if isinstance(neighbor_memory_ids, str):
|
if isinstance(neighbor_memory_ids, str):
|
||||||
import orjson
|
import orjson
|
||||||
try:
|
try:
|
||||||
neighbor_memory_ids = orjson.loads(neighbor_memory_ids)
|
neighbor_memory_ids = orjson.loads(neighbor_memory_ids)
|
||||||
except:
|
except Exception:
|
||||||
neighbor_memory_ids = [neighbor_memory_ids]
|
neighbor_memory_ids = [neighbor_memory_ids]
|
||||||
|
|
||||||
for neighbor_mem_id in neighbor_memory_ids:
|
for neighbor_mem_id in neighbor_memory_ids:
|
||||||
if neighbor_mem_id in visited_memories:
|
if neighbor_mem_id in visited_memories:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if neighbor_mem_id not in expanded_memories:
|
if neighbor_mem_id not in expanded_memories:
|
||||||
expanded_memories[neighbor_mem_id] = relevance_score
|
expanded_memories[neighbor_mem_id] = relevance_score
|
||||||
visited_memories.add(neighbor_mem_id)
|
visited_memories.add(neighbor_mem_id)
|
||||||
@@ -968,52 +968,52 @@ class MemoryTools:
|
|||||||
expanded_memories[neighbor_mem_id],
|
expanded_memories[neighbor_mem_id],
|
||||||
relevance_score
|
relevance_score
|
||||||
)
|
)
|
||||||
|
|
||||||
if not next_level or len(expanded_memories) >= max_expanded:
|
if not next_level or len(expanded_memories) >= max_expanded:
|
||||||
break
|
break
|
||||||
|
|
||||||
current_level = next_level[:max_expanded]
|
current_level = next_level[:max_expanded]
|
||||||
|
|
||||||
sorted_results = sorted(
|
sorted_results = sorted(
|
||||||
expanded_memories.items(),
|
expanded_memories.items(),
|
||||||
key=lambda x: x[1],
|
key=lambda x: x[1],
|
||||||
reverse=True
|
reverse=True
|
||||||
)[:max_expanded]
|
)[:max_expanded]
|
||||||
|
|
||||||
return sorted_results
|
return sorted_results
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"图扩展失败: {e}", exc_info=True)
|
logger.error(f"图扩展失败: {e}", exc_info=True)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def _cosine_similarity(self, vec1, vec2) -> float:
|
def _cosine_similarity(self, vec1, vec2) -> float:
|
||||||
"""计算余弦相似度"""
|
"""计算余弦相似度"""
|
||||||
try:
|
try:
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
if not isinstance(vec1, np.ndarray):
|
if not isinstance(vec1, np.ndarray):
|
||||||
vec1 = np.array(vec1)
|
vec1 = np.array(vec1)
|
||||||
if not isinstance(vec2, np.ndarray):
|
if not isinstance(vec2, np.ndarray):
|
||||||
vec2 = np.array(vec2)
|
vec2 = np.array(vec2)
|
||||||
|
|
||||||
vec1_norm = np.linalg.norm(vec1)
|
vec1_norm = np.linalg.norm(vec1)
|
||||||
vec2_norm = np.linalg.norm(vec2)
|
vec2_norm = np.linalg.norm(vec2)
|
||||||
|
|
||||||
if vec1_norm == 0 or vec2_norm == 0:
|
if vec1_norm == 0 or vec2_norm == 0:
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
similarity = np.dot(vec1, vec2) / (vec1_norm * vec2_norm)
|
similarity = np.dot(vec1, vec2) / (vec1_norm * vec2_norm)
|
||||||
return float(similarity)
|
return float(similarity)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"计算余弦相似度失败: {e}")
|
logger.warning(f"计算余弦相似度失败: {e}")
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_all_tool_schemas() -> List[Dict[str, Any]]:
|
def get_all_tool_schemas() -> list[dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
获取所有工具的 schema
|
获取所有工具的 schema
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
工具 schema 列表
|
工具 schema 列表
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -5,4 +5,4 @@
|
|||||||
from src.memory_graph.utils.embeddings import EmbeddingGenerator, get_embedding_generator
|
from src.memory_graph.utils.embeddings import EmbeddingGenerator, get_embedding_generator
|
||||||
from src.memory_graph.utils.time_parser import TimeParser
|
from src.memory_graph.utils.time_parser import TimeParser
|
||||||
|
|
||||||
__all__ = ["TimeParser", "EmbeddingGenerator", "get_embedding_generator"]
|
__all__ = ["EmbeddingGenerator", "TimeParser", "get_embedding_generator"]
|
||||||
|
|||||||
@@ -5,8 +5,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from functools import lru_cache
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -18,12 +16,12 @@ logger = get_logger(__name__)
|
|||||||
class EmbeddingGenerator:
|
class EmbeddingGenerator:
|
||||||
"""
|
"""
|
||||||
嵌入向量生成器
|
嵌入向量生成器
|
||||||
|
|
||||||
策略:
|
策略:
|
||||||
1. 优先使用配置的 embedding API(通过 LLMRequest)
|
1. 优先使用配置的 embedding API(通过 LLMRequest)
|
||||||
2. 如果 API 不可用,回退到本地 sentence-transformers
|
2. 如果 API 不可用,回退到本地 sentence-transformers
|
||||||
3. 如果 sentence-transformers 未安装,使用随机向量(仅测试)
|
3. 如果 sentence-transformers 未安装,使用随机向量(仅测试)
|
||||||
|
|
||||||
优点:
|
优点:
|
||||||
- 降低本地运算负载
|
- 降低本地运算负载
|
||||||
- 即使未安装 sentence-transformers 也可正常运行
|
- 即使未安装 sentence-transformers 也可正常运行
|
||||||
@@ -37,19 +35,19 @@ class EmbeddingGenerator:
|
|||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
初始化嵌入生成器
|
初始化嵌入生成器
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
use_api: 是否优先使用 API(默认 True)
|
use_api: 是否优先使用 API(默认 True)
|
||||||
fallback_model_name: 回退本地模型名称
|
fallback_model_name: 回退本地模型名称
|
||||||
"""
|
"""
|
||||||
self.use_api = use_api
|
self.use_api = use_api
|
||||||
self.fallback_model_name = fallback_model_name
|
self.fallback_model_name = fallback_model_name
|
||||||
|
|
||||||
# API 相关
|
# API 相关
|
||||||
self._llm_request = None
|
self._llm_request = None
|
||||||
self._api_available = False
|
self._api_available = False
|
||||||
self._api_dimension = None
|
self._api_dimension = None
|
||||||
|
|
||||||
# 本地模型相关
|
# 本地模型相关
|
||||||
self._local_model = None
|
self._local_model = None
|
||||||
self._local_model_loaded = False
|
self._local_model_loaded = False
|
||||||
@@ -58,24 +56,24 @@ class EmbeddingGenerator:
|
|||||||
"""初始化 embedding API"""
|
"""初始化 embedding API"""
|
||||||
if self._api_available:
|
if self._api_available:
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from src.config.config import model_config
|
from src.config.config import model_config
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
|
||||||
embedding_config = model_config.model_task_config.embedding
|
embedding_config = model_config.model_task_config.embedding
|
||||||
self._llm_request = LLMRequest(
|
self._llm_request = LLMRequest(
|
||||||
model_set=embedding_config,
|
model_set=embedding_config,
|
||||||
request_type="memory_graph.embedding"
|
request_type="memory_graph.embedding"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 获取嵌入维度
|
# 获取嵌入维度
|
||||||
if hasattr(embedding_config, "embedding_dimension") and embedding_config.embedding_dimension:
|
if hasattr(embedding_config, "embedding_dimension") and embedding_config.embedding_dimension:
|
||||||
self._api_dimension = embedding_config.embedding_dimension
|
self._api_dimension = embedding_config.embedding_dimension
|
||||||
|
|
||||||
self._api_available = True
|
self._api_available = True
|
||||||
logger.info(f"✅ Embedding API 初始化成功 (维度: {self._api_dimension})")
|
logger.info(f"✅ Embedding API 初始化成功 (维度: {self._api_dimension})")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"⚠️ Embedding API 初始化失败: {e}")
|
logger.warning(f"⚠️ Embedding API 初始化失败: {e}")
|
||||||
self._api_available = False
|
self._api_available = False
|
||||||
@@ -103,15 +101,15 @@ class EmbeddingGenerator:
|
|||||||
async def generate(self, text: str) -> np.ndarray:
|
async def generate(self, text: str) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
生成单个文本的嵌入向量
|
生成单个文本的嵌入向量
|
||||||
|
|
||||||
策略:
|
策略:
|
||||||
1. 优先使用 API
|
1. 优先使用 API
|
||||||
2. API 失败则使用本地模型
|
2. API 失败则使用本地模型
|
||||||
3. 本地模型不可用则使用随机向量
|
3. 本地模型不可用则使用随机向量
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: 输入文本
|
text: 输入文本
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
嵌入向量
|
嵌入向量
|
||||||
"""
|
"""
|
||||||
@@ -126,12 +124,12 @@ class EmbeddingGenerator:
|
|||||||
embedding = await self._generate_with_api(text)
|
embedding = await self._generate_with_api(text)
|
||||||
if embedding is not None:
|
if embedding is not None:
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
# 策略 2: 使用本地模型
|
# 策略 2: 使用本地模型
|
||||||
embedding = await self._generate_with_local_model(text)
|
embedding = await self._generate_with_local_model(text)
|
||||||
if embedding is not None:
|
if embedding is not None:
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
# 策略 3: 随机向量(仅测试)
|
# 策略 3: 随机向量(仅测试)
|
||||||
logger.warning(f"⚠️ 所有嵌入策略失败,使用随机向量: {text[:30]}...")
|
logger.warning(f"⚠️ 所有嵌入策略失败,使用随机向量: {text[:30]}...")
|
||||||
dim = self._get_dimension()
|
dim = self._get_dimension()
|
||||||
@@ -142,47 +140,47 @@ class EmbeddingGenerator:
|
|||||||
dim = self._get_dimension()
|
dim = self._get_dimension()
|
||||||
return np.random.rand(dim).astype(np.float32)
|
return np.random.rand(dim).astype(np.float32)
|
||||||
|
|
||||||
async def _generate_with_api(self, text: str) -> Optional[np.ndarray]:
|
async def _generate_with_api(self, text: str) -> np.ndarray | None:
|
||||||
"""使用 API 生成嵌入"""
|
"""使用 API 生成嵌入"""
|
||||||
try:
|
try:
|
||||||
# 初始化 API
|
# 初始化 API
|
||||||
if not self._api_available:
|
if not self._api_available:
|
||||||
await self._initialize_api()
|
await self._initialize_api()
|
||||||
|
|
||||||
if not self._api_available or not self._llm_request:
|
if not self._api_available or not self._llm_request:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 调用 API
|
# 调用 API
|
||||||
embedding_list, model_name = await self._llm_request.get_embedding(text)
|
embedding_list, model_name = await self._llm_request.get_embedding(text)
|
||||||
|
|
||||||
if embedding_list and len(embedding_list) > 0:
|
if embedding_list and len(embedding_list) > 0:
|
||||||
embedding = np.array(embedding_list, dtype=np.float32)
|
embedding = np.array(embedding_list, dtype=np.float32)
|
||||||
logger.debug(f"🌐 API 生成嵌入: {text[:30]}... -> {len(embedding)}维 (模型: {model_name})")
|
logger.debug(f"🌐 API 生成嵌入: {text[:30]}... -> {len(embedding)}维 (模型: {model_name})")
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"API 嵌入生成失败: {e}")
|
logger.debug(f"API 嵌入生成失败: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def _generate_with_local_model(self, text: str) -> Optional[np.ndarray]:
|
async def _generate_with_local_model(self, text: str) -> np.ndarray | None:
|
||||||
"""使用本地模型生成嵌入"""
|
"""使用本地模型生成嵌入"""
|
||||||
try:
|
try:
|
||||||
# 加载本地模型
|
# 加载本地模型
|
||||||
if not self._local_model_loaded:
|
if not self._local_model_loaded:
|
||||||
self._load_local_model()
|
self._load_local_model()
|
||||||
|
|
||||||
if not self._local_model_loaded or not self._local_model:
|
if not self._local_model_loaded or not self._local_model:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 在线程池中运行
|
# 在线程池中运行
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
embedding = await loop.run_in_executor(None, self._encode_single_local, text)
|
embedding = await loop.run_in_executor(None, self._encode_single_local, text)
|
||||||
|
|
||||||
logger.debug(f"💻 本地生成嵌入: {text[:30]}... -> {len(embedding)}维")
|
logger.debug(f"💻 本地生成嵌入: {text[:30]}... -> {len(embedding)}维")
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"本地模型嵌入生成失败: {e}")
|
logger.debug(f"本地模型嵌入生成失败: {e}")
|
||||||
return None
|
return None
|
||||||
@@ -199,24 +197,24 @@ class EmbeddingGenerator:
|
|||||||
# 优先使用 API 维度
|
# 优先使用 API 维度
|
||||||
if self._api_dimension:
|
if self._api_dimension:
|
||||||
return self._api_dimension
|
return self._api_dimension
|
||||||
|
|
||||||
# 其次使用本地模型维度
|
# 其次使用本地模型维度
|
||||||
if self._local_model_loaded and self._local_model:
|
if self._local_model_loaded and self._local_model:
|
||||||
try:
|
try:
|
||||||
return self._local_model.get_sentence_embedding_dimension()
|
return self._local_model.get_sentence_embedding_dimension()
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# 默认 384(sentence-transformers 常用维度)
|
# 默认 384(sentence-transformers 常用维度)
|
||||||
return 384
|
return 384
|
||||||
|
|
||||||
async def generate_batch(self, texts: List[str]) -> List[np.ndarray]:
|
async def generate_batch(self, texts: list[str]) -> list[np.ndarray]:
|
||||||
"""
|
"""
|
||||||
批量生成嵌入向量
|
批量生成嵌入向量
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
texts: 文本列表
|
texts: 文本列表
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
嵌入向量列表
|
嵌入向量列表
|
||||||
"""
|
"""
|
||||||
@@ -236,13 +234,13 @@ class EmbeddingGenerator:
|
|||||||
results = await self._generate_batch_with_api(valid_texts)
|
results = await self._generate_batch_with_api(valid_texts)
|
||||||
if results:
|
if results:
|
||||||
return results
|
return results
|
||||||
|
|
||||||
# 回退到逐个生成
|
# 回退到逐个生成
|
||||||
results = []
|
results = []
|
||||||
for text in valid_texts:
|
for text in valid_texts:
|
||||||
embedding = await self.generate(text)
|
embedding = await self.generate(text)
|
||||||
results.append(embedding)
|
results.append(embedding)
|
||||||
|
|
||||||
logger.info(f"✅ 批量生成嵌入: {len(texts)} 个文本")
|
logger.info(f"✅ 批量生成嵌入: {len(texts)} 个文本")
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@@ -251,7 +249,7 @@ class EmbeddingGenerator:
|
|||||||
dim = self._get_dimension()
|
dim = self._get_dimension()
|
||||||
return [np.random.rand(dim).astype(np.float32) for _ in texts]
|
return [np.random.rand(dim).astype(np.float32) for _ in texts]
|
||||||
|
|
||||||
async def _generate_batch_with_api(self, texts: List[str]) -> Optional[List[np.ndarray]]:
|
async def _generate_batch_with_api(self, texts: list[str]) -> list[np.ndarray] | None:
|
||||||
"""使用 API 批量生成"""
|
"""使用 API 批量生成"""
|
||||||
try:
|
try:
|
||||||
# 对于大多数 API,批量调用就是多次单独调用
|
# 对于大多数 API,批量调用就是多次单独调用
|
||||||
@@ -273,7 +271,7 @@ class EmbeddingGenerator:
|
|||||||
|
|
||||||
|
|
||||||
# 全局单例
|
# 全局单例
|
||||||
_global_generator: Optional[EmbeddingGenerator] = None
|
_global_generator: EmbeddingGenerator | None = None
|
||||||
|
|
||||||
|
|
||||||
def get_embedding_generator(
|
def get_embedding_generator(
|
||||||
@@ -282,11 +280,11 @@ def get_embedding_generator(
|
|||||||
) -> EmbeddingGenerator:
|
) -> EmbeddingGenerator:
|
||||||
"""
|
"""
|
||||||
获取全局嵌入生成器单例
|
获取全局嵌入生成器单例
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
use_api: 是否优先使用 API
|
use_api: 是否优先使用 API
|
||||||
fallback_model_name: 回退本地模型名称
|
fallback_model_name: 回退本地模型名称
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
EmbeddingGenerator 实例
|
EmbeddingGenerator 实例
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -5,10 +5,9 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional, List, Dict, Any
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from src.memory_graph.models import Memory, MemoryNode, NodeType, EdgeType, MemoryType
|
from src.memory_graph.models import EdgeType, Memory, MemoryType, NodeType
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -16,18 +15,18 @@ logger = logging.getLogger(__name__)
|
|||||||
def format_memory_for_prompt(memory: Memory, include_metadata: bool = False) -> str:
|
def format_memory_for_prompt(memory: Memory, include_metadata: bool = False) -> str:
|
||||||
"""
|
"""
|
||||||
将记忆对象格式化为适合提示词的自然语言描述
|
将记忆对象格式化为适合提示词的自然语言描述
|
||||||
|
|
||||||
根据记忆的图结构,构建完整的主谓宾描述,包含:
|
根据记忆的图结构,构建完整的主谓宾描述,包含:
|
||||||
- 主语(subject node)
|
- 主语(subject node)
|
||||||
- 谓语/动作(topic node)
|
- 谓语/动作(topic node)
|
||||||
- 宾语/对象(object node,如果存在)
|
- 宾语/对象(object node,如果存在)
|
||||||
- 属性信息(attributes,如时间、地点等)
|
- 属性信息(attributes,如时间、地点等)
|
||||||
- 关系信息(记忆之间的关系)
|
- 关系信息(记忆之间的关系)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
memory: 记忆对象
|
memory: 记忆对象
|
||||||
include_metadata: 是否包含元数据(时间、重要性等)
|
include_metadata: 是否包含元数据(时间、重要性等)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
格式化后的自然语言描述
|
格式化后的自然语言描述
|
||||||
"""
|
"""
|
||||||
@@ -37,24 +36,22 @@ def format_memory_for_prompt(memory: Memory, include_metadata: bool = False) ->
|
|||||||
if not subject_node:
|
if not subject_node:
|
||||||
logger.warning(f"记忆 {memory.id} 缺少主体节点")
|
logger.warning(f"记忆 {memory.id} 缺少主体节点")
|
||||||
return "(记忆格式错误:缺少主体)"
|
return "(记忆格式错误:缺少主体)"
|
||||||
|
|
||||||
subject_text = subject_node.content
|
subject_text = subject_node.content
|
||||||
|
|
||||||
# 2. 查找主题节点(谓语/动作)
|
# 2. 查找主题节点(谓语/动作)
|
||||||
topic_node = None
|
topic_node = None
|
||||||
memory_type_relation = None
|
|
||||||
for edge in memory.edges:
|
for edge in memory.edges:
|
||||||
if edge.edge_type == EdgeType.MEMORY_TYPE and edge.source_id == memory.subject_id:
|
if edge.edge_type == EdgeType.MEMORY_TYPE and edge.source_id == memory.subject_id:
|
||||||
topic_node = memory.get_node_by_id(edge.target_id)
|
topic_node = memory.get_node_by_id(edge.target_id)
|
||||||
memory_type_relation = edge.relation
|
|
||||||
break
|
break
|
||||||
|
|
||||||
if not topic_node:
|
if not topic_node:
|
||||||
logger.warning(f"记忆 {memory.id} 缺少主题节点")
|
logger.warning(f"记忆 {memory.id} 缺少主题节点")
|
||||||
return f"{subject_text}(记忆格式错误:缺少主题)"
|
return f"{subject_text}(记忆格式错误:缺少主题)"
|
||||||
|
|
||||||
topic_text = topic_node.content
|
topic_text = topic_node.content
|
||||||
|
|
||||||
# 3. 查找客体节点(宾语)和核心关系
|
# 3. 查找客体节点(宾语)和核心关系
|
||||||
object_node = None
|
object_node = None
|
||||||
core_relation = None
|
core_relation = None
|
||||||
@@ -63,9 +60,9 @@ def format_memory_for_prompt(memory: Memory, include_metadata: bool = False) ->
|
|||||||
object_node = memory.get_node_by_id(edge.target_id)
|
object_node = memory.get_node_by_id(edge.target_id)
|
||||||
core_relation = edge.relation if edge.relation else ""
|
core_relation = edge.relation if edge.relation else ""
|
||||||
break
|
break
|
||||||
|
|
||||||
# 4. 收集属性节点
|
# 4. 收集属性节点
|
||||||
attributes: Dict[str, str] = {}
|
attributes: dict[str, str] = {}
|
||||||
for edge in memory.edges:
|
for edge in memory.edges:
|
||||||
if edge.edge_type == EdgeType.ATTRIBUTE:
|
if edge.edge_type == EdgeType.ATTRIBUTE:
|
||||||
# 查找属性节点和值节点
|
# 查找属性节点和值节点
|
||||||
@@ -73,16 +70,16 @@ def format_memory_for_prompt(memory: Memory, include_metadata: bool = False) ->
|
|||||||
if attr_node and attr_node.node_type == NodeType.ATTRIBUTE:
|
if attr_node and attr_node.node_type == NodeType.ATTRIBUTE:
|
||||||
# 查找这个属性的值
|
# 查找这个属性的值
|
||||||
for value_edge in memory.edges:
|
for value_edge in memory.edges:
|
||||||
if (value_edge.edge_type == EdgeType.ATTRIBUTE
|
if (value_edge.edge_type == EdgeType.ATTRIBUTE
|
||||||
and value_edge.source_id == attr_node.id):
|
and value_edge.source_id == attr_node.id):
|
||||||
value_node = memory.get_node_by_id(value_edge.target_id)
|
value_node = memory.get_node_by_id(value_edge.target_id)
|
||||||
if value_node and value_node.node_type == NodeType.VALUE:
|
if value_node and value_node.node_type == NodeType.VALUE:
|
||||||
attributes[attr_node.content] = value_node.content
|
attributes[attr_node.content] = value_node.content
|
||||||
break
|
break
|
||||||
|
|
||||||
# 5. 构建自然语言描述
|
# 5. 构建自然语言描述
|
||||||
parts = []
|
parts = []
|
||||||
|
|
||||||
# 主谓宾结构
|
# 主谓宾结构
|
||||||
if object_node is not None:
|
if object_node is not None:
|
||||||
# 有完整的主谓宾
|
# 有完整的主谓宾
|
||||||
@@ -93,7 +90,7 @@ def format_memory_for_prompt(memory: Memory, include_metadata: bool = False) ->
|
|||||||
else:
|
else:
|
||||||
# 只有主谓
|
# 只有主谓
|
||||||
parts.append(f"{subject_text}{topic_text}")
|
parts.append(f"{subject_text}{topic_text}")
|
||||||
|
|
||||||
# 添加属性信息
|
# 添加属性信息
|
||||||
if attributes:
|
if attributes:
|
||||||
attr_parts = []
|
attr_parts = []
|
||||||
@@ -106,78 +103,78 @@ def format_memory_for_prompt(memory: Memory, include_metadata: bool = False) ->
|
|||||||
for key, value in attributes.items():
|
for key, value in attributes.items():
|
||||||
if key not in ["时间", "地点"]:
|
if key not in ["时间", "地点"]:
|
||||||
attr_parts.append(f"{key}:{value}")
|
attr_parts.append(f"{key}:{value}")
|
||||||
|
|
||||||
if attr_parts:
|
if attr_parts:
|
||||||
parts.append(f"({' '.join(attr_parts)})")
|
parts.append(f"({' '.join(attr_parts)})")
|
||||||
|
|
||||||
description = "".join(parts)
|
description = "".join(parts)
|
||||||
|
|
||||||
# 6. 添加元数据(可选)
|
# 6. 添加元数据(可选)
|
||||||
if include_metadata:
|
if include_metadata:
|
||||||
metadata_parts = []
|
metadata_parts = []
|
||||||
|
|
||||||
# 记忆类型
|
# 记忆类型
|
||||||
if memory.memory_type:
|
if memory.memory_type:
|
||||||
metadata_parts.append(f"类型:{memory.memory_type.value}")
|
metadata_parts.append(f"类型:{memory.memory_type.value}")
|
||||||
|
|
||||||
# 重要性
|
# 重要性
|
||||||
if memory.importance >= 0.8:
|
if memory.importance >= 0.8:
|
||||||
metadata_parts.append("重要")
|
metadata_parts.append("重要")
|
||||||
elif memory.importance >= 0.6:
|
elif memory.importance >= 0.6:
|
||||||
metadata_parts.append("一般")
|
metadata_parts.append("一般")
|
||||||
|
|
||||||
# 时间(如果没有在属性中)
|
# 时间(如果没有在属性中)
|
||||||
if "时间" not in attributes:
|
if "时间" not in attributes:
|
||||||
time_str = _format_relative_time(memory.created_at)
|
time_str = _format_relative_time(memory.created_at)
|
||||||
if time_str:
|
if time_str:
|
||||||
metadata_parts.append(time_str)
|
metadata_parts.append(time_str)
|
||||||
|
|
||||||
if metadata_parts:
|
if metadata_parts:
|
||||||
description += f" [{', '.join(metadata_parts)}]"
|
description += f" [{', '.join(metadata_parts)}]"
|
||||||
|
|
||||||
return description
|
return description
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"格式化记忆失败: {e}", exc_info=True)
|
logger.error(f"格式化记忆失败: {e}", exc_info=True)
|
||||||
return f"(记忆格式化错误: {str(e)[:50]})"
|
return f"(记忆格式化错误: {str(e)[:50]})"
|
||||||
|
|
||||||
|
|
||||||
def format_memories_for_prompt(
|
def format_memories_for_prompt(
|
||||||
memories: List[Memory],
|
memories: list[Memory],
|
||||||
max_count: Optional[int] = None,
|
max_count: int | None = None,
|
||||||
include_metadata: bool = False,
|
include_metadata: bool = False,
|
||||||
group_by_type: bool = False
|
group_by_type: bool = False
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
批量格式化多条记忆为提示词文本
|
批量格式化多条记忆为提示词文本
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
memories: 记忆列表
|
memories: 记忆列表
|
||||||
max_count: 最大记忆数量(可选)
|
max_count: 最大记忆数量(可选)
|
||||||
include_metadata: 是否包含元数据
|
include_metadata: 是否包含元数据
|
||||||
group_by_type: 是否按类型分组
|
group_by_type: 是否按类型分组
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
格式化后的文本,包含标题和列表
|
格式化后的文本,包含标题和列表
|
||||||
"""
|
"""
|
||||||
if not memories:
|
if not memories:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
# 限制数量
|
# 限制数量
|
||||||
if max_count:
|
if max_count:
|
||||||
memories = memories[:max_count]
|
memories = memories[:max_count]
|
||||||
|
|
||||||
# 按类型分组
|
# 按类型分组
|
||||||
if group_by_type:
|
if group_by_type:
|
||||||
type_groups: Dict[MemoryType, List[Memory]] = {}
|
type_groups: dict[MemoryType, list[Memory]] = {}
|
||||||
for memory in memories:
|
for memory in memories:
|
||||||
if memory.memory_type not in type_groups:
|
if memory.memory_type not in type_groups:
|
||||||
type_groups[memory.memory_type] = []
|
type_groups[memory.memory_type] = []
|
||||||
type_groups[memory.memory_type].append(memory)
|
type_groups[memory.memory_type].append(memory)
|
||||||
|
|
||||||
# 构建分组文本
|
# 构建分组文本
|
||||||
parts = ["### 🧠 相关记忆 (Relevant Memories)", ""]
|
parts = ["### 🧠 相关记忆 (Relevant Memories)", ""]
|
||||||
|
|
||||||
type_order = [MemoryType.FACT, MemoryType.EVENT, MemoryType.RELATION, MemoryType.OPINION]
|
type_order = [MemoryType.FACT, MemoryType.EVENT, MemoryType.RELATION, MemoryType.OPINION]
|
||||||
for mem_type in type_order:
|
for mem_type in type_order:
|
||||||
if mem_type in type_groups:
|
if mem_type in type_groups:
|
||||||
@@ -186,33 +183,33 @@ def format_memories_for_prompt(
|
|||||||
desc = format_memory_for_prompt(memory, include_metadata)
|
desc = format_memory_for_prompt(memory, include_metadata)
|
||||||
parts.append(f"- {desc}")
|
parts.append(f"- {desc}")
|
||||||
parts.append("")
|
parts.append("")
|
||||||
|
|
||||||
return "\n".join(parts)
|
return "\n".join(parts)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# 不分组,直接列出
|
# 不分组,直接列出
|
||||||
parts = ["### 🧠 相关记忆 (Relevant Memories)", ""]
|
parts = ["### 🧠 相关记忆 (Relevant Memories)", ""]
|
||||||
|
|
||||||
for memory in memories:
|
for memory in memories:
|
||||||
# 获取类型标签
|
# 获取类型标签
|
||||||
type_label = memory.memory_type.value if memory.memory_type else "未知"
|
type_label = memory.memory_type.value if memory.memory_type else "未知"
|
||||||
|
|
||||||
# 格式化记忆内容
|
# 格式化记忆内容
|
||||||
desc = format_memory_for_prompt(memory, include_metadata)
|
desc = format_memory_for_prompt(memory, include_metadata)
|
||||||
|
|
||||||
# 添加类型标签
|
# 添加类型标签
|
||||||
parts.append(f"- **[{type_label}]** {desc}")
|
parts.append(f"- **[{type_label}]** {desc}")
|
||||||
|
|
||||||
return "\n".join(parts)
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
def get_memory_type_label(memory_type: str) -> str:
|
def get_memory_type_label(memory_type: str) -> str:
|
||||||
"""
|
"""
|
||||||
获取记忆类型的中文标签
|
获取记忆类型的中文标签
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
memory_type: 记忆类型(可能是英文或中文)
|
memory_type: 记忆类型(可能是英文或中文)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
中文标签
|
中文标签
|
||||||
"""
|
"""
|
||||||
@@ -243,27 +240,27 @@ def get_memory_type_label(memory_type: str) -> str:
|
|||||||
"经历": "经历",
|
"经历": "经历",
|
||||||
"情境": "情境",
|
"情境": "情境",
|
||||||
}
|
}
|
||||||
|
|
||||||
# 转换为小写进行匹配
|
# 转换为小写进行匹配
|
||||||
memory_type_lower = memory_type.lower() if memory_type else ""
|
memory_type_lower = memory_type.lower() if memory_type else ""
|
||||||
|
|
||||||
return type_mapping.get(memory_type_lower, "未知")
|
return type_mapping.get(memory_type_lower, "未知")
|
||||||
|
|
||||||
|
|
||||||
def _format_relative_time(timestamp: datetime) -> Optional[str]:
|
def _format_relative_time(timestamp: datetime) -> str | None:
|
||||||
"""
|
"""
|
||||||
格式化相对时间(如"2天前"、"刚才")
|
格式化相对时间(如"2天前"、"刚才")
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
timestamp: 时间戳
|
timestamp: 时间戳
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
相对时间描述,如果太久远则返回None
|
相对时间描述,如果太久远则返回None
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
delta = now - timestamp
|
delta = now - timestamp
|
||||||
|
|
||||||
if delta.total_seconds() < 60:
|
if delta.total_seconds() < 60:
|
||||||
return "刚才"
|
return "刚才"
|
||||||
elif delta.total_seconds() < 3600:
|
elif delta.total_seconds() < 3600:
|
||||||
@@ -290,17 +287,17 @@ def _format_relative_time(timestamp: datetime) -> Optional[str]:
|
|||||||
def format_memory_summary(memory: Memory) -> str:
|
def format_memory_summary(memory: Memory) -> str:
|
||||||
"""
|
"""
|
||||||
生成记忆的简短摘要(用于日志和调试)
|
生成记忆的简短摘要(用于日志和调试)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
memory: 记忆对象
|
memory: 记忆对象
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
简短摘要
|
简短摘要
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
subject_node = memory.get_subject_node()
|
subject_node = memory.get_subject_node()
|
||||||
subject_text = subject_node.content if subject_node else "?"
|
subject_text = subject_node.content if subject_node else "?"
|
||||||
|
|
||||||
topic_text = "?"
|
topic_text = "?"
|
||||||
for edge in memory.edges:
|
for edge in memory.edges:
|
||||||
if edge.edge_type == EdgeType.MEMORY_TYPE and edge.source_id == memory.subject_id:
|
if edge.edge_type == EdgeType.MEMORY_TYPE and edge.source_id == memory.subject_id:
|
||||||
@@ -308,7 +305,7 @@ def format_memory_summary(memory: Memory) -> str:
|
|||||||
if topic_node:
|
if topic_node:
|
||||||
topic_text = topic_node.content
|
topic_text = topic_node.content
|
||||||
break
|
break
|
||||||
|
|
||||||
return f"{subject_text} - {memory.memory_type.value if memory.memory_type else '?'}: {topic_text}"
|
return f"{subject_text} - {memory.memory_type.value if memory.memory_type else '?'}: {topic_text}"
|
||||||
except Exception:
|
except Exception:
|
||||||
return f"记忆 {memory.id[:8]}"
|
return f"记忆 {memory.id[:8]}"
|
||||||
@@ -316,8 +313,8 @@ def format_memory_summary(memory: Memory) -> str:
|
|||||||
|
|
||||||
# 导出主要函数
|
# 导出主要函数
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'format_memory_for_prompt',
|
"format_memories_for_prompt",
|
||||||
'format_memories_for_prompt',
|
"format_memory_for_prompt",
|
||||||
'get_memory_type_label',
|
"format_memory_summary",
|
||||||
'format_memory_summary',
|
"get_memory_type_label",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import re
|
import re
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
@@ -24,26 +23,26 @@ logger = get_logger(__name__)
|
|||||||
class TimeParser:
|
class TimeParser:
|
||||||
"""
|
"""
|
||||||
时间解析器
|
时间解析器
|
||||||
|
|
||||||
负责将自然语言时间表达转换为标准化的绝对时间
|
负责将自然语言时间表达转换为标准化的绝对时间
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, reference_time: Optional[datetime] = None):
|
def __init__(self, reference_time: datetime | None = None):
|
||||||
"""
|
"""
|
||||||
初始化时间解析器
|
初始化时间解析器
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
reference_time: 参考时间(通常是当前时间)
|
reference_time: 参考时间(通常是当前时间)
|
||||||
"""
|
"""
|
||||||
self.reference_time = reference_time or datetime.now()
|
self.reference_time = reference_time or datetime.now()
|
||||||
|
|
||||||
def parse(self, time_str: str) -> Optional[datetime]:
|
def parse(self, time_str: str) -> datetime | None:
|
||||||
"""
|
"""
|
||||||
解析时间字符串
|
解析时间字符串
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
time_str: 时间字符串
|
time_str: 时间字符串
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
标准化的datetime对象,如果解析失败则返回None
|
标准化的datetime对象,如果解析失败则返回None
|
||||||
"""
|
"""
|
||||||
@@ -81,7 +80,7 @@ class TimeParser:
|
|||||||
logger.warning(f"无法解析时间: '{time_str}',使用当前时间")
|
logger.warning(f"无法解析时间: '{time_str}',使用当前时间")
|
||||||
return self.reference_time
|
return self.reference_time
|
||||||
|
|
||||||
def _parse_relative_day(self, time_str: str) -> Optional[datetime]:
|
def _parse_relative_day(self, time_str: str) -> datetime | None:
|
||||||
"""
|
"""
|
||||||
解析相对日期:今天、明天、昨天、前天、后天
|
解析相对日期:今天、明天、昨天、前天、后天
|
||||||
"""
|
"""
|
||||||
@@ -108,7 +107,7 @@ class TimeParser:
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _parse_days_ago(self, time_str: str) -> Optional[datetime]:
|
def _parse_days_ago(self, time_str: str) -> datetime | None:
|
||||||
"""
|
"""
|
||||||
解析 X天前/X天后、X周前/X周后、X个月前/X个月后
|
解析 X天前/X天后、X周前/X周后、X个月前/X个月后
|
||||||
"""
|
"""
|
||||||
@@ -172,7 +171,7 @@ class TimeParser:
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _parse_hours_ago(self, time_str: str) -> Optional[datetime]:
|
def _parse_hours_ago(self, time_str: str) -> datetime | None:
|
||||||
"""
|
"""
|
||||||
解析 X小时前/X小时后、X分钟前/X分钟后
|
解析 X小时前/X小时后、X分钟前/X分钟后
|
||||||
"""
|
"""
|
||||||
@@ -204,7 +203,7 @@ class TimeParser:
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _parse_week_month_year(self, time_str: str) -> Optional[datetime]:
|
def _parse_week_month_year(self, time_str: str) -> datetime | None:
|
||||||
"""
|
"""
|
||||||
解析:上周、上个月、去年、本周、本月、今年
|
解析:上周、上个月、去年、本周、本月、今年
|
||||||
"""
|
"""
|
||||||
@@ -232,7 +231,7 @@ class TimeParser:
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _parse_specific_date(self, time_str: str) -> Optional[datetime]:
|
def _parse_specific_date(self, time_str: str) -> datetime | None:
|
||||||
"""
|
"""
|
||||||
解析具体日期:
|
解析具体日期:
|
||||||
- 2025-11-05
|
- 2025-11-05
|
||||||
@@ -266,7 +265,7 @@ class TimeParser:
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _parse_time_of_day(self, time_str: str) -> Optional[datetime]:
|
def _parse_time_of_day(self, time_str: str) -> datetime | None:
|
||||||
"""
|
"""
|
||||||
解析一天中的时间:
|
解析一天中的时间:
|
||||||
- 早上、上午、中午、下午、晚上、深夜
|
- 早上、上午、中午、下午、晚上、深夜
|
||||||
@@ -290,7 +289,7 @@ class TimeParser:
|
|||||||
}
|
}
|
||||||
|
|
||||||
# 先检查是否有具体时间点:早上8点、下午3点
|
# 先检查是否有具体时间点:早上8点、下午3点
|
||||||
for period, default_hour in time_periods.items():
|
for period in time_periods.keys():
|
||||||
pattern = rf"{period}(\d{{1,2}})点?"
|
pattern = rf"{period}(\d{{1,2}})点?"
|
||||||
match = re.search(pattern, time_str)
|
match = re.search(pattern, time_str)
|
||||||
if match:
|
if match:
|
||||||
@@ -314,13 +313,13 @@ class TimeParser:
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _parse_combined_time(self, time_str: str) -> Optional[datetime]:
|
def _parse_combined_time(self, time_str: str) -> datetime | None:
|
||||||
"""
|
"""
|
||||||
解析组合时间表达:今天下午、昨天晚上、明天早上
|
解析组合时间表达:今天下午、昨天晚上、明天早上
|
||||||
"""
|
"""
|
||||||
# 先解析日期部分
|
# 先解析日期部分
|
||||||
date_result = None
|
date_result = None
|
||||||
|
|
||||||
# 相对日期关键词
|
# 相对日期关键词
|
||||||
relative_days = {
|
relative_days = {
|
||||||
"今天": 0, "今日": 0,
|
"今天": 0, "今日": 0,
|
||||||
@@ -330,16 +329,16 @@ class TimeParser:
|
|||||||
"后天": 2, "后日": 2,
|
"后天": 2, "后日": 2,
|
||||||
"大前天": -3, "大后天": 3,
|
"大前天": -3, "大后天": 3,
|
||||||
}
|
}
|
||||||
|
|
||||||
for keyword, days in relative_days.items():
|
for keyword, days in relative_days.items():
|
||||||
if keyword in time_str:
|
if keyword in time_str:
|
||||||
date_result = self.reference_time + timedelta(days=days)
|
date_result = self.reference_time + timedelta(days=days)
|
||||||
date_result = date_result.replace(hour=0, minute=0, second=0, microsecond=0)
|
date_result = date_result.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||||
break
|
break
|
||||||
|
|
||||||
if not date_result:
|
if not date_result:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 再解析时间段部分
|
# 再解析时间段部分
|
||||||
time_periods = {
|
time_periods = {
|
||||||
"早上": 8, "早晨": 8,
|
"早上": 8, "早晨": 8,
|
||||||
@@ -351,7 +350,7 @@ class TimeParser:
|
|||||||
"深夜": 23,
|
"深夜": 23,
|
||||||
"凌晨": 2,
|
"凌晨": 2,
|
||||||
}
|
}
|
||||||
|
|
||||||
for period, hour in time_periods.items():
|
for period, hour in time_periods.items():
|
||||||
if period in time_str:
|
if period in time_str:
|
||||||
# 检查是否有具体时间点
|
# 检查是否有具体时间点
|
||||||
@@ -363,17 +362,17 @@ class TimeParser:
|
|||||||
if period in ["下午", "晚上"] and hour < 12:
|
if period in ["下午", "晚上"] and hour < 12:
|
||||||
hour += 12
|
hour += 12
|
||||||
return date_result.replace(hour=hour)
|
return date_result.replace(hour=hour)
|
||||||
|
|
||||||
# 如果没有时间段,返回日期(默认0点)
|
# 如果没有时间段,返回日期(默认0点)
|
||||||
return date_result
|
return date_result
|
||||||
|
|
||||||
def _chinese_num_to_int(self, num_str: str) -> int:
|
def _chinese_num_to_int(self, num_str: str) -> int:
|
||||||
"""
|
"""
|
||||||
将中文数字转换为阿拉伯数字
|
将中文数字转换为阿拉伯数字
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_str: 中文数字字符串(如:"一"、"十"、"3")
|
num_str: 中文数字字符串(如:"一"、"十"、"3")
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
整数
|
整数
|
||||||
"""
|
"""
|
||||||
@@ -418,11 +417,11 @@ class TimeParser:
|
|||||||
def format_time(self, dt: datetime, format_type: str = "iso") -> str:
|
def format_time(self, dt: datetime, format_type: str = "iso") -> str:
|
||||||
"""
|
"""
|
||||||
格式化时间
|
格式化时间
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dt: datetime对象
|
dt: datetime对象
|
||||||
format_type: 格式类型 ("iso", "cn", "relative")
|
format_type: 格式类型 ("iso", "cn", "relative")
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
格式化的时间字符串
|
格式化的时间字符串
|
||||||
"""
|
"""
|
||||||
@@ -461,13 +460,13 @@ class TimeParser:
|
|||||||
|
|
||||||
return str(dt)
|
return str(dt)
|
||||||
|
|
||||||
def parse_time_range(self, time_str: str) -> Tuple[Optional[datetime], Optional[datetime]]:
|
def parse_time_range(self, time_str: str) -> tuple[datetime | None, datetime | None]:
|
||||||
"""
|
"""
|
||||||
解析时间范围:最近一周、最近3天
|
解析时间范围:最近一周、最近3天
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
time_str: 时间范围字符串
|
time_str: 时间范围字符串
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(start_time, end_time)
|
(start_time, end_time)
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user