From bd4e36b1cf678dffd02b7a2a21a1d7d5d4eed561 Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Thu, 20 Nov 2025 18:06:23 +0800 Subject: [PATCH] =?UTF-8?q?feat(replyer):=20=E6=B7=BB=E5=8A=A0=E6=9C=80?= =?UTF-8?q?=E8=BF=91=E6=B6=88=E6=81=AF=E6=94=AF=E6=8C=81=E4=BB=A5=E6=9E=84?= =?UTF-8?q?=E5=BB=BA=E8=AE=B0=E5=BF=86=E5=9D=97=E5=92=8C=E6=9F=A5=E8=AF=A2?= =?UTF-8?q?=E6=96=87=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/replyer/default_generator.py | 74 ++++++++++++++++++- .../utils/three_tier_formatter.py | 34 ++++++--- 2 files changed, 94 insertions(+), 14 deletions(-) diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index 92ec77391..a760e6025 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -566,12 +566,18 @@ class DefaultReplyer: return f"{expression_habits_title}\n{expression_habits_block}" - async def build_memory_block(self, chat_history: str, target: str) -> str: + async def build_memory_block( + self, + chat_history: str, + target: str, + recent_messages: list[dict[str, Any]] | None = None, + ) -> str: """构建记忆块(使用三层记忆系统) Args: chat_history: 聊天历史记录 target: 目标消息内容 + recent_messages: 原始聊天消息列表(用于构建查询块) Returns: str: 记忆信息字符串 @@ -589,9 +595,12 @@ class DefaultReplyer: logger.debug("[三层记忆] 管理器未初始化") return "" + # 目标查询改为使用最近多条消息的组合块 + query_text = self._build_memory_query_text(target, recent_messages) + # 使用统一管理器的智能检索(Judge模型决策) search_result = await unified_manager.search_memories( - query_text=target, + query_text=query_text, use_judge=True, recent_chat_history=chat_history, # 传递最近聊天历史 ) @@ -629,6 +638,62 @@ class DefaultReplyer: logger.error(f"[三层记忆] 检索失败: {e}", exc_info=True) return "" + def _build_memory_query_text( + self, + fallback_text: str, + recent_messages: list[dict[str, Any]] | None, + block_size: int = 5, + ) -> str: + """ + 将最近若干条消息拼接为一个查询块,用于生成语义向量。 + + Args: + fallback_text: 如果无法拼接消息块时使用的后备文本 + recent_messages: 最近的消息列表 + block_size: 组合的消息数量 + + Returns: + str: 用于检索的查询文本 + """ + if not recent_messages: + return fallback_text + + lines: list[str] = [] + for message in recent_messages[-block_size:]: + sender = ( + message.get("sender_name") + or message.get("person_name") + or message.get("user_nickname") + or message.get("user_cardname") + or message.get("nickname") + or message.get("sender") + ) + + if not sender and isinstance(message.get("user_info"), dict): + user_info = message["user_info"] + sender = user_info.get("user_nickname") or user_info.get("user_cardname") + + sender = sender or message.get("user_id") or "未知" + + content = ( + message.get("processed_plain_text") + or message.get("display_message") + or message.get("content") + or message.get("message") + or message.get("text") + or "" + ) + + content = str(content).strip() + if content: + lines.append(f"{sender}: {content}") + + fallback_clean = fallback_text.strip() + if not lines: + return fallback_clean or fallback_text + + return "\n".join(lines[-block_size:]) + async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str: @@ -1251,7 +1316,10 @@ class DefaultReplyer: self._time_and_run_task(self.build_relation_info(sender, target), "relation_info") ), "memory_block": asyncio.create_task( - self._time_and_run_task(self.build_memory_block(chat_talking_prompt_short, target), "memory_block") + self._time_and_run_task( + self.build_memory_block(chat_talking_prompt_short, target, message_list_before_short), + "memory_block", + ) ), "tool_info": asyncio.create_task( self._time_and_run_task( diff --git a/src/memory_graph/utils/three_tier_formatter.py b/src/memory_graph/utils/three_tier_formatter.py index 934eec947..551278c81 100644 --- a/src/memory_graph/utils/three_tier_formatter.py +++ b/src/memory_graph/utils/three_tier_formatter.py @@ -312,7 +312,8 @@ class ThreeTierMemoryFormatter: # 查找客体和属性 objects = [] - attributes = {} + attributes: dict[str, str] = {} + attribute_names: dict[str, str] = {} for edge in memory.edges: edge_type = edge.edge_type.value if hasattr(edge.edge_type, 'value') else str(edge.edge_type) @@ -320,17 +321,28 @@ class ThreeTierMemoryFormatter: if edge_type == "核心关系" and edge.source_id == topic_node.id: obj_node = memory.get_node_by_id(edge.target_id) if obj_node: - if edge.relation and edge.relation != "未知": - objects.append(f"{edge.relation}{obj_node.content}") + relation_label = (edge.relation or "").strip() + obj_text = obj_node.content + if relation_label and relation_label not in {"未知", "核心关系"}: + objects.append(f"{relation_label}:{obj_text}") else: - objects.append(obj_node.content) + objects.append(obj_text) elif edge_type == "属性关系": attr_node = memory.get_node_by_id(edge.target_id) - if attr_node: - attr_name = edge.relation if edge.relation else "属性" - # 使用字典避免重复属性,后面的会覆盖前面的 - attributes[attr_name] = attr_node.content + if not attr_node: + continue + + if edge.source_id == topic_node.id: + # 记录属性节点的名称,稍后匹配对应的值节点 + attribute_names[attr_node.id] = attr_node.content + continue + + attr_name = attribute_names.get(edge.source_id) + if not attr_name: + attr_name = edge.relation.strip() if edge.relation else "属性" + + attributes[attr_name] = attr_node.content # 检查节点中的属性(处理 "key=value" 格式) for node in memory.nodes: @@ -338,9 +350,9 @@ class ThreeTierMemoryFormatter: # 处理 "key=value" 格式的属性 if "=" in node.content: key, value = node.content.split("=", 1) - attributes[key.strip()] = value.strip() + attributes.setdefault(key.strip(), value.strip()) else: - attributes["属性"] = node.content + attributes.setdefault("属性", node.content) # 构建最终格式 result = f"[{type_label}] {subject}-{topic}" @@ -437,4 +449,4 @@ class ThreeTierMemoryFormatter: # 创建全局格式化器实例 -memory_formatter = ThreeTierMemoryFormatter() \ No newline at end of file +memory_formatter = ThreeTierMemoryFormatter()