diff --git a/src/chat/memory_system/Hippocampus.py b/src/chat/memory_system/Hippocampus.py index baa52b979..05a123130 100644 --- a/src/chat/memory_system/Hippocampus.py +++ b/src/chat/memory_system/Hippocampus.py @@ -624,18 +624,19 @@ class Hippocampus: max_memory_length: int = 2, max_depth: int = 3, ) -> list: - """从文本中提取关键词并获取相关记忆。 + """从关键词列表中获取相关记忆。 Args: - keywords (list): 输入文本 - max_memory_num (int, optional): 返回的记忆条目数量上限。默认为3,表示最多返回3条与输入文本相关度最高的记忆。 + keywords (list): 输入关键词列表 + max_memory_num (int, optional): 返回的记忆条目数量上限。默认为3,表示最多返回3条与输入关键词相关度最高的记忆。 max_memory_length (int, optional): 每个主题最多返回的记忆条目数量。默认为2,表示每个主题最多返回2条相似度最高的记忆。 max_depth (int, optional): 记忆检索深度。默认为3。值越大,检索范围越广,可以获取更多间接相关的记忆,但速度会变慢。 Returns: list: 记忆列表,每个元素是一个元组 (topic, memory_content) - topic: str, 记忆主题 - - memory_content: str, 该主题下的完整记忆内容 + - memory_items: list, 该主题下的记忆项列表 + - similarity: float, 与关键词的相似度 """ if not keywords: return [] @@ -734,19 +735,27 @@ class Hippocampus: # 直接使用完整的记忆内容 if memory_items: logger.debug(f"节点包含 {len(memory_items)} 条记忆") - # 计算每条记忆与输入文本的相似度 + # 计算每条记忆与输入关键词的相似度 memory_similarities = [] for memory in memory_items: - # 计算与输入文本的相似度 + # 计算与输入关键词的相似度 memory_words = set(jieba.cut(memory)) - text_words = set(jieba.cut(text)) - all_words = memory_words | text_words + # 将所有关键词合并成一个字符串来计算相似度 + keywords_text = " ".join(valid_keywords) + keywords_words = set(jieba.cut(keywords_text)) + all_words = memory_words | keywords_words v1 = [1 if word in memory_words else 0 for word in all_words] - v2 = [1 if word in text_words else 0 for word in all_words] - _ = cosine_similarity(v1, v2) # 计算但不使用,用_表示 - - # 添加完整记忆到结果中 - all_memories.append((node, memory_items, activation)) + v2 = [1 if word in keywords_words else 0 for word in all_words] + similarity = cosine_similarity(v1, v2) + memory_similarities.append((memory, similarity)) + + # 按相似度排序 + memory_similarities.sort(key=lambda x: x[1], reverse=True) + # 获取最匹配的记忆 + top_memories = memory_similarities[:max_memory_length] + + # 添加到结果中 + all_memories.extend((node, [memory], similarity) for memory, similarity in top_memories) else: logger.info("节点没有记忆") @@ -1563,7 +1572,7 @@ class ParahippocampalGyrus: edge_check_start = time.time() for source, target in edges_to_check: edge_data = self.memory_graph.G[source][target] - last_modified = edge_data.get("last_modified") + last_modified = edge_data.get("last_modified", current_time) if current_time - last_modified > 3600 * global_config.memory.memory_forget_time: current_strength = edge_data.get("strength", 1) @@ -1804,6 +1813,7 @@ class HippocampusManager: """获取所有节点名称的公共接口""" if not self._initialized: raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法") + return self._hippocampus.get_all_node_names() # 创建全局实例 hippocampus_manager = HippocampusManager()