From befb8ad3f6f1b977af463c7d5ad9347f4cd9c12f Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Fri, 7 Nov 2025 18:09:28 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E4=BC=98=E5=8C=96=E5=B5=8C?= =?UTF-8?q?=E5=85=A5=E7=94=9F=E6=88=90=E9=80=BB=E8=BE=91=EF=BC=8C=E5=A4=B1?= =?UTF-8?q?=E8=B4=A5=E6=97=B6=E8=BF=94=E5=9B=9E=20None=EF=BC=8C=E7=AE=80?= =?UTF-8?q?=E5=8C=96=E9=94=99=E8=AF=AF=E5=A4=84=E7=90=86=EF=BC=9B=E6=9B=B4?= =?UTF-8?q?=E6=96=B0=E8=B0=83=E5=BA=A6=E5=99=A8=E4=BB=BB=E5=8A=A1=E7=AE=A1?= =?UTF-8?q?=E7=90=86=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../message_manager/scheduler_dispatcher.py | 13 +- src/memory_graph/core/builder.py | 24 ++- src/memory_graph/tools/memory_tools.py | 51 ++++--- src/memory_graph/utils/embeddings.py | 143 ++++-------------- 4 files changed, 85 insertions(+), 146 deletions(-) diff --git a/src/chat/message_manager/scheduler_dispatcher.py b/src/chat/message_manager/scheduler_dispatcher.py index 392da2212..a13008e26 100644 --- a/src/chat/message_manager/scheduler_dispatcher.py +++ b/src/chat/message_manager/scheduler_dispatcher.py @@ -417,8 +417,7 @@ class SchedulerDispatcher: stream_id: 流ID """ try: - # 从追踪中移除(因为是一次性任务) - old_schedule_id = self.stream_schedules.pop(stream_id, None) + old_schedule_id = self.stream_schedules.get(stream_id) logger.info( f"⏰ Schedule 触发: 流={stream_id[:8]}..., " @@ -445,14 +444,8 @@ class SchedulerDispatcher: if not success: self.stats["total_failures"] += 1 - # 处理完成后,检查是否需要创建新的 schedule - if stream_id in self.stream_schedules: - logger.info( - f"⚠️ 处理完成时发现已有新 schedule: 流={stream_id[:8]}..., " - f"可能是打断创建的,跳过创建新 schedule" - ) - return - + self.stream_schedules.pop(stream_id, None) + # 检查缓存中是否有待处理的消息 from src.chat.message_manager.message_manager import message_manager diff --git a/src/memory_graph/core/builder.py b/src/memory_graph/core/builder.py index a5cad1f20..5c93dd3dc 100644 --- a/src/memory_graph/core/builder.py +++ b/src/memory_graph/core/builder.py @@ -318,7 +318,7 @@ class MemoryBuilder: return nodes, edges - async def _generate_embedding(self, text: str) -> np.ndarray: + async def _generate_embedding(self, text: str) -> np.ndarray | None: """ 生成文本的嵌入向量 @@ -326,17 +326,17 @@ class MemoryBuilder: text: 文本内容 Returns: - 嵌入向量 + 嵌入向量,失败时返回 None """ if self.embedding_generator: try: embedding = await self.embedding_generator.generate(text) return embedding except Exception as e: - logger.warning(f"嵌入生成失败,使用随机向量: {e}") + logger.warning(f"嵌入生成失败,跳过: {e}") - # 回退:生成随机向量(仅用于测试) - return np.random.rand(384).astype(np.float32) + # 嵌入生成失败,返回 None + return None async def _find_existing_node( self, content: str, node_type: NodeType @@ -367,7 +367,7 @@ class MemoryBuilder: return None async def _find_similar_topic( - self, content: str, embedding: np.ndarray + self, content: str, embedding: np.ndarray | None ) -> MemoryNode | None: """ 查找相似的主题节点(基于语义相似度) @@ -379,6 +379,11 @@ class MemoryBuilder: Returns: 相似节点,如果没有则返回 None """ + # 如果嵌入为空,无法进行相似性搜索 + if embedding is None: + logger.debug("嵌入向量为空,跳过相似节点搜索") + return None + try: # 搜索相似节点(阈值 0.95) similar_nodes = await self.vector_store.search_similar_nodes( @@ -412,7 +417,7 @@ class MemoryBuilder: return None async def _find_similar_object( - self, content: str, embedding: np.ndarray + self, content: str, embedding: np.ndarray | None ) -> MemoryNode | None: """ 查找相似的客体节点(基于语义相似度) @@ -424,6 +429,11 @@ class MemoryBuilder: Returns: 相似节点,如果没有则返回 None """ + # 如果嵌入为空,无法进行相似性搜索 + if embedding is None: + logger.debug("嵌入向量为空,跳过相似节点搜索") + return None + try: # 搜索相似节点(阈值 0.95) similar_nodes = await self.vector_store.search_similar_nodes( diff --git a/src/memory_graph/tools/memory_tools.py b/src/memory_graph/tools/memory_tools.py index 325cd0af4..8986ce732 100644 --- a/src/memory_graph/tools/memory_tools.py +++ b/src/memory_graph/tools/memory_tools.py @@ -506,16 +506,18 @@ class MemoryTools: try: query_embedding = await self.builder.embedding_generator.generate(query) - # 使用共享的图扩展工具函数 - expanded_results = await expand_memories_with_semantic_filter( - graph_store=self.graph_store, - vector_store=self.vector_store, - initial_memory_ids=list(initial_memory_ids), - query_embedding=query_embedding, - max_depth=expand_depth, - semantic_threshold=self.expand_semantic_threshold, # 使用配置的阈值 - max_expanded=top_k * 2 - ) + # 只有在嵌入生成成功时才进行语义扩展 + if query_embedding is not None: + # 使用共享的图扩展工具函数 + expanded_results = await expand_memories_with_semantic_filter( + graph_store=self.graph_store, + vector_store=self.vector_store, + initial_memory_ids=list(initial_memory_ids), + query_embedding=query_embedding, + max_depth=expand_depth, + semantic_threshold=self.expand_semantic_threshold, # 使用配置的阈值 + max_expanded=top_k * 2 + ) # 合并扩展结果 expanded_memory_scores.update(dict(expanded_results)) @@ -714,12 +716,14 @@ class MemoryTools: 相似节点列表 [(node_id, similarity, metadata), ...] """ # 生成查询嵌入 + query_embedding = None if self.builder.embedding_generator: query_embedding = await self.builder.embedding_generator.generate(query) - else: - logger.warning("未配置嵌入生成器,使用随机向量") - import numpy as np - query_embedding = np.random.rand(384).astype(np.float32) + + # 如果嵌入生成失败,无法进行向量搜索 + if query_embedding is None: + logger.warning("嵌入生成失败,跳过节点搜索") + return [] # 向量搜索 similar_nodes = await self.vector_store.search_similar_nodes( @@ -766,8 +770,14 @@ class MemoryTools: for sub_query, weight in multi_queries: embedding = await self.builder.embedding_generator.generate(sub_query) - query_embeddings.append(embedding) - query_weights.append(weight) + if embedding is not None: + query_embeddings.append(embedding) + query_weights.append(weight) + + # 如果所有嵌入都生成失败,回退到单查询模式 + if not query_embeddings: + logger.warning("所有查询嵌入生成失败,回退到单查询模式") + return await self._single_query_search(query, top_k) # 3. 多查询融合搜索 similar_nodes = await self.vector_store.search_with_multiple_queries( @@ -806,11 +816,14 @@ class MemoryTools: 找到的记忆,如果没有则返回 None """ # 使用语义搜索查找最相关的记忆 + query_embedding = None if self.builder.embedding_generator: query_embedding = await self.builder.embedding_generator.generate(description) - else: - import numpy as np - query_embedding = np.random.rand(384).astype(np.float32) + + # 如果嵌入生成失败,无法进行语义搜索 + if query_embedding is None: + logger.debug("嵌入生成失败,跳过描述搜索") + return None # 搜索相似节点 similar_nodes = await self.vector_store.search_similar_nodes( diff --git a/src/memory_graph/utils/embeddings.py b/src/memory_graph/utils/embeddings.py index ae80b5aa0..30787d34f 100644 --- a/src/memory_graph/utils/embeddings.py +++ b/src/memory_graph/utils/embeddings.py @@ -1,5 +1,5 @@ """ -嵌入向量生成器:优先使用配置的 embedding API,sentence-transformers 作为备选 +嵌入向量生成器:优先使用配置的 embedding API,失败时跳过向量生成 """ from __future__ import annotations @@ -19,39 +19,33 @@ class EmbeddingGenerator: 策略: 1. 优先使用配置的 embedding API(通过 LLMRequest) - 2. 如果 API 不可用,回退到本地 sentence-transformers - 3. 如果 sentence-transformers 未安装,使用随机向量(仅测试) + 2. 如果 API 不可用或失败,跳过向量生成,返回 None 或零向量 + 3. 不再使用本地 sentence-transformers 模型,避免向量维度不匹配 优点: - - 降低本地运算负载 - - 即使未安装 sentence-transformers 也可正常运行 + - 完全避免本地运算负载 + - 避免向量维度不匹配问题 + - 简化错误处理逻辑 - 保持与现有系统的一致性 """ def __init__( self, use_api: bool = True, - fallback_model_name: str = "paraphrase-multilingual-MiniLM-L12-v2", ): """ 初始化嵌入生成器 Args: - use_api: 是否优先使用 API(默认 True) - fallback_model_name: 回退本地模型名称 + use_api: 是否使用 API(默认 True) """ self.use_api = use_api - self.fallback_model_name = fallback_model_name # API 相关 self._llm_request = None self._api_available = False self._api_dimension = None - # 本地模型相关 - self._local_model = None - self._local_model_loaded = False - async def _initialize_api(self): """初始化 embedding API""" if self._api_available: @@ -78,67 +72,39 @@ class EmbeddingGenerator: logger.warning(f"⚠️ Embedding API 初始化失败: {e}") self._api_available = False - def _load_local_model(self): - """延迟加载本地模型""" - if not self._local_model_loaded: - try: - from sentence_transformers import SentenceTransformer - - logger.info(f"📦 加载本地嵌入模型: {self.fallback_model_name}") - self._local_model = SentenceTransformer(self.fallback_model_name) - self._local_model_loaded = True - logger.info("✅ 本地嵌入模型加载成功") - except ImportError: - logger.warning( - "⚠️ sentence-transformers 未安装,将使用随机向量(仅测试用)\n" - " 安装方法: pip install sentence-transformers" - ) - self._local_model_loaded = False - except Exception as e: - logger.warning(f"⚠️ 本地模型加载失败: {e}") - self._local_model_loaded = False - - async def generate(self, text: str) -> np.ndarray: + + async def generate(self, text: str) -> np.ndarray | None: """ 生成单个文本的嵌入向量 策略: - 1. 优先使用 API - 2. API 失败则使用本地模型 - 3. 本地模型不可用则使用随机向量 + 1. 使用 API 生成向量 + 2. API 失败则返回 None,跳过向量生成 Args: text: 输入文本 Returns: - 嵌入向量 + 嵌入向量,失败时返回 None """ if not text or not text.strip(): - logger.warning("输入文本为空,返回零向量") - dim = self._get_dimension() - return np.zeros(dim, dtype=np.float32) + logger.debug("输入文本为空,返回 None") + return None try: - # 策略 1: 使用 API + # 使用 API 生成嵌入 if self.use_api: embedding = await self._generate_with_api(text) if embedding is not None: return embedding - # 策略 2: 使用本地模型 - embedding = await self._generate_with_local_model(text) - if embedding is not None: - return embedding - - # 策略 3: 随机向量(仅测试) - logger.warning(f"⚠️ 所有嵌入策略失败,使用随机向量: {text[:30]}...") - dim = self._get_dimension() - return np.random.rand(dim).astype(np.float32) + # API 失败,记录日志并返回 None + logger.debug(f"⚠️ 嵌入生成失败,跳过: {text[:30]}...") + return None except Exception as e: - logger.error(f"❌ 嵌入生成失败: {e}", exc_info=True) - dim = self._get_dimension() - return np.random.rand(dim).astype(np.float32) + logger.error(f"❌ 嵌入生成异常: {e}", exc_info=True) + return None async def _generate_with_api(self, text: str) -> np.ndarray | None: """使用 API 生成嵌入""" @@ -164,51 +130,16 @@ class EmbeddingGenerator: logger.debug(f"API 嵌入生成失败: {e}") return None - async def _generate_with_local_model(self, text: str) -> np.ndarray | None: - """使用本地模型生成嵌入""" - try: - # 加载本地模型 - if not self._local_model_loaded: - self._load_local_model() - - if not self._local_model_loaded or not self._local_model: - return None - - # 在线程池中运行 - loop = asyncio.get_event_loop() - embedding = await loop.run_in_executor(None, self._encode_single_local, text) - - logger.debug(f"💻 本地生成嵌入: {text[:30]}... -> {len(embedding)}维") - return embedding - - except Exception as e: - logger.debug(f"本地模型嵌入生成失败: {e}") - return None - - def _encode_single_local(self, text: str) -> np.ndarray: - """同步编码单个文本(本地模型)""" - if self._local_model is None: - raise RuntimeError("本地模型未加载") - embedding = self._local_model.encode(text, convert_to_numpy=True) # type: ignore - return embedding.astype(np.float32) - + def _get_dimension(self) -> int: """获取嵌入维度""" # 优先使用 API 维度 if self._api_dimension: return self._api_dimension - # 其次使用本地模型维度 - if self._local_model_loaded and self._local_model: - try: - return self._local_model.get_sentence_embedding_dimension() - except Exception: - pass + raise ValueError("无法确定嵌入向量维度,请确保已正确配置 embedding API") - # 默认 384(sentence-transformers 常用维度) - return 384 - - async def generate_batch(self, texts: list[str]) -> list[np.ndarray]: + async def generate_batch(self, texts: list[str]) -> list[np.ndarray | None]: """ 批量生成嵌入向量 @@ -216,7 +147,7 @@ class EmbeddingGenerator: texts: 文本列表 Returns: - 嵌入向量列表 + 嵌入向量列表,失败的项目为 None """ if not texts: return [] @@ -225,9 +156,8 @@ class EmbeddingGenerator: # 过滤空文本 valid_texts = [t for t in texts if t and t.strip()] if not valid_texts: - logger.warning("所有文本为空,返回零向量列表") - dim = self._get_dimension() - return [np.zeros(dim, dtype=np.float32) for _ in texts] + logger.debug("所有文本为空,返回 None 列表") + return [None for _ in texts] # 使用 API 批量生成(如果可用) if self.use_api: @@ -241,15 +171,15 @@ class EmbeddingGenerator: embedding = await self.generate(text) results.append(embedding) - logger.info(f"✅ 批量生成嵌入: {len(texts)} 个文本") + success_count = sum(1 for r in results if r is not None) + logger.debug(f"✅ 批量生成嵌入: {success_count}/{len(texts)} 个成功") return results except Exception as e: logger.error(f"❌ 批量嵌入生成失败: {e}", exc_info=True) - dim = self._get_dimension() - return [np.random.rand(dim).astype(np.float32) for _ in texts] + return [None for _ in texts] - async def _generate_batch_with_api(self, texts: list[str]) -> list[np.ndarray] | None: + async def _generate_batch_with_api(self, texts: list[str]) -> list[np.ndarray | None] | None: """使用 API 批量生成""" try: # 对于大多数 API,批量调用就是多次单独调用 @@ -257,9 +187,7 @@ class EmbeddingGenerator: results = [] for text in texts: embedding = await self._generate_with_api(text) - if embedding is None: - return None # 如果任何一个失败,返回 None 触发回退 - results.append(embedding) + results.append(embedding) # 失败的项目为 None,不中断整个批量处理 return results except Exception as e: logger.debug(f"API 批量生成失败: {e}") @@ -276,22 +204,17 @@ _global_generator: EmbeddingGenerator | None = None def get_embedding_generator( use_api: bool = True, - fallback_model_name: str = "paraphrase-multilingual-MiniLM-L12-v2", ) -> EmbeddingGenerator: """ 获取全局嵌入生成器单例 Args: - use_api: 是否优先使用 API - fallback_model_name: 回退本地模型名称 + use_api: 是否使用 API Returns: EmbeddingGenerator 实例 """ global _global_generator if _global_generator is None: - _global_generator = EmbeddingGenerator( - use_api=use_api, - fallback_model_name=fallback_model_name - ) + _global_generator = EmbeddingGenerator(use_api=use_api) return _global_generator