feat(memory): 添加自动关联功能及相关配置支持
This commit is contained in:
@@ -17,12 +17,13 @@ from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
from src.memory_graph.config import MemoryGraphConfig
|
||||
from src.memory_graph.core.builder import MemoryBuilder
|
||||
from src.memory_graph.core.extractor import MemoryExtractor
|
||||
from src.memory_graph.models import Memory, MemoryNode, MemoryType, NodeType
|
||||
from src.memory_graph.models import Memory, MemoryEdge, MemoryNode, MemoryType, NodeType, EdgeType
|
||||
from src.memory_graph.storage.graph_store import GraphStore
|
||||
from src.memory_graph.storage.persistence import PersistenceManager
|
||||
from src.memory_graph.storage.vector_store import VectorStore
|
||||
from src.memory_graph.tools.memory_tools import MemoryTools
|
||||
from src.memory_graph.utils.embeddings import EmbeddingGenerator
|
||||
import uuid
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -851,6 +852,376 @@ class MemoryManager:
|
||||
logger.error(f"记忆整理失败: {e}", exc_info=True)
|
||||
return {"error": str(e), "merged_count": 0, "checked_count": 0}
|
||||
|
||||
async def auto_link_memories(
|
||||
self,
|
||||
time_window_hours: float = None,
|
||||
max_candidates: int = None,
|
||||
min_confidence: float = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
自动关联记忆
|
||||
|
||||
使用LLM分析记忆之间的关系,自动建立关联边。
|
||||
|
||||
Args:
|
||||
time_window_hours: 分析时间窗口(小时)
|
||||
max_candidates: 每个记忆最多关联的候选数
|
||||
min_confidence: 最低置信度阈值
|
||||
|
||||
Returns:
|
||||
关联结果统计
|
||||
"""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
# 使用配置值或参数覆盖
|
||||
time_window_hours = time_window_hours if time_window_hours is not None else 24
|
||||
max_candidates = max_candidates if max_candidates is not None else self.config.auto_link_max_candidates
|
||||
min_confidence = min_confidence if min_confidence is not None else self.config.auto_link_min_confidence
|
||||
|
||||
try:
|
||||
logger.info(f"开始自动关联记忆 (时间窗口={time_window_hours}h)...")
|
||||
|
||||
result = {
|
||||
"checked_count": 0,
|
||||
"linked_count": 0,
|
||||
"relation_stats": {}, # 关系类型统计 {类型: 数量}
|
||||
"relations": {}, # 详细关系 {source_id: [关系列表]}
|
||||
}
|
||||
|
||||
# 1. 获取时间窗口内的记忆
|
||||
time_threshold = datetime.now() - timedelta(hours=time_window_hours)
|
||||
all_memories = self.graph_store.get_all_memories()
|
||||
|
||||
recent_memories = [
|
||||
mem for mem in all_memories
|
||||
if mem.created_at >= time_threshold
|
||||
and not mem.metadata.get("forgotten", False)
|
||||
]
|
||||
|
||||
if len(recent_memories) < 2:
|
||||
logger.info("记忆数量不足,跳过自动关联")
|
||||
return result
|
||||
|
||||
logger.info(f"找到 {len(recent_memories)} 条待关联记忆")
|
||||
|
||||
# 2. 为每个记忆寻找关联候选
|
||||
for memory in recent_memories:
|
||||
result["checked_count"] += 1
|
||||
|
||||
# 跳过已经有很多连接的记忆
|
||||
existing_edges = len([
|
||||
e for e in memory.edges
|
||||
if e.edge_type == EdgeType.RELATION
|
||||
])
|
||||
if existing_edges >= 10:
|
||||
continue
|
||||
|
||||
# 3. 使用向量搜索找候选记忆
|
||||
candidates = await self._find_link_candidates(
|
||||
memory,
|
||||
exclude_ids={memory.id},
|
||||
max_results=max_candidates
|
||||
)
|
||||
|
||||
if not candidates:
|
||||
continue
|
||||
|
||||
# 4. 使用LLM分析关系
|
||||
relations = await self._analyze_memory_relations(
|
||||
source_memory=memory,
|
||||
candidate_memories=candidates,
|
||||
min_confidence=min_confidence
|
||||
)
|
||||
|
||||
# 5. 建立关联
|
||||
for relation in relations:
|
||||
try:
|
||||
# 创建关联边
|
||||
edge = MemoryEdge(
|
||||
id=f"edge_{uuid.uuid4().hex[:12]}",
|
||||
source_id=memory.subject_id,
|
||||
target_id=relation["target_memory"].subject_id,
|
||||
relation=relation["relation_type"],
|
||||
edge_type=EdgeType.RELATION,
|
||||
importance=relation["confidence"],
|
||||
metadata={
|
||||
"auto_linked": True,
|
||||
"confidence": relation["confidence"],
|
||||
"reasoning": relation["reasoning"],
|
||||
"created_at": datetime.now().isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
# 添加到图
|
||||
self.graph_store.graph.add_edge(
|
||||
edge.source_id,
|
||||
edge.target_id,
|
||||
edge_id=edge.id,
|
||||
relation=edge.relation,
|
||||
edge_type=edge.edge_type.value,
|
||||
importance=edge.importance,
|
||||
metadata=edge.metadata,
|
||||
)
|
||||
|
||||
# 同时添加到记忆的边列表
|
||||
memory.edges.append(edge)
|
||||
|
||||
result["linked_count"] += 1
|
||||
|
||||
# 更新统计
|
||||
result["relation_stats"][relation["relation_type"]] = \
|
||||
result["relation_stats"].get(relation["relation_type"], 0) + 1
|
||||
|
||||
# 记录详细关系
|
||||
if memory.id not in result["relations"]:
|
||||
result["relations"][memory.id] = []
|
||||
result["relations"][memory.id].append({
|
||||
"target_id": relation["target_memory"].id,
|
||||
"relation_type": relation["relation_type"],
|
||||
"confidence": relation["confidence"],
|
||||
"reasoning": relation["reasoning"],
|
||||
})
|
||||
|
||||
logger.info(
|
||||
f"建立关联: {memory.id[:8]} --[{relation['relation_type']}]--> "
|
||||
f"{relation['target_memory'].id[:8]} "
|
||||
f"(置信度={relation['confidence']:.2f})"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"建立关联失败: {e}")
|
||||
continue
|
||||
|
||||
# 保存更新后的图数据
|
||||
if result["linked_count"] > 0:
|
||||
await self.persistence.save_graph_store(self.graph_store)
|
||||
logger.info(f"已保存 {result['linked_count']} 条自动关联边")
|
||||
|
||||
logger.info(f"自动关联完成: {result}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"自动关联失败: {e}", exc_info=True)
|
||||
return {"error": str(e), "checked_count": 0, "linked_count": 0}
|
||||
|
||||
async def _find_link_candidates(
|
||||
self,
|
||||
memory: Memory,
|
||||
exclude_ids: Set[str],
|
||||
max_results: int = 5,
|
||||
) -> List[Memory]:
|
||||
"""
|
||||
为记忆寻找关联候选
|
||||
|
||||
使用向量相似度 + 时间接近度找到潜在相关记忆
|
||||
"""
|
||||
try:
|
||||
# 获取记忆的主题
|
||||
topic_node = next(
|
||||
(n for n in memory.nodes if n.node_type == NodeType.TOPIC),
|
||||
None
|
||||
)
|
||||
|
||||
if not topic_node or not topic_node.content:
|
||||
return []
|
||||
|
||||
# 使用主题内容搜索相似记忆
|
||||
candidates = await self.search_memories(
|
||||
query=topic_node.content,
|
||||
top_k=max_results * 2,
|
||||
include_forgotten=False,
|
||||
optimize_query=False,
|
||||
)
|
||||
|
||||
# 过滤:排除自己和已关联的
|
||||
existing_targets = {
|
||||
e.target_id for e in memory.edges
|
||||
if e.edge_type == EdgeType.RELATION
|
||||
}
|
||||
|
||||
filtered = [
|
||||
c for c in candidates
|
||||
if c.id not in exclude_ids
|
||||
and c.id not in existing_targets
|
||||
]
|
||||
|
||||
return filtered[:max_results]
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"查找候选失败: {e}")
|
||||
return []
|
||||
|
||||
async def _analyze_memory_relations(
|
||||
self,
|
||||
source_memory: Memory,
|
||||
candidate_memories: List[Memory],
|
||||
min_confidence: float = 0.7,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
使用LLM分析记忆之间的关系
|
||||
|
||||
Args:
|
||||
source_memory: 源记忆
|
||||
candidate_memories: 候选记忆列表
|
||||
min_confidence: 最低置信度
|
||||
|
||||
Returns:
|
||||
关系列表,每项包含:
|
||||
- target_memory: 目标记忆
|
||||
- relation_type: 关系类型
|
||||
- confidence: 置信度
|
||||
- reasoning: 推理过程
|
||||
"""
|
||||
try:
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
|
||||
# 构建LLM请求
|
||||
llm = LLMRequest(
|
||||
model_set=model_config.model_task_config.utils_small,
|
||||
request_type="memory.relation_analysis"
|
||||
)
|
||||
|
||||
# 格式化记忆信息
|
||||
source_desc = self._format_memory_for_llm(source_memory)
|
||||
candidates_desc = "\n\n".join([
|
||||
f"记忆{i+1}:\n{self._format_memory_for_llm(mem)}"
|
||||
for i, mem in enumerate(candidate_memories)
|
||||
])
|
||||
|
||||
# 构建提示词
|
||||
prompt = f"""你是一个记忆关系分析专家。请分析源记忆与候选记忆之间是否存在有意义的关系。
|
||||
|
||||
**关系类型说明:**
|
||||
- 导致: A的发生导致了B的发生(因果关系)
|
||||
- 引用: A提到或涉及B(引用关系)
|
||||
- 相似: A和B描述相似的内容(相似关系)
|
||||
- 相反: A和B表达相反的观点(对立关系)
|
||||
- 关联: A和B存在某种关联但不属于以上类型(一般关联)
|
||||
|
||||
**源记忆:**
|
||||
{source_desc}
|
||||
|
||||
**候选记忆:**
|
||||
{candidates_desc}
|
||||
|
||||
**任务要求:**
|
||||
1. 对每个候选记忆,判断是否与源记忆存在关系
|
||||
2. 如果存在关系,指定关系类型和置信度(0.0-1.0)
|
||||
3. 简要说明判断理由
|
||||
4. 只返回置信度 >= {min_confidence} 的关系
|
||||
|
||||
**输出格式(JSON):**
|
||||
```json
|
||||
[
|
||||
{{
|
||||
"candidate_id": 1,
|
||||
"has_relation": true,
|
||||
"relation_type": "导致",
|
||||
"confidence": 0.85,
|
||||
"reasoning": "记忆1是记忆源的结果"
|
||||
}},
|
||||
{{
|
||||
"candidate_id": 2,
|
||||
"has_relation": false,
|
||||
"reasoning": "两者无明显关联"
|
||||
}}
|
||||
]
|
||||
```
|
||||
|
||||
请分析并输出JSON结果:"""
|
||||
|
||||
# 调用LLM
|
||||
response, _ = await llm.generate_response_async(
|
||||
prompt,
|
||||
temperature=0.3,
|
||||
max_tokens=1000,
|
||||
)
|
||||
|
||||
# 解析响应
|
||||
import json
|
||||
import re
|
||||
|
||||
# 提取JSON
|
||||
json_match = re.search(r'```json\s*(.*?)\s*```', response, re.DOTALL)
|
||||
if json_match:
|
||||
json_str = json_match.group(1)
|
||||
else:
|
||||
json_str = response.strip()
|
||||
|
||||
try:
|
||||
analysis_results = json.loads(json_str)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"LLM返回格式错误,尝试修复: {response[:200]}")
|
||||
# 尝试简单修复
|
||||
json_str = re.sub(r'[\r\n\t]', '', json_str)
|
||||
analysis_results = json.loads(json_str)
|
||||
|
||||
# 转换为结果格式
|
||||
relations = []
|
||||
for result in analysis_results:
|
||||
if not result.get("has_relation", False):
|
||||
continue
|
||||
|
||||
confidence = result.get("confidence", 0.0)
|
||||
if confidence < min_confidence:
|
||||
continue
|
||||
|
||||
candidate_id = result.get("candidate_id", 0) - 1
|
||||
if 0 <= candidate_id < len(candidate_memories):
|
||||
relations.append({
|
||||
"target_memory": candidate_memories[candidate_id],
|
||||
"relation_type": result.get("relation_type", "关联"),
|
||||
"confidence": confidence,
|
||||
"reasoning": result.get("reasoning", ""),
|
||||
})
|
||||
|
||||
logger.debug(f"LLM分析完成: 发现 {len(relations)} 个关系")
|
||||
return relations
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM关系分析失败: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
def _format_memory_for_llm(self, memory: Memory) -> str:
|
||||
"""格式化记忆为LLM可读的文本"""
|
||||
try:
|
||||
# 获取关键节点
|
||||
subject_node = next(
|
||||
(n for n in memory.nodes if n.node_type == NodeType.SUBJECT),
|
||||
None
|
||||
)
|
||||
topic_node = next(
|
||||
(n for n in memory.nodes if n.node_type == NodeType.TOPIC),
|
||||
None
|
||||
)
|
||||
object_node = next(
|
||||
(n for n in memory.nodes if n.node_type == NodeType.OBJECT),
|
||||
None
|
||||
)
|
||||
|
||||
parts = []
|
||||
parts.append(f"类型: {memory.memory_type.value}")
|
||||
|
||||
if subject_node:
|
||||
parts.append(f"主体: {subject_node.content}")
|
||||
|
||||
if topic_node:
|
||||
parts.append(f"主题: {topic_node.content}")
|
||||
|
||||
if object_node:
|
||||
parts.append(f"对象: {object_node.content}")
|
||||
|
||||
parts.append(f"重要性: {memory.importance:.2f}")
|
||||
parts.append(f"时间: {memory.created_at.strftime('%Y-%m-%d %H:%M')}")
|
||||
|
||||
return " | ".join(parts)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"格式化记忆失败: {e}")
|
||||
return f"记忆ID: {memory.id}"
|
||||
|
||||
async def maintenance(self) -> Dict[str, Any]:
|
||||
"""
|
||||
执行维护任务
|
||||
@@ -885,17 +1256,22 @@ class MemoryManager:
|
||||
)
|
||||
result["consolidated"] = consolidate_result.get("merged_count", 0)
|
||||
|
||||
# 2. 自动遗忘
|
||||
# 2. 自动关联记忆(发现和建立关系)
|
||||
if self.config.auto_link_enabled:
|
||||
link_result = await self.auto_link_memories()
|
||||
result["linked"] = link_result.get("linked_count", 0)
|
||||
|
||||
# 3. 自动遗忘
|
||||
if self.config.forgetting_enabled:
|
||||
forgotten_count = await self.auto_forget_memories(
|
||||
threshold=self.config.forgetting_activation_threshold
|
||||
)
|
||||
result["forgotten"] = forgotten_count
|
||||
|
||||
# 3. 清理非常旧的已遗忘记忆(可选)
|
||||
# 4. 清理非常旧的已遗忘记忆(可选)
|
||||
# TODO: 实现清理逻辑
|
||||
|
||||
# 4. 保存数据
|
||||
# 5. 保存数据
|
||||
await self.persistence.save_graph_store(self.graph_store)
|
||||
result["saved"] = True
|
||||
|
||||
|
||||
Reference in New Issue
Block a user