better:海马体2.0升级,进度 60%,炸了别怪我
This commit is contained in:
@@ -129,9 +129,9 @@ class ChatBot:
|
|||||||
|
|
||||||
# 根据话题计算激活度
|
# 根据话题计算激活度
|
||||||
topic = ""
|
topic = ""
|
||||||
# interested_rate = await HippocampusManager.get_instance().memory_activate_value(message.processed_plain_text) / 100
|
interested_rate = await HippocampusManager.get_instance().get_activate_from_text(message.processed_plain_text)
|
||||||
interested_rate = 0.1
|
# interested_rate = 0.1
|
||||||
logger.debug(f"对{message.processed_plain_text}的激活度:{interested_rate}")
|
logger.info(f"对{message.processed_plain_text}的激活度:{interested_rate}")
|
||||||
# logger.info(f"\033[1;32m[主题识别]\033[0m 使用{global_config.topic_extract}主题: {topic}")
|
# logger.info(f"\033[1;32m[主题识别]\033[0m 使用{global_config.topic_extract}主题: {topic}")
|
||||||
|
|
||||||
await self.storage.store_message(message, chat, topic[0] if topic else None)
|
await self.storage.store_message(message, chat, topic[0] if topic else None)
|
||||||
|
|||||||
@@ -80,10 +80,15 @@ class PromptBuilder:
|
|||||||
|
|
||||||
# 调用 hippocampus 的 get_relevant_memories 方法
|
# 调用 hippocampus 的 get_relevant_memories 方法
|
||||||
relevant_memories = await HippocampusManager.get_instance().get_memory_from_text(
|
relevant_memories = await HippocampusManager.get_instance().get_memory_from_text(
|
||||||
text=message_txt, num=3, max_depth=2, fast_retrieval=True
|
text=message_txt,
|
||||||
|
max_memory_num=4,
|
||||||
|
max_memory_length=2,
|
||||||
|
max_depth=3,
|
||||||
|
fast_retrieval=False
|
||||||
)
|
)
|
||||||
# memory_str = "\n".join(memory for topic, memories, _ in relevant_memories for memory in memories)
|
|
||||||
memory_str = ""
|
memory_str = ""
|
||||||
|
for topic, memories in relevant_memories:
|
||||||
|
memory_str += f"{memories}\n"
|
||||||
print(f"memory_str: {memory_str}")
|
print(f"memory_str: {memory_str}")
|
||||||
|
|
||||||
if relevant_memories:
|
if relevant_memories:
|
||||||
|
|||||||
@@ -903,7 +903,7 @@ class Hippocampus:
|
|||||||
memories.sort(key=lambda x: x[2], reverse=True)
|
memories.sort(key=lambda x: x[2], reverse=True)
|
||||||
return memories
|
return memories
|
||||||
|
|
||||||
async def get_memory_from_text(self, text: str, num: int = 5, max_depth: int = 3,
|
async def get_memory_from_text(self, text: str, max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3,
|
||||||
fast_retrieval: bool = False) -> list:
|
fast_retrieval: bool = False) -> list:
|
||||||
"""从文本中提取关键词并获取相关记忆。
|
"""从文本中提取关键词并获取相关记忆。
|
||||||
|
|
||||||
@@ -935,8 +935,8 @@ class Hippocampus:
|
|||||||
keywords = keywords[:5]
|
keywords = keywords[:5]
|
||||||
else:
|
else:
|
||||||
# 使用LLM提取关键词
|
# 使用LLM提取关键词
|
||||||
topic_num = min(5, max(1, int(len(text) * 0.2))) # 根据文本长度动态调整关键词数量
|
topic_num = min(5, max(1, int(len(text) * 0.1))) # 根据文本长度动态调整关键词数量
|
||||||
print(f"提取关键词数量: {topic_num}")
|
# logger.info(f"提取关键词数量: {topic_num}")
|
||||||
topics_response = await self.llm_topic_judge.generate_response(
|
topics_response = await self.llm_topic_judge.generate_response(
|
||||||
self.find_topic_llm(text, topic_num)
|
self.find_topic_llm(text, topic_num)
|
||||||
)
|
)
|
||||||
@@ -952,96 +952,276 @@ class Hippocampus:
|
|||||||
if keyword.strip()
|
if keyword.strip()
|
||||||
]
|
]
|
||||||
|
|
||||||
logger.info(f"提取的关键词: {', '.join(keywords)}")
|
# logger.info(f"提取的关键词: {', '.join(keywords)}")
|
||||||
|
|
||||||
|
# 过滤掉不存在于记忆图中的关键词
|
||||||
|
valid_keywords = [keyword for keyword in keywords if keyword in self.memory_graph.G]
|
||||||
|
if not valid_keywords:
|
||||||
|
logger.info("没有找到有效的关键词节点")
|
||||||
|
return []
|
||||||
|
|
||||||
|
logger.info(f"有效的关键词: {', '.join(valid_keywords)}")
|
||||||
|
|
||||||
# 从每个关键词获取记忆
|
# 从每个关键词获取记忆
|
||||||
all_memories = []
|
all_memories = []
|
||||||
keyword_connections = [] # 存储关键词之间的连接关系
|
keyword_connections = [] # 存储关键词之间的连接关系
|
||||||
|
activation_words = set(valid_keywords) # 存储所有激活词(包括关键词和途经点)
|
||||||
|
activate_map = {} # 存储每个词的累计激活值
|
||||||
|
|
||||||
# 检查关键词之间的连接
|
# 对每个关键词进行扩散式检索
|
||||||
for i in range(len(keywords)):
|
for keyword in valid_keywords:
|
||||||
for j in range(i + 1, len(keywords)):
|
logger.debug(f"开始以关键词 '{keyword}' 为中心进行扩散检索 (最大深度: {max_depth}):")
|
||||||
keyword1, keyword2 = keywords[i], keywords[j]
|
# 初始化激活值
|
||||||
|
activation_values = {keyword: 1.0}
|
||||||
|
# 记录已访问的节点
|
||||||
|
visited_nodes = {keyword}
|
||||||
|
# 待处理的节点队列,每个元素是(节点, 激活值, 当前深度)
|
||||||
|
nodes_to_process = [(keyword, 1.0, 0)]
|
||||||
|
|
||||||
# 检查节点是否存在于图中
|
while nodes_to_process:
|
||||||
if keyword1 not in self.memory_graph.G or keyword2 not in self.memory_graph.G:
|
current_node, current_activation, current_depth = nodes_to_process.pop(0)
|
||||||
logger.debug(f"关键词 {keyword1} 或 {keyword2} 不在记忆图中")
|
|
||||||
|
# 如果激活值小于0或超过最大深度,停止扩散
|
||||||
|
if current_activation <= 0 or current_depth >= max_depth:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 检查直接连接
|
# 获取当前节点的所有邻居
|
||||||
if self.memory_graph.G.has_edge(keyword1, keyword2):
|
neighbors = list(self.memory_graph.G.neighbors(current_node))
|
||||||
keyword_connections.append((keyword1, keyword2, 1))
|
|
||||||
logger.info(f"发现直接连接: {keyword1} <-> {keyword2} (长度: 1)")
|
for neighbor in neighbors:
|
||||||
|
if neighbor in visited_nodes:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 检查间接连接(通过其他节点)
|
# 获取连接强度
|
||||||
for depth in range(2, max_depth + 1):
|
edge_data = self.memory_graph.G[current_node][neighbor]
|
||||||
# 使用networkx的shortest_path_length检查是否存在指定长度的路径
|
strength = edge_data.get("strength", 1)
|
||||||
try:
|
|
||||||
path_length = nx.shortest_path_length(self.memory_graph.G, keyword1, keyword2)
|
|
||||||
if path_length <= depth:
|
|
||||||
keyword_connections.append((keyword1, keyword2, path_length))
|
|
||||||
logger.info(f"发现间接连接: {keyword1} <-> {keyword2} (长度: {path_length})")
|
|
||||||
# 输出连接路径
|
|
||||||
path = nx.shortest_path(self.memory_graph.G, keyword1, keyword2)
|
|
||||||
logger.info(f"连接路径: {' -> '.join(path)}")
|
|
||||||
break
|
|
||||||
except nx.NetworkXNoPath:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not keyword_connections:
|
# 计算新的激活值
|
||||||
logger.info("未发现任何关键词之间的连接")
|
new_activation = current_activation - (1 / strength)
|
||||||
|
|
||||||
# 记录已处理的关键词连接
|
if new_activation > 0:
|
||||||
processed_connections = set()
|
activation_values[neighbor] = new_activation
|
||||||
|
visited_nodes.add(neighbor)
|
||||||
|
nodes_to_process.append((neighbor, new_activation, current_depth + 1))
|
||||||
|
logger.debug(f"节点 '{neighbor}' 被激活,激活值: {new_activation:.2f} (通过 '{current_node}' 连接,强度: {strength}, 深度: {current_depth + 1})")
|
||||||
|
|
||||||
# 从每个关键词获取记忆
|
# 更新激活映射
|
||||||
for keyword in keywords:
|
for node, activation_value in activation_values.items():
|
||||||
if keyword in self.memory_graph.G: # 只处理存在于图中的关键词
|
if activation_value > 0:
|
||||||
memories = self.get_memory_from_keyword(keyword, max_depth)
|
if node in activate_map:
|
||||||
all_memories.extend(memories)
|
activate_map[node] += activation_value
|
||||||
|
else:
|
||||||
|
activate_map[node] = activation_value
|
||||||
|
|
||||||
# 处理关键词连接相关的记忆
|
# 输出激活映射
|
||||||
for keyword1, keyword2, path_length in keyword_connections:
|
# logger.info("激活映射统计:")
|
||||||
if (keyword1, keyword2) in processed_connections or (keyword2, keyword1) in processed_connections:
|
# for node, total_activation in sorted(activate_map.items(), key=lambda x: x[1], reverse=True):
|
||||||
continue
|
# logger.info(f"节点 '{node}': 累计激活值 = {total_activation:.2f}")
|
||||||
|
|
||||||
processed_connections.add((keyword1, keyword2))
|
# 基于激活值平方的独立概率选择
|
||||||
|
remember_map = {}
|
||||||
|
logger.info("基于激活值平方的归一化选择:")
|
||||||
|
|
||||||
# 获取连接路径上的所有节点
|
# 计算所有激活值的平方和
|
||||||
try:
|
total_squared_activation = sum(activation ** 2 for activation in activate_map.values())
|
||||||
path = nx.shortest_path(self.memory_graph.G, keyword1, keyword2)
|
if total_squared_activation > 0:
|
||||||
for node in path:
|
# 计算归一化的激活值
|
||||||
if node not in keywords: # 只处理路径上的非关键词节点
|
normalized_activations = {
|
||||||
|
node: (activation ** 2) / total_squared_activation
|
||||||
|
for node, activation in activate_map.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
# 按归一化激活值排序并选择前max_memory_num个
|
||||||
|
sorted_nodes = sorted(
|
||||||
|
normalized_activations.items(),
|
||||||
|
key=lambda x: x[1],
|
||||||
|
reverse=True
|
||||||
|
)[:max_memory_num]
|
||||||
|
|
||||||
|
# 将选中的节点添加到remember_map
|
||||||
|
for node, normalized_activation in sorted_nodes:
|
||||||
|
remember_map[node] = activate_map[node] # 使用原始激活值
|
||||||
|
logger.info(f"节点 '{node}' 被选中 (归一化激活值: {normalized_activation:.2f}, 原始激活值: {activate_map[node]:.2f})")
|
||||||
|
else:
|
||||||
|
logger.info("没有有效的激活值")
|
||||||
|
|
||||||
|
# 从选中的节点中提取记忆
|
||||||
|
all_memories = []
|
||||||
|
logger.info("开始从选中的节点中提取记忆:")
|
||||||
|
for node, activation in remember_map.items():
|
||||||
|
logger.debug(f"处理节点 '{node}' (激活值: {activation:.2f}):")
|
||||||
node_data = self.memory_graph.G.nodes[node]
|
node_data = self.memory_graph.G.nodes[node]
|
||||||
memory_items = node_data.get("memory_items", [])
|
memory_items = node_data.get("memory_items", [])
|
||||||
if not isinstance(memory_items, list):
|
if not isinstance(memory_items, list):
|
||||||
memory_items = [memory_items] if memory_items else []
|
memory_items = [memory_items] if memory_items else []
|
||||||
|
|
||||||
|
if memory_items:
|
||||||
|
logger.debug(f"节点包含 {len(memory_items)} 条记忆")
|
||||||
|
# 计算每条记忆与输入文本的相似度
|
||||||
|
memory_similarities = []
|
||||||
|
for memory in memory_items:
|
||||||
# 计算与输入文本的相似度
|
# 计算与输入文本的相似度
|
||||||
node_words = set(jieba.cut(node))
|
memory_words = set(jieba.cut(memory))
|
||||||
text_words = set(jieba.cut(text))
|
text_words = set(jieba.cut(text))
|
||||||
all_words = node_words | text_words
|
all_words = memory_words | text_words
|
||||||
v1 = [1 if word in node_words else 0 for word in all_words]
|
v1 = [1 if word in memory_words else 0 for word in all_words]
|
||||||
v2 = [1 if word in text_words else 0 for word in all_words]
|
v2 = [1 if word in text_words else 0 for word in all_words]
|
||||||
similarity = cosine_similarity(v1, v2)
|
similarity = cosine_similarity(v1, v2)
|
||||||
|
memory_similarities.append((memory, similarity))
|
||||||
|
|
||||||
if similarity >= 0.3: # 相似度阈值
|
# 按相似度排序
|
||||||
all_memories.append((node, memory_items, similarity))
|
memory_similarities.sort(key=lambda x: x[1], reverse=True)
|
||||||
except nx.NetworkXNoPath:
|
# 获取最匹配的记忆
|
||||||
|
top_memories = memory_similarities[:max_memory_length]
|
||||||
|
|
||||||
|
|
||||||
|
# 添加到结果中
|
||||||
|
for memory, similarity in top_memories:
|
||||||
|
all_memories.append((node, [memory], similarity))
|
||||||
|
logger.info(f"选中记忆: {memory} (相似度: {similarity:.2f})")
|
||||||
|
else:
|
||||||
|
logger.info("节点没有记忆")
|
||||||
|
|
||||||
|
# 去重(基于记忆内容)
|
||||||
|
logger.debug("开始记忆去重:")
|
||||||
|
seen_memories = set()
|
||||||
|
unique_memories = []
|
||||||
|
for topic, memory_items, activation_value in all_memories:
|
||||||
|
memory = memory_items[0] # 因为每个topic只有一条记忆
|
||||||
|
if memory not in seen_memories:
|
||||||
|
seen_memories.add(memory)
|
||||||
|
unique_memories.append((topic, memory_items, activation_value))
|
||||||
|
logger.debug(f"保留记忆: {memory} (来自节点: {topic}, 激活值: {activation_value:.2f})")
|
||||||
|
else:
|
||||||
|
logger.debug(f"跳过重复记忆: {memory} (来自节点: {topic})")
|
||||||
|
|
||||||
|
# 转换为(关键词, 记忆)格式
|
||||||
|
result = []
|
||||||
|
for topic, memory_items, _ in unique_memories:
|
||||||
|
memory = memory_items[0] # 因为每个topic只有一条记忆
|
||||||
|
result.append((topic, memory))
|
||||||
|
logger.info(f"选中记忆: {memory} (来自节点: {topic})")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def get_activate_from_text(self, text: str, max_depth: int = 3,
|
||||||
|
fast_retrieval: bool = False) -> float:
|
||||||
|
"""从文本中提取关键词并获取相关记忆。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): 输入文本
|
||||||
|
num (int, optional): 需要返回的记忆数量。默认为5。
|
||||||
|
max_depth (int, optional): 记忆检索深度。默认为2。
|
||||||
|
fast_retrieval (bool, optional): 是否使用快速检索。默认为False。
|
||||||
|
如果为True,使用jieba分词和TF-IDF提取关键词,速度更快但可能不够准确。
|
||||||
|
如果为False,使用LLM提取关键词,速度较慢但更准确。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: 激活节点数与总节点数的比值
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
if fast_retrieval:
|
||||||
|
# 使用jieba分词提取关键词
|
||||||
|
words = jieba.cut(text)
|
||||||
|
# 过滤掉停用词和单字词
|
||||||
|
keywords = [word for word in words if len(word) > 1]
|
||||||
|
# 去重
|
||||||
|
keywords = list(set(keywords))
|
||||||
|
# 限制关键词数量
|
||||||
|
keywords = keywords[:5]
|
||||||
|
else:
|
||||||
|
# 使用LLM提取关键词
|
||||||
|
topic_num = min(5, max(1, int(len(text) * 0.1))) # 根据文本长度动态调整关键词数量
|
||||||
|
# logger.info(f"提取关键词数量: {topic_num}")
|
||||||
|
topics_response = await self.llm_topic_judge.generate_response(
|
||||||
|
self.find_topic_llm(text, topic_num)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 提取关键词
|
||||||
|
keywords = re.findall(r'<([^>]+)>', topics_response[0])
|
||||||
|
if not keywords:
|
||||||
|
keywords = ['none']
|
||||||
|
else:
|
||||||
|
keywords = [
|
||||||
|
keyword.strip()
|
||||||
|
for keyword in ','.join(keywords).replace(",", ",").replace("、", ",").replace(" ", ",").split(",")
|
||||||
|
if keyword.strip()
|
||||||
|
]
|
||||||
|
|
||||||
|
# logger.info(f"提取的关键词: {', '.join(keywords)}")
|
||||||
|
|
||||||
|
# 过滤掉不存在于记忆图中的关键词
|
||||||
|
valid_keywords = [keyword for keyword in keywords if keyword in self.memory_graph.G]
|
||||||
|
if not valid_keywords:
|
||||||
|
logger.info("没有找到有效的关键词节点")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
logger.info(f"有效的关键词: {', '.join(valid_keywords)}")
|
||||||
|
|
||||||
|
# 从每个关键词获取记忆
|
||||||
|
keyword_connections = [] # 存储关键词之间的连接关系
|
||||||
|
activation_words = set(valid_keywords) # 存储所有激活词(包括关键词和途经点)
|
||||||
|
activate_map = {} # 存储每个词的累计激活值
|
||||||
|
|
||||||
|
# 对每个关键词进行扩散式检索
|
||||||
|
for keyword in valid_keywords:
|
||||||
|
logger.debug(f"开始以关键词 '{keyword}' 为中心进行扩散检索 (最大深度: {max_depth}):")
|
||||||
|
# 初始化激活值
|
||||||
|
activation_values = {keyword: 1.0}
|
||||||
|
# 记录已访问的节点
|
||||||
|
visited_nodes = {keyword}
|
||||||
|
# 待处理的节点队列,每个元素是(节点, 激活值, 当前深度)
|
||||||
|
nodes_to_process = [(keyword, 1.0, 0)]
|
||||||
|
|
||||||
|
while nodes_to_process:
|
||||||
|
current_node, current_activation, current_depth = nodes_to_process.pop(0)
|
||||||
|
|
||||||
|
# 如果激活值小于0或超过最大深度,停止扩散
|
||||||
|
if current_activation <= 0 or current_depth >= max_depth:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 去重(基于主题)
|
# 获取当前节点的所有邻居
|
||||||
seen_topics = set()
|
neighbors = list(self.memory_graph.G.neighbors(current_node))
|
||||||
unique_memories = []
|
|
||||||
for topic, memory_items, similarity in all_memories:
|
|
||||||
if topic not in seen_topics:
|
|
||||||
seen_topics.add(topic)
|
|
||||||
unique_memories.append((topic, memory_items, similarity))
|
|
||||||
|
|
||||||
# 按相似度排序并返回前num个
|
for neighbor in neighbors:
|
||||||
unique_memories.sort(key=lambda x: x[2], reverse=True)
|
if neighbor in visited_nodes:
|
||||||
return unique_memories[:num]
|
continue
|
||||||
|
|
||||||
|
# 获取连接强度
|
||||||
|
edge_data = self.memory_graph.G[current_node][neighbor]
|
||||||
|
strength = edge_data.get("strength", 1)
|
||||||
|
|
||||||
|
# 计算新的激活值
|
||||||
|
new_activation = current_activation - (1 / strength)
|
||||||
|
|
||||||
|
if new_activation > 0:
|
||||||
|
activation_values[neighbor] = new_activation
|
||||||
|
visited_nodes.add(neighbor)
|
||||||
|
nodes_to_process.append((neighbor, new_activation, current_depth + 1))
|
||||||
|
logger.debug(f"节点 '{neighbor}' 被激活,激活值: {new_activation:.2f} (通过 '{current_node}' 连接,强度: {strength}, 深度: {current_depth + 1})")
|
||||||
|
|
||||||
|
# 更新激活映射
|
||||||
|
for node, activation_value in activation_values.items():
|
||||||
|
if activation_value > 0:
|
||||||
|
if node in activate_map:
|
||||||
|
activate_map[node] += activation_value
|
||||||
|
else:
|
||||||
|
activate_map[node] = activation_value
|
||||||
|
|
||||||
|
# 输出激活映射
|
||||||
|
# logger.info("激活映射统计:")
|
||||||
|
# for node, total_activation in sorted(activate_map.items(), key=lambda x: x[1], reverse=True):
|
||||||
|
# logger.info(f"节点 '{node}': 累计激活值 = {total_activation:.2f}")
|
||||||
|
|
||||||
|
# 计算激活节点数与总节点数的比值
|
||||||
|
total_nodes = len(self.memory_graph.G.nodes())
|
||||||
|
activated_nodes = len(activate_map)
|
||||||
|
activation_ratio = activated_nodes / total_nodes if total_nodes > 0 else 0
|
||||||
|
logger.info(f"激活节点数: {activated_nodes}, 总节点数: {total_nodes}, 激活比例: {activation_ratio:.2%}")
|
||||||
|
|
||||||
|
return activation_ratio
|
||||||
|
|
||||||
class HippocampusManager:
|
class HippocampusManager:
|
||||||
_instance = None
|
_instance = None
|
||||||
@@ -1109,12 +1289,19 @@ class HippocampusManager:
|
|||||||
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
||||||
return await self._hippocampus.parahippocampal_gyrus.operation_forget_topic(percentage)
|
return await self._hippocampus.parahippocampal_gyrus.operation_forget_topic(percentage)
|
||||||
|
|
||||||
async def get_memory_from_text(self, text: str, num: int = 5, max_depth: int = 2,
|
async def get_memory_from_text(self, text: str, max_memory_num: int = 3, max_memory_length: int = 2, max_depth: int = 3,
|
||||||
fast_retrieval: bool = False) -> list:
|
fast_retrieval: bool = False) -> list:
|
||||||
"""从文本中获取相关记忆的公共接口"""
|
"""从文本中获取相关记忆的公共接口"""
|
||||||
if not self._initialized:
|
if not self._initialized:
|
||||||
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
||||||
return await self._hippocampus.get_memory_from_text(text, num, max_depth, fast_retrieval)
|
return await self._hippocampus.get_memory_from_text(text, max_memory_num, max_memory_length, max_depth, fast_retrieval)
|
||||||
|
|
||||||
|
async def get_activate_from_text(self, text: str, max_depth: int = 3,
|
||||||
|
fast_retrieval: bool = False) -> float:
|
||||||
|
"""从文本中获取激活值的公共接口"""
|
||||||
|
if not self._initialized:
|
||||||
|
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
|
||||||
|
return await self._hippocampus.get_activate_from_text(text, max_depth, fast_retrieval)
|
||||||
|
|
||||||
def get_memory_from_keyword(self, keyword: str, max_depth: int = 2) -> list:
|
def get_memory_from_keyword(self, keyword: str, max_depth: int = 2) -> list:
|
||||||
"""从关键词获取相关记忆的公共接口"""
|
"""从关键词获取相关记忆的公共接口"""
|
||||||
|
|||||||
@@ -42,21 +42,22 @@ async def test_memory_system():
|
|||||||
[03-24 10:46:49] ❦幻凌慌てない(ta的id:2459587037): 为什么改了回复系数麦麦还是不怎么回复?大佬们'''
|
[03-24 10:46:49] ❦幻凌慌てない(ta的id:2459587037): 为什么改了回复系数麦麦还是不怎么回复?大佬们'''
|
||||||
|
|
||||||
|
|
||||||
test_text = '''千石可乐:niko分不清AI的陪伴和人类的陪伴,是这样吗?'''
|
# test_text = '''千石可乐:分不清AI的陪伴和人类的陪伴,是这样吗?'''
|
||||||
print(f"开始测试记忆检索,测试文本: {test_text}\n")
|
print(f"开始测试记忆检索,测试文本: {test_text}\n")
|
||||||
memories = await hippocampus_manager.get_memory_from_text(
|
memories = await hippocampus_manager.get_memory_from_text(
|
||||||
text=test_text,
|
text=test_text,
|
||||||
num=3,
|
max_memory_num=3,
|
||||||
|
max_memory_length=2,
|
||||||
max_depth=3,
|
max_depth=3,
|
||||||
fast_retrieval=False
|
fast_retrieval=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
print("检索到的记忆:")
|
print("检索到的记忆:")
|
||||||
for topic, memory_items, similarity in memories:
|
for topic, memory_items in memories:
|
||||||
print(f"主题: {topic}")
|
print(f"主题: {topic}")
|
||||||
print(f"相似度: {similarity:.2f}")
|
print(f"- {memory_items}")
|
||||||
for memory in memory_items:
|
|
||||||
print(f"- {memory}")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user