From bc805aabee70bb4662972d9c55ba19fe0e2d570f Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Thu, 6 Nov 2025 12:01:28 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=AE=9E=E7=8E=B0=E5=A4=9A=E6=9F=A5?= =?UTF-8?q?=E8=AF=A2=E7=94=9F=E6=88=90=E4=B8=8E=E8=9E=8D=E5=90=88=E6=90=9C?= =?UTF-8?q?=E7=B4=A2=EF=BC=8C=E7=AE=80=E5=8C=96=E8=AE=B0=E5=BF=86=E6=A3=80?= =?UTF-8?q?=E7=B4=A2=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/memory_graph/manager.py | 130 ++++++++++------ src/memory_graph/retrieval/__init__.py | 10 ++ src/memory_graph/storage/vector_store.py | 115 ++++++++++++++ src/memory_graph/tools/memory_tools.py | 186 ++++++++++++++++++++--- 4 files changed, 369 insertions(+), 72 deletions(-) create mode 100644 src/memory_graph/retrieval/__init__.py diff --git a/src/memory_graph/manager.py b/src/memory_graph/manager.py index 6a8705a96..05520ba03 100644 --- a/src/memory_graph/manager.py +++ b/src/memory_graph/manager.py @@ -333,76 +333,100 @@ class MemoryManager: # ==================== 记忆检索操作 ==================== - async def optimize_search_query( + async def generate_multi_queries( self, query: str, context: Optional[Dict[str, Any]] = None, - ) -> str: + ) -> List[Tuple[str, float]]: """ - 使用小模型优化搜索查询 + 使用小模型生成多个查询语句(用于多路召回) + + 简化版多查询策略:直接让小模型生成3-5个不同角度的查询, + 避免复杂的查询分解和组合逻辑。 Args: query: 原始查询 - context: 上下文信息(聊天历史、发言人等) + context: 上下文信息(聊天历史、发言人、参与者等) Returns: - 优化后的查询字符串 + List of (query_string, weight) - 查询语句和权重 """ - if not context: - return query - try: from src.llm_models.utils_model import LLMRequest from src.config.config import model_config - # 使用小模型优化查询 llm = LLMRequest( model_set=model_config.model_task_config.utils_small, - request_type="memory.query_optimizer" + request_type="memory.multi_query_generator" ) - # 构建优化提示 - chat_history = context.get("chat_history", "") - sender = context.get("sender", "") + # 构建上下文信息 + chat_history = context.get("chat_history", "") if context else "" + sender = context.get("sender", "") if context else "" + participants = context.get("participants", []) if context else [] + participants_str = "、".join(participants) if participants else "无" - prompt = f"""你是一个记忆检索查询优化助手。你的任务是分析对话历史,生成一个综合性的搜索查询。 + prompt = f"""你是记忆检索助手。为提高检索准确率,请为查询生成3-5个不同角度的搜索语句。 -**任务说明:** -不要只优化单个消息,而是要综合分析整个对话上下文,提取出最核心的检索意图。 +**核心原则(重要!):** +对于包含多个概念的复杂查询(如"杰瑞喵如何评价新的记忆系统"),应该生成: +1. 完整查询(包含所有要素)- 权重1.0 +2. 每个关键概念的独立查询(如"新的记忆系统")- 权重0.8,避免被主体淹没! +3. 主体+动作组合(如"杰瑞喵 评价")- 权重0.6 +4. 泛化查询(如"记忆系统")- 权重0.7 **要求:** -1. 仔细阅读对话历史,理解对话的主题和脉络 -2. 识别关键人物、事件、关系和话题 -3. 提取最值得检索的核心信息点 -4. 生成一个简洁但信息丰富的搜索查询(15-30字) -5. 如果涉及特定人物,必须明确指出人名 -6. 只输出查询文本,不要解释 +- 第一个必须是原始查询或同义改写 +- 识别查询中的所有重要概念,为每个概念生成独立查询 +- 查询简洁(5-20字) +- 直接输出JSON,不要添加说明 -**对话上下文:** -{chat_history[-500:] if chat_history else "(无历史对话)"} +**已知参与者:** {participants_str} +**对话上下文:** {chat_history[-300:] if chat_history else "无"} +**当前查询:** {sender}: {query} -**当前消息:** -{sender}: {query} +**输出JSON格式:** +```json +{{ + "queries": [ + {{"text": "完整查询", "weight": 1.0}}, + {{"text": "关键概念1", "weight": 0.8}}, + {{"text": "关键概念2", "weight": 0.8}}, + {{"text": "组合查询", "weight": 0.6}} + ] +}} +```""" -**生成综合查询:**""" - - optimized_query, _ = await llm.generate_response_async( - prompt, - temperature=0.3, - max_tokens=100 - ) + response, _ = await llm.generate_response_async(prompt, temperature=0.3, max_tokens=300) - # 清理输出 - optimized_query = optimized_query.strip() - if optimized_query and len(optimized_query) > 5: - logger.debug(f"[查询优化] '{query}' -> '{optimized_query}'") - return optimized_query + # 解析JSON + import json, re + response = re.sub(r'```json\s*', '', response) + response = re.sub(r'```\s*$', '', response).strip() - return query + try: + data = json.loads(response) + queries = data.get("queries", []) + + result = [] + for item in queries: + text = item.get("text", "").strip() + weight = float(item.get("weight", 0.5)) + if text: + result.append((text, weight)) + + if result: + logger.info(f"生成 {len(result)} 个查询: {[q for q, _ in result]}") + return result + + except json.JSONDecodeError as e: + logger.warning(f"解析失败: {e}, response={response[:100]}") except Exception as e: - logger.warning(f"查询优化失败,使用原始查询: {e}") - return query + logger.warning(f"多查询生成失败: {e}") + + # 回退到原始查询 + return [(query, 1.0)] async def search_memories( self, @@ -413,11 +437,16 @@ class MemoryManager: min_importance: float = 0.0, include_forgotten: bool = False, optimize_query: bool = True, + use_multi_query: bool = True, context: Optional[Dict[str, Any]] = None, ) -> List[Memory]: """ 搜索记忆 + 使用多策略检索优化,解决复杂查询问题。 + 例如:"杰瑞喵如何评价新的记忆系统" 会被分解为多个子查询, + 确保同时匹配"杰瑞喵"和"新的记忆系统"两个关键概念。 + Args: query: 搜索查询 top_k: 返回结果数 @@ -425,7 +454,8 @@ class MemoryManager: time_range: 时间范围过滤 (start, end) min_importance: 最小重要性 include_forgotten: 是否包含已遗忘的记忆 - optimize_query: 是否使用小模型优化查询 + optimize_query: 是否使用小模型优化查询(已弃用,被 use_multi_query 替代) + use_multi_query: 是否使用多查询策略(推荐,默认True) context: 查询上下文(用于优化) Returns: @@ -435,19 +465,18 @@ class MemoryManager: await self.initialize() try: - # 查询优化 - search_query = query - if optimize_query and context: - search_query = await self.optimize_search_query(query, context) - + # 准备搜索参数 params = { - "query": search_query, + "query": query, "top_k": top_k, + "use_multi_query": use_multi_query, + "context": context, } if memory_types: params["memory_types"] = memory_types + # 执行搜索 result = await self.tools.search_memories(**params) if not result["success"]: @@ -484,7 +513,10 @@ class MemoryManager: filtered_memories.append(memory) - logger.info(f"搜索完成: 找到 {len(filtered_memories)} 条记忆") + strategy = result.get("strategy", "unknown") + logger.info( + f"搜索完成: 找到 {len(filtered_memories)} 条记忆 (策略={strategy})" + ) return filtered_memories[:top_k] except Exception as e: diff --git a/src/memory_graph/retrieval/__init__.py b/src/memory_graph/retrieval/__init__.py new file mode 100644 index 000000000..42cb85423 --- /dev/null +++ b/src/memory_graph/retrieval/__init__.py @@ -0,0 +1,10 @@ +""" +记忆检索模块 + +提供简化的多查询检索功能: +- 直接使用小模型生成多个查询语句 +- 多查询融合检索 +- 避免复杂的查询分解逻辑 +""" + +__all__ = [] diff --git a/src/memory_graph/storage/vector_store.py b/src/memory_graph/storage/vector_store.py index 5a791916b..74a148c24 100644 --- a/src/memory_graph/storage/vector_store.py +++ b/src/memory_graph/storage/vector_store.py @@ -236,6 +236,121 @@ class VectorStore: logger.error(f"相似节点搜索失败: {e}", exc_info=True) raise + async def search_with_multiple_queries( + self, + query_embeddings: List[np.ndarray], + query_weights: Optional[List[float]] = None, + limit: int = 10, + node_types: Optional[List[NodeType]] = None, + min_similarity: float = 0.0, + fusion_strategy: str = "weighted_max", + ) -> List[Tuple[str, float, Dict[str, Any]]]: + """ + 多查询融合搜索 + + 使用多个查询向量进行搜索,然后融合结果。 + 这能解决单一查询向量无法同时关注多个关键概念的问题。 + + Args: + query_embeddings: 查询向量列表 + query_weights: 每个查询的权重(可选,默认均等) + limit: 最终返回结果数量 + node_types: 限制节点类型(可选) + min_similarity: 最小相似度阈值 + fusion_strategy: 融合策略 + - "weighted_max": 加权最大值(推荐) + - "weighted_sum": 加权求和 + - "rrf": Reciprocal Rank Fusion + + Returns: + 融合后的节点列表 [(node_id, fused_score, metadata), ...] + """ + if not self.collection: + raise RuntimeError("向量存储未初始化") + + if not query_embeddings: + return [] + + # 默认权重均等 + if query_weights is None: + query_weights = [1.0 / len(query_embeddings)] * len(query_embeddings) + + # 归一化权重 + total_weight = sum(query_weights) + if total_weight > 0: + query_weights = [w / total_weight for w in query_weights] + + try: + # 1. 对每个查询执行搜索 + all_results: Dict[str, Dict[str, Any]] = {} # node_id -> {scores, metadata} + + for i, (query_emb, weight) in enumerate(zip(query_embeddings, query_weights)): + # 搜索更多结果以提高融合质量 + search_limit = limit * 3 + results = await self.search_similar_nodes( + query_embedding=query_emb, + limit=search_limit, + node_types=node_types, + min_similarity=min_similarity, + ) + + # 记录每个结果 + for rank, (node_id, similarity, metadata) in enumerate(results): + if node_id not in all_results: + all_results[node_id] = { + "scores": [], + "ranks": [], + "metadata": metadata, + } + + all_results[node_id]["scores"].append((similarity, weight)) + all_results[node_id]["ranks"].append((rank, weight)) + + # 2. 融合分数 + fused_results = [] + + for node_id, data in all_results.items(): + scores = data["scores"] + ranks = data["ranks"] + metadata = data["metadata"] + + if fusion_strategy == "weighted_max": + # 加权最大值 + 出现次数奖励 + max_weighted_score = max(score * weight for score, weight in scores) + appearance_bonus = len(scores) * 0.05 # 出现多次有奖励 + fused_score = max_weighted_score + appearance_bonus + + elif fusion_strategy == "weighted_sum": + # 加权求和(可能导致出现多次的结果分数过高) + fused_score = sum(score * weight for score, weight in scores) + + elif fusion_strategy == "rrf": + # Reciprocal Rank Fusion + # RRF score = sum(weight / (rank + k)) + k = 60 # RRF 常数 + fused_score = sum(weight / (rank + k) for rank, weight in ranks) + + else: + # 默认使用加权平均 + fused_score = sum(score * weight for score, weight in scores) / len(scores) + + fused_results.append((node_id, fused_score, metadata)) + + # 3. 排序并返回 Top-K + fused_results.sort(key=lambda x: x[1], reverse=True) + final_results = fused_results[:limit] + + logger.info( + f"多查询融合搜索完成: {len(query_embeddings)} 个查询, " + f"融合后 {len(fused_results)} 个结果, 返回 {len(final_results)} 个" + ) + + return final_results + + except Exception as e: + logger.error(f"多查询融合搜索失败: {e}", exc_info=True) + raise + async def get_node_by_id(self, node_id: str) -> Optional[Dict[str, Any]]: """ 根据ID获取节点元数据 diff --git a/src/memory_graph/tools/memory_tools.py b/src/memory_graph/tools/memory_tools.py index 4d568e4eb..20e31f047 100644 --- a/src/memory_graph/tools/memory_tools.py +++ b/src/memory_graph/tools/memory_tools.py @@ -4,7 +4,7 @@ LLM 工具接口:定义记忆系统的工具 schema 和执行逻辑 from __future__ import annotations -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple from src.common.logger import get_logger from src.memory_graph.core.builder import MemoryBuilder @@ -429,8 +429,18 @@ class MemoryTools: """ 执行 search_memories 工具 + 使用多策略检索优化: + 1. 查询分解(识别主要实体和概念) + 2. 多查询并行检索 + 3. 结果融合和重排 + Args: **params: 工具参数 + - query: 查询字符串 + - top_k: 返回结果数(默认10) + - expand_depth: 扩展深度(暂未使用) + - use_multi_query: 是否使用多查询策略(默认True) + - context: 查询上下文(可选) Returns: 搜索结果 @@ -439,33 +449,23 @@ class MemoryTools: query = params.get("query", "") top_k = params.get("top_k", 10) expand_depth = params.get("expand_depth", 1) + use_multi_query = params.get("use_multi_query", True) + context = params.get("context", None) - logger.info(f"搜索记忆: {query} (top_k={top_k}, expand_depth={expand_depth})") + logger.info(f"搜索记忆: {query} (top_k={top_k}, multi_query={use_multi_query})") # 0. 确保初始化 await self._ensure_initialized() - # 1. 生成查询嵌入 - if self.builder.embedding_generator: - query_embedding = await self.builder.embedding_generator.generate(query) + # 1. 根据策略选择检索方式 + if use_multi_query: + # 多查询策略 + similar_nodes = await self._multi_query_search(query, top_k, context) else: - logger.warning("未配置嵌入生成器,使用随机向量") - import numpy as np - query_embedding = np.random.rand(384).astype(np.float32) + # 传统单查询策略 + similar_nodes = await self._single_query_search(query, top_k) - # 2. 向量搜索 - node_types_filter = None - if "memory_types" in params: - # 添加类型过滤 - pass - - similar_nodes = await self.vector_store.search_similar_nodes( - query_embedding=query_embedding, - limit=top_k * 2, # 多取一些,后续过滤 - node_types=node_types_filter, - ) - - # 3. 提取记忆ID + # 2. 提取记忆ID memory_ids = set() for node_id, similarity, metadata in similar_nodes: if "memory_ids" in metadata: @@ -480,14 +480,14 @@ class MemoryTools: if isinstance(ids, list): memory_ids.update(ids) - # 4. 获取完整记忆 + # 3. 获取完整记忆 memories = [] for memory_id in list(memory_ids)[:top_k]: memory = self.graph_store.get_memory_by_id(memory_id) if memory: memories.append(memory) - # 5. 格式化结果 + # 4. 格式化结果 results = [] for memory in memories: result = { @@ -505,6 +505,7 @@ class MemoryTools: "results": results, "total": len(results), "query": query, + "strategy": "multi_query" if use_multi_query else "single_query", } except Exception as e: @@ -516,6 +517,145 @@ class MemoryTools: "results": [], } + async def _generate_multi_queries_simple( + self, query: str, context: Optional[Dict[str, Any]] = None + ) -> List[Tuple[str, float]]: + """ + 简化版多查询生成(直接在 Tools 层实现,避免循环依赖) + + 让小模型直接生成3-5个不同角度的查询语句。 + """ + try: + from src.llm_models.utils_model import LLMRequest + from src.config.config import model_config + + llm = LLMRequest( + model_set=model_config.model_task_config.utils_small, + request_type="memory.multi_query" + ) + + participants = context.get("participants", []) if context else [] + prompt = f"""为查询生成3-5个不同角度的搜索语句(JSON格式)。 + +**查询:** {query} +**参与者:** {', '.join(participants) if participants else '无'} + +**原则:** 对复杂查询(如"杰瑞喵如何评价新的记忆系统"),应生成: +1. 完整查询(权重1.0) +2. 每个关键概念独立查询(权重0.8)- 重要! +3. 主体+动作(权重0.6) + +**输出JSON:** +```json +{{"queries": [{{"text": "查询1", "weight": 1.0}}, {{"text": "查询2", "weight": 0.8}}]}} +```""" + + response, _ = await llm.generate_response_async(prompt, temperature=0.3, max_tokens=250) + + import json, re + response = re.sub(r'```json\s*', '', response) + response = re.sub(r'```\s*$', '', response).strip() + + data = json.loads(response) + queries = data.get("queries", []) + + result = [(item.get("text", "").strip(), float(item.get("weight", 0.5))) + for item in queries if item.get("text", "").strip()] + + if result: + logger.info(f"生成查询: {[q for q, _ in result]}") + return result + + except Exception as e: + logger.warning(f"多查询生成失败: {e}") + + return [(query, 1.0)] + + async def _single_query_search( + self, query: str, top_k: int + ) -> List[Tuple[str, float, Dict[str, Any]]]: + """ + 传统的单查询搜索 + + Args: + query: 查询字符串 + top_k: 返回结果数 + + Returns: + 相似节点列表 [(node_id, similarity, metadata), ...] + """ + # 生成查询嵌入 + if self.builder.embedding_generator: + query_embedding = await self.builder.embedding_generator.generate(query) + else: + logger.warning("未配置嵌入生成器,使用随机向量") + import numpy as np + query_embedding = np.random.rand(384).astype(np.float32) + + # 向量搜索 + similar_nodes = await self.vector_store.search_similar_nodes( + query_embedding=query_embedding, + limit=top_k * 2, # 多取一些,后续过滤 + ) + + return similar_nodes + + async def _multi_query_search( + self, query: str, top_k: int, context: Optional[Dict[str, Any]] = None + ) -> List[Tuple[str, float, Dict[str, Any]]]: + """ + 多查询策略搜索(简化版) + + 直接使用小模型生成多个查询,无需复杂的分解和组合。 + + 步骤: + 1. 让小模型生成3-5个不同角度的查询 + 2. 为每个查询生成嵌入 + 3. 并行搜索并融合结果 + + Args: + query: 查询字符串 + top_k: 返回结果数 + context: 查询上下文 + + Returns: + 融合后的相似节点列表 + """ + try: + # 1. 使用小模型生成多个查询 + multi_queries = await self._generate_multi_queries_simple(query, context) + + logger.debug(f"生成 {len(multi_queries)} 个查询: {multi_queries}") + + # 2. 生成所有查询的嵌入 + if not self.builder.embedding_generator: + logger.warning("未配置嵌入生成器,回退到单查询模式") + return await self._single_query_search(query, top_k) + + query_embeddings = [] + query_weights = [] + + for sub_query, weight in multi_queries: + embedding = await self.builder.embedding_generator.generate(sub_query) + query_embeddings.append(embedding) + query_weights.append(weight) + + # 3. 多查询融合搜索 + similar_nodes = await self.vector_store.search_with_multiple_queries( + query_embeddings=query_embeddings, + query_weights=query_weights, + limit=top_k * 2, # 多取一些,后续过滤 + fusion_strategy="weighted_max", + ) + + logger.info(f"多查询检索完成: {len(similar_nodes)} 个节点") + + return similar_nodes + + except Exception as e: + logger.warning(f"多查询搜索失败,回退到单查询模式: {e}", exc_info=True) + return await self._single_query_search(query, top_k) + async def _add_memory_to_stores(self, memory: Memory): """将记忆添加到存储""" # 1. 添加到图存储