feat: 实现多查询生成与融合搜索,简化记忆检索逻辑
This commit is contained in:
@@ -333,76 +333,100 @@ class MemoryManager:
|
|||||||
|
|
||||||
# ==================== 记忆检索操作 ====================
|
# ==================== 记忆检索操作 ====================
|
||||||
|
|
||||||
async def optimize_search_query(
|
async def generate_multi_queries(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
context: Optional[Dict[str, Any]] = None,
|
context: Optional[Dict[str, Any]] = None,
|
||||||
) -> str:
|
) -> List[Tuple[str, float]]:
|
||||||
"""
|
"""
|
||||||
使用小模型优化搜索查询
|
使用小模型生成多个查询语句(用于多路召回)
|
||||||
|
|
||||||
|
简化版多查询策略:直接让小模型生成3-5个不同角度的查询,
|
||||||
|
避免复杂的查询分解和组合逻辑。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: 原始查询
|
query: 原始查询
|
||||||
context: 上下文信息(聊天历史、发言人等)
|
context: 上下文信息(聊天历史、发言人、参与者等)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
优化后的查询字符串
|
List of (query_string, weight) - 查询语句和权重
|
||||||
"""
|
"""
|
||||||
if not context:
|
|
||||||
return query
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import model_config
|
from src.config.config import model_config
|
||||||
|
|
||||||
# 使用小模型优化查询
|
|
||||||
llm = LLMRequest(
|
llm = LLMRequest(
|
||||||
model_set=model_config.model_task_config.utils_small,
|
model_set=model_config.model_task_config.utils_small,
|
||||||
request_type="memory.query_optimizer"
|
request_type="memory.multi_query_generator"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 构建优化提示
|
# 构建上下文信息
|
||||||
chat_history = context.get("chat_history", "")
|
chat_history = context.get("chat_history", "") if context else ""
|
||||||
sender = context.get("sender", "")
|
sender = context.get("sender", "") if context else ""
|
||||||
|
participants = context.get("participants", []) if context else []
|
||||||
|
participants_str = "、".join(participants) if participants else "无"
|
||||||
|
|
||||||
prompt = f"""你是一个记忆检索查询优化助手。你的任务是分析对话历史,生成一个综合性的搜索查询。
|
prompt = f"""你是记忆检索助手。为提高检索准确率,请为查询生成3-5个不同角度的搜索语句。
|
||||||
|
|
||||||
**任务说明:**
|
**核心原则(重要!):**
|
||||||
不要只优化单个消息,而是要综合分析整个对话上下文,提取出最核心的检索意图。
|
对于包含多个概念的复杂查询(如"杰瑞喵如何评价新的记忆系统"),应该生成:
|
||||||
|
1. 完整查询(包含所有要素)- 权重1.0
|
||||||
|
2. 每个关键概念的独立查询(如"新的记忆系统")- 权重0.8,避免被主体淹没!
|
||||||
|
3. 主体+动作组合(如"杰瑞喵 评价")- 权重0.6
|
||||||
|
4. 泛化查询(如"记忆系统")- 权重0.7
|
||||||
|
|
||||||
**要求:**
|
**要求:**
|
||||||
1. 仔细阅读对话历史,理解对话的主题和脉络
|
- 第一个必须是原始查询或同义改写
|
||||||
2. 识别关键人物、事件、关系和话题
|
- 识别查询中的所有重要概念,为每个概念生成独立查询
|
||||||
3. 提取最值得检索的核心信息点
|
- 查询简洁(5-20字)
|
||||||
4. 生成一个简洁但信息丰富的搜索查询(15-30字)
|
- 直接输出JSON,不要添加说明
|
||||||
5. 如果涉及特定人物,必须明确指出人名
|
|
||||||
6. 只输出查询文本,不要解释
|
|
||||||
|
|
||||||
**对话上下文:**
|
**已知参与者:** {participants_str}
|
||||||
{chat_history[-500:] if chat_history else "(无历史对话)"}
|
**对话上下文:** {chat_history[-300:] if chat_history else "无"}
|
||||||
|
**当前查询:** {sender}: {query}
|
||||||
|
|
||||||
**当前消息:**
|
**输出JSON格式:**
|
||||||
{sender}: {query}
|
```json
|
||||||
|
{{
|
||||||
|
"queries": [
|
||||||
|
{{"text": "完整查询", "weight": 1.0}},
|
||||||
|
{{"text": "关键概念1", "weight": 0.8}},
|
||||||
|
{{"text": "关键概念2", "weight": 0.8}},
|
||||||
|
{{"text": "组合查询", "weight": 0.6}}
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
```"""
|
||||||
|
|
||||||
**生成综合查询:**"""
|
response, _ = await llm.generate_response_async(prompt, temperature=0.3, max_tokens=300)
|
||||||
|
|
||||||
optimized_query, _ = await llm.generate_response_async(
|
|
||||||
prompt,
|
|
||||||
temperature=0.3,
|
|
||||||
max_tokens=100
|
|
||||||
)
|
|
||||||
|
|
||||||
# 清理输出
|
# 解析JSON
|
||||||
optimized_query = optimized_query.strip()
|
import json, re
|
||||||
if optimized_query and len(optimized_query) > 5:
|
response = re.sub(r'```json\s*', '', response)
|
||||||
logger.debug(f"[查询优化] '{query}' -> '{optimized_query}'")
|
response = re.sub(r'```\s*$', '', response).strip()
|
||||||
return optimized_query
|
|
||||||
|
|
||||||
return query
|
try:
|
||||||
|
data = json.loads(response)
|
||||||
|
queries = data.get("queries", [])
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for item in queries:
|
||||||
|
text = item.get("text", "").strip()
|
||||||
|
weight = float(item.get("weight", 0.5))
|
||||||
|
if text:
|
||||||
|
result.append((text, weight))
|
||||||
|
|
||||||
|
if result:
|
||||||
|
logger.info(f"生成 {len(result)} 个查询: {[q for q, _ in result]}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.warning(f"解析失败: {e}, response={response[:100]}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"查询优化失败,使用原始查询: {e}")
|
logger.warning(f"多查询生成失败: {e}")
|
||||||
return query
|
|
||||||
|
# 回退到原始查询
|
||||||
|
return [(query, 1.0)]
|
||||||
|
|
||||||
async def search_memories(
|
async def search_memories(
|
||||||
self,
|
self,
|
||||||
@@ -413,11 +437,16 @@ class MemoryManager:
|
|||||||
min_importance: float = 0.0,
|
min_importance: float = 0.0,
|
||||||
include_forgotten: bool = False,
|
include_forgotten: bool = False,
|
||||||
optimize_query: bool = True,
|
optimize_query: bool = True,
|
||||||
|
use_multi_query: bool = True,
|
||||||
context: Optional[Dict[str, Any]] = None,
|
context: Optional[Dict[str, Any]] = None,
|
||||||
) -> List[Memory]:
|
) -> List[Memory]:
|
||||||
"""
|
"""
|
||||||
搜索记忆
|
搜索记忆
|
||||||
|
|
||||||
|
使用多策略检索优化,解决复杂查询问题。
|
||||||
|
例如:"杰瑞喵如何评价新的记忆系统" 会被分解为多个子查询,
|
||||||
|
确保同时匹配"杰瑞喵"和"新的记忆系统"两个关键概念。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
query: 搜索查询
|
query: 搜索查询
|
||||||
top_k: 返回结果数
|
top_k: 返回结果数
|
||||||
@@ -425,7 +454,8 @@ class MemoryManager:
|
|||||||
time_range: 时间范围过滤 (start, end)
|
time_range: 时间范围过滤 (start, end)
|
||||||
min_importance: 最小重要性
|
min_importance: 最小重要性
|
||||||
include_forgotten: 是否包含已遗忘的记忆
|
include_forgotten: 是否包含已遗忘的记忆
|
||||||
optimize_query: 是否使用小模型优化查询
|
optimize_query: 是否使用小模型优化查询(已弃用,被 use_multi_query 替代)
|
||||||
|
use_multi_query: 是否使用多查询策略(推荐,默认True)
|
||||||
context: 查询上下文(用于优化)
|
context: 查询上下文(用于优化)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -435,19 +465,18 @@ class MemoryManager:
|
|||||||
await self.initialize()
|
await self.initialize()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 查询优化
|
# 准备搜索参数
|
||||||
search_query = query
|
|
||||||
if optimize_query and context:
|
|
||||||
search_query = await self.optimize_search_query(query, context)
|
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"query": search_query,
|
"query": query,
|
||||||
"top_k": top_k,
|
"top_k": top_k,
|
||||||
|
"use_multi_query": use_multi_query,
|
||||||
|
"context": context,
|
||||||
}
|
}
|
||||||
|
|
||||||
if memory_types:
|
if memory_types:
|
||||||
params["memory_types"] = memory_types
|
params["memory_types"] = memory_types
|
||||||
|
|
||||||
|
# 执行搜索
|
||||||
result = await self.tools.search_memories(**params)
|
result = await self.tools.search_memories(**params)
|
||||||
|
|
||||||
if not result["success"]:
|
if not result["success"]:
|
||||||
@@ -484,7 +513,10 @@ class MemoryManager:
|
|||||||
|
|
||||||
filtered_memories.append(memory)
|
filtered_memories.append(memory)
|
||||||
|
|
||||||
logger.info(f"搜索完成: 找到 {len(filtered_memories)} 条记忆")
|
strategy = result.get("strategy", "unknown")
|
||||||
|
logger.info(
|
||||||
|
f"搜索完成: 找到 {len(filtered_memories)} 条记忆 (策略={strategy})"
|
||||||
|
)
|
||||||
return filtered_memories[:top_k]
|
return filtered_memories[:top_k]
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
10
src/memory_graph/retrieval/__init__.py
Normal file
10
src/memory_graph/retrieval/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
"""
|
||||||
|
记忆检索模块
|
||||||
|
|
||||||
|
提供简化的多查询检索功能:
|
||||||
|
- 直接使用小模型生成多个查询语句
|
||||||
|
- 多查询融合检索
|
||||||
|
- 避免复杂的查询分解逻辑
|
||||||
|
"""
|
||||||
|
|
||||||
|
__all__ = []
|
||||||
@@ -236,6 +236,121 @@ class VectorStore:
|
|||||||
logger.error(f"相似节点搜索失败: {e}", exc_info=True)
|
logger.error(f"相似节点搜索失败: {e}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
async def search_with_multiple_queries(
|
||||||
|
self,
|
||||||
|
query_embeddings: List[np.ndarray],
|
||||||
|
query_weights: Optional[List[float]] = None,
|
||||||
|
limit: int = 10,
|
||||||
|
node_types: Optional[List[NodeType]] = None,
|
||||||
|
min_similarity: float = 0.0,
|
||||||
|
fusion_strategy: str = "weighted_max",
|
||||||
|
) -> List[Tuple[str, float, Dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
多查询融合搜索
|
||||||
|
|
||||||
|
使用多个查询向量进行搜索,然后融合结果。
|
||||||
|
这能解决单一查询向量无法同时关注多个关键概念的问题。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_embeddings: 查询向量列表
|
||||||
|
query_weights: 每个查询的权重(可选,默认均等)
|
||||||
|
limit: 最终返回结果数量
|
||||||
|
node_types: 限制节点类型(可选)
|
||||||
|
min_similarity: 最小相似度阈值
|
||||||
|
fusion_strategy: 融合策略
|
||||||
|
- "weighted_max": 加权最大值(推荐)
|
||||||
|
- "weighted_sum": 加权求和
|
||||||
|
- "rrf": Reciprocal Rank Fusion
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
融合后的节点列表 [(node_id, fused_score, metadata), ...]
|
||||||
|
"""
|
||||||
|
if not self.collection:
|
||||||
|
raise RuntimeError("向量存储未初始化")
|
||||||
|
|
||||||
|
if not query_embeddings:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 默认权重均等
|
||||||
|
if query_weights is None:
|
||||||
|
query_weights = [1.0 / len(query_embeddings)] * len(query_embeddings)
|
||||||
|
|
||||||
|
# 归一化权重
|
||||||
|
total_weight = sum(query_weights)
|
||||||
|
if total_weight > 0:
|
||||||
|
query_weights = [w / total_weight for w in query_weights]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. 对每个查询执行搜索
|
||||||
|
all_results: Dict[str, Dict[str, Any]] = {} # node_id -> {scores, metadata}
|
||||||
|
|
||||||
|
for i, (query_emb, weight) in enumerate(zip(query_embeddings, query_weights)):
|
||||||
|
# 搜索更多结果以提高融合质量
|
||||||
|
search_limit = limit * 3
|
||||||
|
results = await self.search_similar_nodes(
|
||||||
|
query_embedding=query_emb,
|
||||||
|
limit=search_limit,
|
||||||
|
node_types=node_types,
|
||||||
|
min_similarity=min_similarity,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 记录每个结果
|
||||||
|
for rank, (node_id, similarity, metadata) in enumerate(results):
|
||||||
|
if node_id not in all_results:
|
||||||
|
all_results[node_id] = {
|
||||||
|
"scores": [],
|
||||||
|
"ranks": [],
|
||||||
|
"metadata": metadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
all_results[node_id]["scores"].append((similarity, weight))
|
||||||
|
all_results[node_id]["ranks"].append((rank, weight))
|
||||||
|
|
||||||
|
# 2. 融合分数
|
||||||
|
fused_results = []
|
||||||
|
|
||||||
|
for node_id, data in all_results.items():
|
||||||
|
scores = data["scores"]
|
||||||
|
ranks = data["ranks"]
|
||||||
|
metadata = data["metadata"]
|
||||||
|
|
||||||
|
if fusion_strategy == "weighted_max":
|
||||||
|
# 加权最大值 + 出现次数奖励
|
||||||
|
max_weighted_score = max(score * weight for score, weight in scores)
|
||||||
|
appearance_bonus = len(scores) * 0.05 # 出现多次有奖励
|
||||||
|
fused_score = max_weighted_score + appearance_bonus
|
||||||
|
|
||||||
|
elif fusion_strategy == "weighted_sum":
|
||||||
|
# 加权求和(可能导致出现多次的结果分数过高)
|
||||||
|
fused_score = sum(score * weight for score, weight in scores)
|
||||||
|
|
||||||
|
elif fusion_strategy == "rrf":
|
||||||
|
# Reciprocal Rank Fusion
|
||||||
|
# RRF score = sum(weight / (rank + k))
|
||||||
|
k = 60 # RRF 常数
|
||||||
|
fused_score = sum(weight / (rank + k) for rank, weight in ranks)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# 默认使用加权平均
|
||||||
|
fused_score = sum(score * weight for score, weight in scores) / len(scores)
|
||||||
|
|
||||||
|
fused_results.append((node_id, fused_score, metadata))
|
||||||
|
|
||||||
|
# 3. 排序并返回 Top-K
|
||||||
|
fused_results.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
final_results = fused_results[:limit]
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"多查询融合搜索完成: {len(query_embeddings)} 个查询, "
|
||||||
|
f"融合后 {len(fused_results)} 个结果, 返回 {len(final_results)} 个"
|
||||||
|
)
|
||||||
|
|
||||||
|
return final_results
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"多查询融合搜索失败: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
async def get_node_by_id(self, node_id: str) -> Optional[Dict[str, Any]]:
|
async def get_node_by_id(self, node_id: str) -> Optional[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
根据ID获取节点元数据
|
根据ID获取节点元数据
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ LLM 工具接口:定义记忆系统的工具 schema 和执行逻辑
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.memory_graph.core.builder import MemoryBuilder
|
from src.memory_graph.core.builder import MemoryBuilder
|
||||||
@@ -429,8 +429,18 @@ class MemoryTools:
|
|||||||
"""
|
"""
|
||||||
执行 search_memories 工具
|
执行 search_memories 工具
|
||||||
|
|
||||||
|
使用多策略检索优化:
|
||||||
|
1. 查询分解(识别主要实体和概念)
|
||||||
|
2. 多查询并行检索
|
||||||
|
3. 结果融合和重排
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
**params: 工具参数
|
**params: 工具参数
|
||||||
|
- query: 查询字符串
|
||||||
|
- top_k: 返回结果数(默认10)
|
||||||
|
- expand_depth: 扩展深度(暂未使用)
|
||||||
|
- use_multi_query: 是否使用多查询策略(默认True)
|
||||||
|
- context: 查询上下文(可选)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
搜索结果
|
搜索结果
|
||||||
@@ -439,33 +449,23 @@ class MemoryTools:
|
|||||||
query = params.get("query", "")
|
query = params.get("query", "")
|
||||||
top_k = params.get("top_k", 10)
|
top_k = params.get("top_k", 10)
|
||||||
expand_depth = params.get("expand_depth", 1)
|
expand_depth = params.get("expand_depth", 1)
|
||||||
|
use_multi_query = params.get("use_multi_query", True)
|
||||||
|
context = params.get("context", None)
|
||||||
|
|
||||||
logger.info(f"搜索记忆: {query} (top_k={top_k}, expand_depth={expand_depth})")
|
logger.info(f"搜索记忆: {query} (top_k={top_k}, multi_query={use_multi_query})")
|
||||||
|
|
||||||
# 0. 确保初始化
|
# 0. 确保初始化
|
||||||
await self._ensure_initialized()
|
await self._ensure_initialized()
|
||||||
|
|
||||||
# 1. 生成查询嵌入
|
# 1. 根据策略选择检索方式
|
||||||
if self.builder.embedding_generator:
|
if use_multi_query:
|
||||||
query_embedding = await self.builder.embedding_generator.generate(query)
|
# 多查询策略
|
||||||
|
similar_nodes = await self._multi_query_search(query, top_k, context)
|
||||||
else:
|
else:
|
||||||
logger.warning("未配置嵌入生成器,使用随机向量")
|
# 传统单查询策略
|
||||||
import numpy as np
|
similar_nodes = await self._single_query_search(query, top_k)
|
||||||
query_embedding = np.random.rand(384).astype(np.float32)
|
|
||||||
|
|
||||||
# 2. 向量搜索
|
# 2. 提取记忆ID
|
||||||
node_types_filter = None
|
|
||||||
if "memory_types" in params:
|
|
||||||
# 添加类型过滤
|
|
||||||
pass
|
|
||||||
|
|
||||||
similar_nodes = await self.vector_store.search_similar_nodes(
|
|
||||||
query_embedding=query_embedding,
|
|
||||||
limit=top_k * 2, # 多取一些,后续过滤
|
|
||||||
node_types=node_types_filter,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. 提取记忆ID
|
|
||||||
memory_ids = set()
|
memory_ids = set()
|
||||||
for node_id, similarity, metadata in similar_nodes:
|
for node_id, similarity, metadata in similar_nodes:
|
||||||
if "memory_ids" in metadata:
|
if "memory_ids" in metadata:
|
||||||
@@ -480,14 +480,14 @@ class MemoryTools:
|
|||||||
if isinstance(ids, list):
|
if isinstance(ids, list):
|
||||||
memory_ids.update(ids)
|
memory_ids.update(ids)
|
||||||
|
|
||||||
# 4. 获取完整记忆
|
# 3. 获取完整记忆
|
||||||
memories = []
|
memories = []
|
||||||
for memory_id in list(memory_ids)[:top_k]:
|
for memory_id in list(memory_ids)[:top_k]:
|
||||||
memory = self.graph_store.get_memory_by_id(memory_id)
|
memory = self.graph_store.get_memory_by_id(memory_id)
|
||||||
if memory:
|
if memory:
|
||||||
memories.append(memory)
|
memories.append(memory)
|
||||||
|
|
||||||
# 5. 格式化结果
|
# 4. 格式化结果
|
||||||
results = []
|
results = []
|
||||||
for memory in memories:
|
for memory in memories:
|
||||||
result = {
|
result = {
|
||||||
@@ -505,6 +505,7 @@ class MemoryTools:
|
|||||||
"results": results,
|
"results": results,
|
||||||
"total": len(results),
|
"total": len(results),
|
||||||
"query": query,
|
"query": query,
|
||||||
|
"strategy": "multi_query" if use_multi_query else "single_query",
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -516,6 +517,145 @@ class MemoryTools:
|
|||||||
"results": [],
|
"results": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async def _generate_multi_queries_simple(
|
||||||
|
self, query: str, context: Optional[Dict[str, Any]] = None
|
||||||
|
) -> List[Tuple[str, float]]:
|
||||||
|
"""
|
||||||
|
简化版多查询生成(直接在 Tools 层实现,避免循环依赖)
|
||||||
|
|
||||||
|
让小模型直接生成3-5个不同角度的查询语句。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
from src.config.config import model_config
|
||||||
|
|
||||||
|
llm = LLMRequest(
|
||||||
|
model_set=model_config.model_task_config.utils_small,
|
||||||
|
request_type="memory.multi_query"
|
||||||
|
)
|
||||||
|
|
||||||
|
participants = context.get("participants", []) if context else []
|
||||||
|
prompt = f"""为查询生成3-5个不同角度的搜索语句(JSON格式)。
|
||||||
|
|
||||||
|
**查询:** {query}
|
||||||
|
**参与者:** {', '.join(participants) if participants else '无'}
|
||||||
|
|
||||||
|
**原则:** 对复杂查询(如"杰瑞喵如何评价新的记忆系统"),应生成:
|
||||||
|
1. 完整查询(权重1.0)
|
||||||
|
2. 每个关键概念独立查询(权重0.8)- 重要!
|
||||||
|
3. 主体+动作(权重0.6)
|
||||||
|
|
||||||
|
**输出JSON:**
|
||||||
|
```json
|
||||||
|
{{"queries": [{{"text": "查询1", "weight": 1.0}}, {{"text": "查询2", "weight": 0.8}}]}}
|
||||||
|
```"""
|
||||||
|
|
||||||
|
response, _ = await llm.generate_response_async(prompt, temperature=0.3, max_tokens=250)
|
||||||
|
|
||||||
|
import json, re
|
||||||
|
response = re.sub(r'```json\s*', '', response)
|
||||||
|
response = re.sub(r'```\s*$', '', response).strip()
|
||||||
|
|
||||||
|
data = json.loads(response)
|
||||||
|
queries = data.get("queries", [])
|
||||||
|
|
||||||
|
result = [(item.get("text", "").strip(), float(item.get("weight", 0.5)))
|
||||||
|
for item in queries if item.get("text", "").strip()]
|
||||||
|
|
||||||
|
if result:
|
||||||
|
logger.info(f"生成查询: {[q for q, _ in result]}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"多查询生成失败: {e}")
|
||||||
|
|
||||||
|
return [(query, 1.0)]
|
||||||
|
|
||||||
|
async def _single_query_search(
|
||||||
|
self, query: str, top_k: int
|
||||||
|
) -> List[Tuple[str, float, Dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
传统的单查询搜索
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 查询字符串
|
||||||
|
top_k: 返回结果数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
相似节点列表 [(node_id, similarity, metadata), ...]
|
||||||
|
"""
|
||||||
|
# 生成查询嵌入
|
||||||
|
if self.builder.embedding_generator:
|
||||||
|
query_embedding = await self.builder.embedding_generator.generate(query)
|
||||||
|
else:
|
||||||
|
logger.warning("未配置嵌入生成器,使用随机向量")
|
||||||
|
import numpy as np
|
||||||
|
query_embedding = np.random.rand(384).astype(np.float32)
|
||||||
|
|
||||||
|
# 向量搜索
|
||||||
|
similar_nodes = await self.vector_store.search_similar_nodes(
|
||||||
|
query_embedding=query_embedding,
|
||||||
|
limit=top_k * 2, # 多取一些,后续过滤
|
||||||
|
)
|
||||||
|
|
||||||
|
return similar_nodes
|
||||||
|
|
||||||
|
async def _multi_query_search(
|
||||||
|
self, query: str, top_k: int, context: Optional[Dict[str, Any]] = None
|
||||||
|
) -> List[Tuple[str, float, Dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
多查询策略搜索(简化版)
|
||||||
|
|
||||||
|
直接使用小模型生成多个查询,无需复杂的分解和组合。
|
||||||
|
|
||||||
|
步骤:
|
||||||
|
1. 让小模型生成3-5个不同角度的查询
|
||||||
|
2. 为每个查询生成嵌入
|
||||||
|
3. 并行搜索并融合结果
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 查询字符串
|
||||||
|
top_k: 返回结果数
|
||||||
|
context: 查询上下文
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
融合后的相似节点列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 1. 使用小模型生成多个查询
|
||||||
|
multi_queries = await self._generate_multi_queries_simple(query, context)
|
||||||
|
|
||||||
|
logger.debug(f"生成 {len(multi_queries)} 个查询: {multi_queries}")
|
||||||
|
|
||||||
|
# 2. 生成所有查询的嵌入
|
||||||
|
if not self.builder.embedding_generator:
|
||||||
|
logger.warning("未配置嵌入生成器,回退到单查询模式")
|
||||||
|
return await self._single_query_search(query, top_k)
|
||||||
|
|
||||||
|
query_embeddings = []
|
||||||
|
query_weights = []
|
||||||
|
|
||||||
|
for sub_query, weight in multi_queries:
|
||||||
|
embedding = await self.builder.embedding_generator.generate(sub_query)
|
||||||
|
query_embeddings.append(embedding)
|
||||||
|
query_weights.append(weight)
|
||||||
|
|
||||||
|
# 3. 多查询融合搜索
|
||||||
|
similar_nodes = await self.vector_store.search_with_multiple_queries(
|
||||||
|
query_embeddings=query_embeddings,
|
||||||
|
query_weights=query_weights,
|
||||||
|
limit=top_k * 2, # 多取一些,后续过滤
|
||||||
|
fusion_strategy="weighted_max",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"多查询检索完成: {len(similar_nodes)} 个节点")
|
||||||
|
|
||||||
|
return similar_nodes
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"多查询搜索失败,回退到单查询模式: {e}", exc_info=True)
|
||||||
|
return await self._single_query_search(query, top_k)
|
||||||
|
|
||||||
async def _add_memory_to_stores(self, memory: Memory):
|
async def _add_memory_to_stores(self, memory: Memory):
|
||||||
"""将记忆添加到存储"""
|
"""将记忆添加到存储"""
|
||||||
# 1. 添加到图存储
|
# 1. 添加到图存储
|
||||||
|
|||||||
Reference in New Issue
Block a user