feat(memory_tools): 添加优先节点类型支持,优化多查询生成与记忆扩展逻辑
This commit is contained in:
@@ -303,6 +303,14 @@ class MemoryTools:
|
|||||||
"maximum": 3,
|
"maximum": 3,
|
||||||
"description": "图扩展深度(0-3,默认1):\n- 0: 仅返回直接匹配的记忆\n- 1: 包含一度相关的记忆(推荐)\n- 2-3: 包含更多间接相关的记忆(用于深度探索)",
|
"description": "图扩展深度(0-3,默认1):\n- 0: 仅返回直接匹配的记忆\n- 1: 包含一度相关的记忆(推荐)\n- 2-3: 包含更多间接相关的记忆(用于深度探索)",
|
||||||
},
|
},
|
||||||
|
"prefer_node_types": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["ATTRIBUTE", "REFERENCE", "ENTITY", "EVENT", "RELATION"],
|
||||||
|
},
|
||||||
|
"description": "优先召回的节点类型(可选):\n- ATTRIBUTE: 属性信息(如配置、参数)\n- REFERENCE: 引用信息(如文档地址、链接)\n- ENTITY: 实体信息(如人物、组织)\n- EVENT: 事件信息(如活动、对话)\n- RELATION: 关系信息(如人际关系)",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
"required": ["query"],
|
"required": ["query"],
|
||||||
},
|
},
|
||||||
@@ -447,8 +455,9 @@ class MemoryTools:
|
|||||||
**params: 工具参数
|
**params: 工具参数
|
||||||
- query: 查询字符串
|
- query: 查询字符串
|
||||||
- top_k: 返回结果数(默认10)
|
- top_k: 返回结果数(默认10)
|
||||||
- expand_depth: 扩展深度(暂未使用)
|
- expand_depth: 扩展深度(默认使用配置)
|
||||||
- use_multi_query: 是否使用多查询策略(默认True)
|
- use_multi_query: 是否使用多查询策略(默认True)
|
||||||
|
- prefer_node_types: 优先召回的节点类型列表(可选)
|
||||||
- context: 查询上下文(可选)
|
- context: 查询上下文(可选)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -457,24 +466,36 @@ class MemoryTools:
|
|||||||
try:
|
try:
|
||||||
query = params.get("query", "")
|
query = params.get("query", "")
|
||||||
top_k = params.get("top_k", 10)
|
top_k = params.get("top_k", 10)
|
||||||
# 使用配置中的默认值而不是硬编码的 1
|
|
||||||
expand_depth = params.get("expand_depth", self.max_expand_depth)
|
expand_depth = params.get("expand_depth", self.max_expand_depth)
|
||||||
use_multi_query = params.get("use_multi_query", True)
|
use_multi_query = params.get("use_multi_query", True)
|
||||||
|
prefer_node_types = params.get("prefer_node_types", []) # 🆕 优先节点类型
|
||||||
context = params.get("context", None)
|
context = params.get("context", None)
|
||||||
|
|
||||||
logger.info(f"搜索记忆: {query} (top_k={top_k}, expand_depth={expand_depth}, multi_query={use_multi_query})")
|
logger.info(
|
||||||
|
f"搜索记忆: {query} (top_k={top_k}, expand_depth={expand_depth}, "
|
||||||
|
f"multi_query={use_multi_query}, prefer_types={prefer_node_types})"
|
||||||
|
)
|
||||||
|
|
||||||
# 0. 确保初始化
|
# 0. 确保初始化
|
||||||
await self._ensure_initialized()
|
await self._ensure_initialized()
|
||||||
|
|
||||||
# 1. 根据策略选择检索方式
|
# 1. 根据策略选择检索方式
|
||||||
|
llm_prefer_types = [] # LLM识别的偏好节点类型
|
||||||
|
|
||||||
if use_multi_query:
|
if use_multi_query:
|
||||||
# 多查询策略
|
# 多查询策略(返回节点列表 + 偏好类型)
|
||||||
similar_nodes = await self._multi_query_search(query, top_k, context)
|
similar_nodes, llm_prefer_types = await self._multi_query_search(query, top_k, context)
|
||||||
else:
|
else:
|
||||||
# 传统单查询策略
|
# 传统单查询策略
|
||||||
similar_nodes = await self._single_query_search(query, top_k)
|
similar_nodes = await self._single_query_search(query, top_k)
|
||||||
|
|
||||||
|
# 合并用户指定的偏好类型和LLM识别的偏好类型
|
||||||
|
all_prefer_types = list(set(prefer_node_types + llm_prefer_types))
|
||||||
|
if all_prefer_types:
|
||||||
|
logger.info(f"最终偏好节点类型: {all_prefer_types} (用户指定: {prefer_node_types}, LLM识别: {llm_prefer_types})")
|
||||||
|
# 更新prefer_node_types用于后续评分
|
||||||
|
prefer_node_types = all_prefer_types
|
||||||
|
|
||||||
# 2. 提取初始记忆ID(来自向量搜索)
|
# 2. 提取初始记忆ID(来自向量搜索)
|
||||||
initial_memory_ids = set()
|
initial_memory_ids = set()
|
||||||
memory_scores = {} # 记录每个记忆的初始分数
|
memory_scores = {} # 记录每个记忆的初始分数
|
||||||
@@ -519,10 +540,10 @@ class MemoryTools:
|
|||||||
max_expanded=top_k * 2
|
max_expanded=top_k * 2
|
||||||
)
|
)
|
||||||
|
|
||||||
# 合并扩展结果
|
# 合并扩展结果
|
||||||
expanded_memory_scores.update(dict(expanded_results))
|
expanded_memory_scores.update(dict(expanded_results))
|
||||||
|
|
||||||
logger.info(f"图扩展完成: 新增{len(expanded_memory_scores)}个相关记忆")
|
logger.info(f"图扩展完成: 新增{len(expanded_memory_scores)}个相关记忆")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"图扩展失败: {e}")
|
logger.warning(f"图扩展失败: {e}")
|
||||||
@@ -547,12 +568,12 @@ class MemoryTools:
|
|||||||
reverse=True
|
reverse=True
|
||||||
)[:top_k * 2] # 取2倍数量用于后续过滤
|
)[:top_k * 2] # 取2倍数量用于后续过滤
|
||||||
|
|
||||||
# 5. 获取完整记忆并进行最终排序
|
# 5. 获取完整记忆并进行最终排序(优化后的动态权重系统)
|
||||||
memories_with_scores = []
|
memories_with_scores = []
|
||||||
for memory_id in sorted_memory_ids:
|
for memory_id in sorted_memory_ids:
|
||||||
memory = self.graph_store.get_memory_by_id(memory_id)
|
memory = self.graph_store.get_memory_by_id(memory_id)
|
||||||
if memory:
|
if memory:
|
||||||
# 综合评分:相似度(40%) + 重要性(20%) + 时效性(10%) + 激活度(30%)
|
# 基础分数
|
||||||
similarity_score = final_scores[memory_id]
|
similarity_score = final_scores[memory_id]
|
||||||
importance_score = memory.importance
|
importance_score = memory.importance
|
||||||
|
|
||||||
@@ -567,43 +588,101 @@ class MemoryTools:
|
|||||||
age_days = (now - memory_time).total_seconds() / 86400
|
age_days = (now - memory_time).total_seconds() / 86400
|
||||||
recency_score = 1.0 / (1.0 + age_days / 30) # 30天半衰期
|
recency_score = 1.0 / (1.0 + age_days / 30) # 30天半衰期
|
||||||
|
|
||||||
# 获取激活度分数(从metadata中读取,兼容memory.activation字段)
|
# 获取激活度分数
|
||||||
activation_info = memory.metadata.get("activation", {})
|
activation_info = memory.metadata.get("activation", {})
|
||||||
activation_score = activation_info.get("level", memory.activation)
|
activation_score = activation_info.get("level", memory.activation)
|
||||||
|
|
||||||
# 如果metadata中没有激活度信息,使用memory.activation作为备选
|
|
||||||
if activation_score == 0.0 and memory.activation > 0.0:
|
if activation_score == 0.0 and memory.activation > 0.0:
|
||||||
activation_score = memory.activation
|
activation_score = memory.activation
|
||||||
|
|
||||||
# 综合分数 - 加入激活度影响
|
# 🆕 动态权重计算:根据记忆类型和节点类型自适应调整
|
||||||
|
memory_type = memory.memory_type.value if hasattr(memory.memory_type, 'value') else str(memory.memory_type)
|
||||||
|
|
||||||
|
# 检测记忆的主要节点类型
|
||||||
|
node_types_count = {}
|
||||||
|
for node in memory.nodes:
|
||||||
|
nt = node.node_type.value if hasattr(node.node_type, 'value') else str(node.node_type)
|
||||||
|
node_types_count[nt] = node_types_count.get(nt, 0) + 1
|
||||||
|
|
||||||
|
dominant_node_type = max(node_types_count.items(), key=lambda x: x[1])[0] if node_types_count else "unknown"
|
||||||
|
|
||||||
|
# 根据节点类型动态调整权重
|
||||||
|
if dominant_node_type in ["ATTRIBUTE", "REFERENCE"] or memory_type == "FACT":
|
||||||
|
# 事实性记忆(如文档地址、配置信息):语义相似度最重要
|
||||||
|
weights = {
|
||||||
|
"similarity": 0.65, # 语义相似度 65% ⬆️
|
||||||
|
"importance": 0.20, # 重要性 20%
|
||||||
|
"recency": 0.05, # 时效性 5% ⬇️(事实不随时间失效)
|
||||||
|
"activation": 0.10 # 激活度 10% ⬇️(避免冷门信息被压制)
|
||||||
|
}
|
||||||
|
elif memory_type in ["CONVERSATION", "EPISODIC"] or dominant_node_type == "EVENT":
|
||||||
|
# 对话/事件记忆:时效性和激活度更重要
|
||||||
|
weights = {
|
||||||
|
"similarity": 0.45, # 语义相似度 45%
|
||||||
|
"importance": 0.15, # 重要性 15%
|
||||||
|
"recency": 0.20, # 时效性 20% ⬆️
|
||||||
|
"activation": 0.20 # 激活度 20%
|
||||||
|
}
|
||||||
|
elif dominant_node_type == "ENTITY" or memory_type == "SEMANTIC":
|
||||||
|
# 实体/语义记忆:平衡各项
|
||||||
|
weights = {
|
||||||
|
"similarity": 0.50, # 语义相似度 50%
|
||||||
|
"importance": 0.25, # 重要性 25%
|
||||||
|
"recency": 0.10, # 时效性 10%
|
||||||
|
"activation": 0.15 # 激活度 15%
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# 默认权重(保守策略,偏向语义)
|
||||||
|
weights = {
|
||||||
|
"similarity": 0.55, # 语义相似度 55%
|
||||||
|
"importance": 0.20, # 重要性 20%
|
||||||
|
"recency": 0.10, # 时效性 10%
|
||||||
|
"activation": 0.15 # 激活度 15%
|
||||||
|
}
|
||||||
|
|
||||||
|
# 综合分数计算
|
||||||
final_score = (
|
final_score = (
|
||||||
similarity_score * 0.4 + # 向量相似度 40%
|
similarity_score * weights["similarity"] +
|
||||||
importance_score * 0.2 + # 重要性 20%
|
importance_score * weights["importance"] +
|
||||||
recency_score * 0.1 + # 时效性 10%
|
recency_score * weights["recency"] +
|
||||||
activation_score * 0.3 # 激活度 30% ← 新增
|
activation_score * weights["activation"]
|
||||||
)
|
)
|
||||||
|
|
||||||
memories_with_scores.append((memory, final_score))
|
# 🆕 节点类型加权:对REFERENCE/ATTRIBUTE节点额外加分(促进事实性信息召回)
|
||||||
|
if "REFERENCE" in node_types_count or "ATTRIBUTE" in node_types_count:
|
||||||
|
final_score *= 1.1 # 10% 加成
|
||||||
|
|
||||||
|
# 🆕 用户指定的优先节点类型额外加权
|
||||||
|
if prefer_node_types:
|
||||||
|
for prefer_type in prefer_node_types:
|
||||||
|
if prefer_type in node_types_count:
|
||||||
|
final_score *= 1.15 # 15% 额外加成
|
||||||
|
logger.debug(f"记忆 {memory.id[:8]} 包含优先节点类型 {prefer_type},加权后分数: {final_score:.4f}")
|
||||||
|
break
|
||||||
|
|
||||||
|
memories_with_scores.append((memory, final_score, dominant_node_type))
|
||||||
|
|
||||||
# 按综合分数排序
|
# 按综合分数排序
|
||||||
memories_with_scores.sort(key=lambda x: x[1], reverse=True)
|
memories_with_scores.sort(key=lambda x: x[1], reverse=True)
|
||||||
memories = [mem for mem, _ in memories_with_scores[:top_k]]
|
memories = [mem for mem, _, _ in memories_with_scores[:top_k]]
|
||||||
|
|
||||||
# 6. 格式化结果
|
# 6. 格式化结果(包含调试信息)
|
||||||
results = []
|
results = []
|
||||||
for memory in memories:
|
for memory, score, node_type in memories_with_scores[:top_k]:
|
||||||
result = {
|
result = {
|
||||||
"memory_id": memory.id,
|
"memory_id": memory.id,
|
||||||
"importance": memory.importance,
|
"importance": memory.importance,
|
||||||
"created_at": memory.created_at.isoformat(),
|
"created_at": memory.created_at.isoformat(),
|
||||||
"summary": self._summarize_memory(memory),
|
"summary": self._summarize_memory(memory),
|
||||||
|
"score": round(score, 4), # 🆕 暴露最终分数,便于调试
|
||||||
|
"dominant_node_type": node_type, # 🆕 暴露节点类型
|
||||||
}
|
}
|
||||||
results.append(result)
|
results.append(result)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"搜索完成: 初始{len(initial_memory_ids)}个 → "
|
f"搜索完成: 初始{len(initial_memory_ids)}个 → "
|
||||||
f"扩展{len(expanded_memory_scores)}个 → "
|
f"扩展{len(expanded_memory_scores)}个 → "
|
||||||
f"最终返回{len(results)}条记忆"
|
f"最终返回{len(results)}条记忆 "
|
||||||
|
f"(节点类型分布: {', '.join(f'{nt}:{ct}' for nt, ct in sorted(set((r['dominant_node_type'], 1) for r in results))[:3])})"
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -627,11 +706,14 @@ class MemoryTools:
|
|||||||
|
|
||||||
async def _generate_multi_queries_simple(
|
async def _generate_multi_queries_simple(
|
||||||
self, query: str, context: dict[str, Any] | None = None
|
self, query: str, context: dict[str, Any] | None = None
|
||||||
) -> list[tuple[str, float]]:
|
) -> tuple[list[tuple[str, float]], list[str]]:
|
||||||
"""
|
"""
|
||||||
简化版多查询生成(直接在 Tools 层实现,避免循环依赖)
|
简化版多查询生成(直接在 Tools 层实现,避免循环依赖)
|
||||||
|
|
||||||
让小模型直接生成3-5个不同角度的查询语句。
|
让小模型直接生成3-5个不同角度的查询语句,并识别偏好的节点类型。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(查询列表, 偏好节点类型列表)
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from src.config.config import model_config
|
from src.config.config import model_config
|
||||||
@@ -655,7 +737,7 @@ class MemoryTools:
|
|||||||
recent_lines = lines[-5:] if len(lines) > 5 else lines
|
recent_lines = lines[-5:] if len(lines) > 5 else lines
|
||||||
recent_chat = "\n".join(recent_lines)
|
recent_chat = "\n".join(recent_lines)
|
||||||
|
|
||||||
prompt = f"""基于聊天上下文为查询生成3-5个不同角度的搜索语句(JSON格式)。
|
prompt = f"""基于聊天上下文为查询生成3-5个不同角度的搜索语句,并识别查询意图对应的记忆类型(JSON格式)。
|
||||||
|
|
||||||
**当前查询:** {query}
|
**当前查询:** {query}
|
||||||
**发送者:** {sender if sender else '未知'}
|
**发送者:** {sender if sender else '未知'}
|
||||||
@@ -665,51 +747,178 @@ class MemoryTools:
|
|||||||
**最近聊天记录(最近5条):**
|
**最近聊天记录(最近5条):**
|
||||||
{recent_chat if recent_chat else '无聊天历史'}
|
{recent_chat if recent_chat else '无聊天历史'}
|
||||||
|
|
||||||
**分析原则:**
|
---
|
||||||
|
|
||||||
|
## 第一步:分析查询意图与记忆类型
|
||||||
|
|
||||||
|
### 记忆类型识别表(按优先级判断)
|
||||||
|
|
||||||
|
| 查询特征 | 偏好节点类型 | 示例 |
|
||||||
|
|---------|-------------|------|
|
||||||
|
| 🔗 **查找链接/地址/URL/网址/文档位置** | `REFERENCE` | "xxx的文档地址"、"那个网站链接" |
|
||||||
|
| ⚙️ **查询配置/参数/设置/属性值** | `ATTRIBUTE` | "Python版本是多少"、"数据库配置" |
|
||||||
|
| 👤 **询问人物/组织/实体身份** | `ENTITY` | "拾风是谁"、"MoFox团队成员" |
|
||||||
|
| 🔄 **询问关系/人际/交互** | `RELATION` | "我和机器人的关系"、"谁认识谁" |
|
||||||
|
| 📅 **回忆事件/对话/活动** | `EVENT` | "上次聊了什么"、"昨天的会议" |
|
||||||
|
| 💡 **查询概念/定义/知识** | 无特定偏好 | "什么是记忆图谱" |
|
||||||
|
|
||||||
|
### 判断规则
|
||||||
|
- 如果查询包含"地址"、"链接"、"URL"、"网址"、"文档"等关键词 → `REFERENCE`
|
||||||
|
- 如果查询包含"配置"、"参数"、"设置"、"版本"、"属性"等关键词 → `ATTRIBUTE`
|
||||||
|
- 如果查询询问"是谁"、"什么人"、"团队"、"组织"等 → `ENTITY`
|
||||||
|
- 如果查询询问"关系"、"朋友"、"认识"等 → `RELATION`
|
||||||
|
- 如果查询回忆"上次"、"之前"、"讨论过"、"聊过"等 → `EVENT`
|
||||||
|
- 如果无明确特征 → 不指定类型(空列表)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 第二步:生成多角度查询
|
||||||
|
|
||||||
|
### 分析原则
|
||||||
1. **上下文理解**:根据聊天历史理解查询的真实意图
|
1. **上下文理解**:根据聊天历史理解查询的真实意图
|
||||||
2. **指代消解**:识别并代换"他"、"她"、"它"、"那个"等指代词
|
2. **指代消解**:识别并代换"他"、"她"、"它"、"那个"等指代词为具体实体名
|
||||||
3. **话题关联**:结合最近讨论的话题生成更精准的查询
|
3. **话题关联**:结合最近讨论的话题生成更精准的查询
|
||||||
4. **查询分解**:对复杂查询分解为多个子查询
|
4. **查询分解**:对复杂查询分解为多个子查询
|
||||||
|
5. **实体提取**:显式提取查询中的关键实体(人名、项目名、组织名等)
|
||||||
|
|
||||||
**生成策略:**
|
### 生成策略(按顺序)
|
||||||
1. **完整查询**(权重1.0):结合上下文的完整查询,包含指代消解
|
1. **完整查询**(权重1.0):结合上下文的完整查询,包含指代消解后的实体名
|
||||||
2. **关键概念查询**(权重0.8):查询中的核心概念,特别是聊天中提到的实体
|
2. **关键实体查询**(权重0.9):只包含核心实体,去除修饰词(如"xxx的"→"xxx")
|
||||||
3. **话题扩展查询**(权重0.7):基于最近聊天话题的相关查询
|
3. **同义表达查询**(权重0.8):用不同表达方式重述查询意图
|
||||||
4. **动作/情感查询**(权重0.6):如果涉及情感或动作,生成相关查询
|
4. **话题扩展查询**(权重0.7):基于最近聊天话题的相关查询
|
||||||
5. **精准时间查询**(权重0.5):针对时间相关的查询,生成更具体的时间范围,如2023年5月1日 12:00
|
5. **时间范围查询**(权重0.6,如适用):如果涉及时间,生成具体时间范围
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 输出格式(严格JSON)
|
||||||
|
|
||||||
**输出JSON格式:**
|
|
||||||
```json
|
```json
|
||||||
{{"queries": [{{"text": "查询语句", "weight": 1.0}}, {{"text": "查询语句", "weight": 0.8}}]}}
|
{{
|
||||||
|
"prefer_node_types": ["REFERENCE", "ATTRIBUTE"],
|
||||||
|
"queries": [
|
||||||
|
{{"text": "完整查询(已消解指代)", "weight": 1.0}},
|
||||||
|
{{"text": "核心实体查询", "weight": 0.9}},
|
||||||
|
{{"text": "同义表达查询", "weight": 0.8}}
|
||||||
|
]
|
||||||
|
}}
|
||||||
```
|
```
|
||||||
|
|
||||||
**示例:**
|
**字段说明**:
|
||||||
- 查询:"他怎么样了?" + 聊天中提到"小明生病了" → "小明身体恢复情况"
|
- `prefer_node_types`: 偏好的节点类型数组,可选值:`REFERENCE`、`ATTRIBUTE`、`ENTITY`、`RELATION`、`EVENT`,如无明确特征则为空数组`[]`
|
||||||
- 查询:"那个项目" + 聊天中讨论"记忆系统开发" → "记忆系统项目进展"
|
- `queries`: 查询数组,每个查询包含`text`(查询文本)和`weight`(权重0.5-1.0)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 示例
|
||||||
|
|
||||||
|
### 示例1:查询文档地址
|
||||||
|
**输入**:
|
||||||
|
- 查询:"你知道MoFox-Bot的文档地址吗?"
|
||||||
|
- 聊天历史:无
|
||||||
|
|
||||||
|
**输出**:
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"prefer_node_types": ["REFERENCE"],
|
||||||
|
"queries": [
|
||||||
|
{{"text": "MoFox-Bot文档地址", "weight": 1.0}},
|
||||||
|
{{"text": "MoFox-Bot", "weight": 0.9}},
|
||||||
|
{{"text": "MoFox-Bot官方文档URL", "weight": 0.8}}
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 示例2:查询人物关系
|
||||||
|
**输入**:
|
||||||
|
- 查询:"拾风是谁?"
|
||||||
|
- 聊天历史:提到过"拾风和杰瑞喵"
|
||||||
|
|
||||||
|
**输出**:
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"prefer_node_types": ["ENTITY", "RELATION"],
|
||||||
|
"queries": [
|
||||||
|
{{"text": "拾风身份信息", "weight": 1.0}},
|
||||||
|
{{"text": "拾风", "weight": 0.9}},
|
||||||
|
{{"text": "拾风和杰瑞喵的关系", "weight": 0.8}}
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 示例3:查询配置参数
|
||||||
|
**输入**:
|
||||||
|
- 查询:"Python版本是多少?"
|
||||||
|
- 聊天历史:讨论过"项目环境配置"
|
||||||
|
|
||||||
|
**输出**:
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"prefer_node_types": ["ATTRIBUTE"],
|
||||||
|
"queries": [
|
||||||
|
{{"text": "Python版本号", "weight": 1.0}},
|
||||||
|
{{"text": "Python配置", "weight": 0.9}},
|
||||||
|
{{"text": "项目Python环境版本", "weight": 0.8}}
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 示例4:回忆对话(无明确类型)
|
||||||
|
**输入**:
|
||||||
|
- 查询:"我们上次聊了什么?"
|
||||||
|
- 聊天历史:最近讨论"记忆系统优化"
|
||||||
|
|
||||||
|
**输出**:
|
||||||
|
```json
|
||||||
|
{{
|
||||||
|
"prefer_node_types": ["EVENT"],
|
||||||
|
"queries": [
|
||||||
|
{{"text": "最近对话内容", "weight": 1.0}},
|
||||||
|
{{"text": "记忆系统优化讨论", "weight": 0.9}},
|
||||||
|
{{"text": "上次聊天记录", "weight": 0.8}}
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**现在请根据上述规则生成输出(仅输出JSON,不要其他内容):**
|
||||||
"""
|
"""
|
||||||
|
|
||||||
response, _ = await llm.generate_response_async(prompt, temperature=0.3, max_tokens=250)
|
response, _ = await llm.generate_response_async(prompt, temperature=0.3, max_tokens=300)
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import orjson
|
import orjson
|
||||||
|
|
||||||
|
# 清理Markdown代码块
|
||||||
response = re.sub(r"```json\s*", "", response)
|
response = re.sub(r"```json\s*", "", response)
|
||||||
response = re.sub(r"```\s*$", "", response).strip()
|
response = re.sub(r"```\s*$", "", response).strip()
|
||||||
|
|
||||||
|
# 解析JSON
|
||||||
data = orjson.loads(response)
|
data = orjson.loads(response)
|
||||||
|
|
||||||
|
# 提取查询列表
|
||||||
queries = data.get("queries", [])
|
queries = data.get("queries", [])
|
||||||
|
result_queries = [(item.get("text", "").strip(), float(item.get("weight", 0.5)))
|
||||||
|
for item in queries if item.get("text", "").strip()]
|
||||||
|
|
||||||
result = [(item.get("text", "").strip(), float(item.get("weight", 0.5)))
|
# 提取偏好节点类型
|
||||||
for item in queries if item.get("text", "").strip()]
|
prefer_node_types = data.get("prefer_node_types", [])
|
||||||
|
# 确保类型正确且有效
|
||||||
|
valid_types = {"REFERENCE", "ATTRIBUTE", "ENTITY", "RELATION", "EVENT"}
|
||||||
|
prefer_node_types = [t for t in prefer_node_types if t in valid_types]
|
||||||
|
|
||||||
if result:
|
if result_queries:
|
||||||
logger.info(f"生成查询: {[q for q, _ in result]}")
|
logger.info(
|
||||||
return result
|
f"生成查询: {[q for q, _ in result_queries]} "
|
||||||
|
f"(偏好类型: {prefer_node_types if prefer_node_types else '无'})"
|
||||||
|
)
|
||||||
|
return result_queries, prefer_node_types
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"多查询生成失败: {e}")
|
logger.warning(f"多查询生成失败: {e}")
|
||||||
|
|
||||||
return [(query, 1.0)]
|
# 降级:返回原始查询和空的节点类型列表
|
||||||
|
return [(query, 1.0)], []
|
||||||
|
|
||||||
async def _single_query_search(
|
async def _single_query_search(
|
||||||
self, query: str, top_k: int
|
self, query: str, top_k: int
|
||||||
@@ -744,14 +953,14 @@ class MemoryTools:
|
|||||||
|
|
||||||
async def _multi_query_search(
|
async def _multi_query_search(
|
||||||
self, query: str, top_k: int, context: dict[str, Any] | None = None
|
self, query: str, top_k: int, context: dict[str, Any] | None = None
|
||||||
) -> list[tuple[str, float, dict[str, Any]]]:
|
) -> tuple[list[tuple[str, float, dict[str, Any]]], list[str]]:
|
||||||
"""
|
"""
|
||||||
多查询策略搜索(简化版)
|
多查询策略搜索(简化版 + 节点类型识别)
|
||||||
|
|
||||||
直接使用小模型生成多个查询,无需复杂的分解和组合。
|
直接使用小模型生成多个查询,并识别查询意图对应的偏好节点类型。
|
||||||
|
|
||||||
步骤:
|
步骤:
|
||||||
1. 让小模型生成3-5个不同角度的查询
|
1. 让小模型生成3-5个不同角度的查询 + 识别偏好节点类型
|
||||||
2. 为每个查询生成嵌入
|
2. 为每个查询生成嵌入
|
||||||
3. 并行搜索并融合结果
|
3. 并行搜索并融合结果
|
||||||
|
|
||||||
@@ -761,18 +970,19 @@ class MemoryTools:
|
|||||||
context: 查询上下文
|
context: 查询上下文
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
融合后的相似节点列表
|
(融合后的相似节点列表, 偏好节点类型列表)
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 1. 使用小模型生成多个查询
|
# 1. 使用小模型生成多个查询 + 节点类型识别
|
||||||
multi_queries = await self._generate_multi_queries_simple(query, context)
|
multi_queries, prefer_node_types = await self._generate_multi_queries_simple(query, context)
|
||||||
|
|
||||||
logger.debug(f"生成 {len(multi_queries)} 个查询: {multi_queries}")
|
logger.debug(f"生成 {len(multi_queries)} 个查询: {multi_queries}, 偏好类型: {prefer_node_types}")
|
||||||
|
|
||||||
# 2. 生成所有查询的嵌入
|
# 2. 生成所有查询的嵌入
|
||||||
if not self.builder.embedding_generator:
|
if not self.builder.embedding_generator:
|
||||||
logger.warning("未配置嵌入生成器,回退到单查询模式")
|
logger.warning("未配置嵌入生成器,回退到单查询模式")
|
||||||
return await self._single_query_search(query, top_k)
|
single_results = await self._single_query_search(query, top_k)
|
||||||
|
return single_results, prefer_node_types
|
||||||
|
|
||||||
query_embeddings = []
|
query_embeddings = []
|
||||||
query_weights = []
|
query_weights = []
|
||||||
@@ -786,7 +996,8 @@ class MemoryTools:
|
|||||||
# 如果所有嵌入都生成失败,回退到单查询模式
|
# 如果所有嵌入都生成失败,回退到单查询模式
|
||||||
if not query_embeddings:
|
if not query_embeddings:
|
||||||
logger.warning("所有查询嵌入生成失败,回退到单查询模式")
|
logger.warning("所有查询嵌入生成失败,回退到单查询模式")
|
||||||
return await self._single_query_search(query, top_k)
|
single_results = await self._single_query_search(query, top_k)
|
||||||
|
return single_results, prefer_node_types
|
||||||
|
|
||||||
# 3. 多查询融合搜索
|
# 3. 多查询融合搜索
|
||||||
similar_nodes = await self.vector_store.search_with_multiple_queries(
|
similar_nodes = await self.vector_store.search_with_multiple_queries(
|
||||||
@@ -796,13 +1007,14 @@ class MemoryTools:
|
|||||||
fusion_strategy="weighted_max",
|
fusion_strategy="weighted_max",
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"多查询检索完成: {len(similar_nodes)} 个节点")
|
logger.info(f"多查询检索完成: {len(similar_nodes)} 个节点 (偏好类型: {prefer_node_types})")
|
||||||
|
|
||||||
return similar_nodes
|
return similar_nodes, prefer_node_types
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"多查询搜索失败,回退到单查询模式: {e}", exc_info=True)
|
logger.warning(f"多查询搜索失败,回退到单查询模式: {e}", exc_info=True)
|
||||||
return await self._single_query_search(query, top_k)
|
single_results = await self._single_query_search(query, top_k)
|
||||||
|
return single_results, []
|
||||||
|
|
||||||
async def _add_memory_to_stores(self, memory: Memory):
|
async def _add_memory_to_stores(self, memory: Memory):
|
||||||
"""将记忆添加到存储"""
|
"""将记忆添加到存储"""
|
||||||
|
|||||||
@@ -88,23 +88,46 @@ async def expand_memories_with_semantic_filter(
|
|||||||
# 获取该记忆的邻居记忆(通过边关系)
|
# 获取该记忆的邻居记忆(通过边关系)
|
||||||
neighbor_memory_ids = set()
|
neighbor_memory_ids = set()
|
||||||
|
|
||||||
# 遍历记忆的所有边,收集邻居记忆
|
# 🆕 遍历记忆的所有边,收集邻居记忆(带边类型权重)
|
||||||
|
edge_weights = {} # 记录通过不同边类型到达的记忆的权重
|
||||||
|
|
||||||
for edge in memory.edges:
|
for edge in memory.edges:
|
||||||
# 获取边的目标节点
|
# 获取边的目标节点
|
||||||
target_node_id = edge.target_id
|
target_node_id = edge.target_id
|
||||||
source_node_id = edge.source_id
|
source_node_id = edge.source_id
|
||||||
|
|
||||||
|
# 🆕 根据边类型设置权重(优先扩展REFERENCE、ATTRIBUTE相关的边)
|
||||||
|
edge_type_str = edge.edge_type.value if hasattr(edge.edge_type, 'value') else str(edge.edge_type)
|
||||||
|
if edge_type_str == "REFERENCE":
|
||||||
|
edge_weight = 1.3 # REFERENCE边权重最高(引用关系)
|
||||||
|
elif edge_type_str in ["ATTRIBUTE", "HAS_PROPERTY"]:
|
||||||
|
edge_weight = 1.2 # 属性边次之
|
||||||
|
elif edge_type_str == "TEMPORAL":
|
||||||
|
edge_weight = 0.7 # 时间关系降权(避免扩展到无关时间点)
|
||||||
|
elif edge_type_str == "RELATION":
|
||||||
|
edge_weight = 0.9 # 一般关系适中降权
|
||||||
|
else:
|
||||||
|
edge_weight = 1.0 # 默认权重
|
||||||
|
|
||||||
# 通过节点找到其他记忆
|
# 通过节点找到其他记忆
|
||||||
for node_id in [target_node_id, source_node_id]:
|
for node_id in [target_node_id, source_node_id]:
|
||||||
if node_id in graph_store.node_to_memories:
|
if node_id in graph_store.node_to_memories:
|
||||||
neighbor_memory_ids.update(graph_store.node_to_memories[node_id])
|
for neighbor_id in graph_store.node_to_memories[node_id]:
|
||||||
|
if neighbor_id not in edge_weights or edge_weights[neighbor_id] < edge_weight:
|
||||||
|
edge_weights[neighbor_id] = edge_weight
|
||||||
|
|
||||||
|
# 将权重高的邻居记忆加入候选
|
||||||
|
for neighbor_id, edge_weight in edge_weights.items():
|
||||||
|
neighbor_memory_ids.add((neighbor_id, edge_weight))
|
||||||
|
|
||||||
# 过滤掉已访问的和自己
|
# 过滤掉已访问的和自己
|
||||||
neighbor_memory_ids.discard(memory_id)
|
filtered_neighbors = []
|
||||||
neighbor_memory_ids -= visited_memories
|
for neighbor_id, edge_weight in neighbor_memory_ids:
|
||||||
|
if neighbor_id != memory_id and neighbor_id not in visited_memories:
|
||||||
|
filtered_neighbors.append((neighbor_id, edge_weight))
|
||||||
|
|
||||||
# 批量评估邻居记忆
|
# 批量评估邻居记忆
|
||||||
for neighbor_mem_id in neighbor_memory_ids:
|
for neighbor_mem_id, edge_weight in filtered_neighbors:
|
||||||
candidates_checked += 1
|
candidates_checked += 1
|
||||||
|
|
||||||
neighbor_memory = graph_store.get_memory_by_id(neighbor_mem_id)
|
neighbor_memory = graph_store.get_memory_by_id(neighbor_mem_id)
|
||||||
@@ -123,12 +146,17 @@ async def expand_memories_with_semantic_filter(
|
|||||||
# 计算语义相似度
|
# 计算语义相似度
|
||||||
semantic_sim = cosine_similarity(query_embedding, topic_node.embedding)
|
semantic_sim = cosine_similarity(query_embedding, topic_node.embedding)
|
||||||
|
|
||||||
# 计算边的重要性(影响评分)
|
# 🆕 计算边的重要性(结合边类型权重和记忆重要性)
|
||||||
edge_importance = neighbor_memory.importance * 0.5 # 使用记忆重要性作为边权重
|
edge_importance = neighbor_memory.importance * edge_weight * 0.5
|
||||||
|
|
||||||
# 综合评分:语义相似度(70%) + 重要性(20%) + 深度衰减(10%)
|
# 🆕 综合评分:语义相似度(60%) + 边权重(20%) + 重要性(10%) + 深度衰减(10%)
|
||||||
depth_decay = 1.0 / (depth + 2) # 深度衰减
|
depth_decay = 1.0 / (depth + 2) # 深度衰减
|
||||||
relevance_score = semantic_sim * 0.7 + edge_importance * 0.2 + depth_decay * 0.1
|
relevance_score = (
|
||||||
|
semantic_sim * 0.60 + # 语义相似度主导 ⬆️
|
||||||
|
edge_weight * 0.20 + # 边类型权重 🆕
|
||||||
|
edge_importance * 0.10 + # 重要性降权 ⬇️
|
||||||
|
depth_decay * 0.10 # 深度衰减
|
||||||
|
)
|
||||||
|
|
||||||
# 只保留超过阈值的
|
# 只保留超过阈值的
|
||||||
if relevance_score < semantic_threshold:
|
if relevance_score < semantic_threshold:
|
||||||
|
|||||||
Reference in New Issue
Block a user