feat: 添加三级记忆系统提示词格式化器,优化记忆块和短期记忆的格式化逻辑
This commit is contained in:
@@ -305,29 +305,11 @@ class UnifiedMemoryManager:
|
||||
try:
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.memory_graph.utils.memory_formatter import format_memory_for_prompt
|
||||
from src.memory_graph.utils.three_tier_formatter import memory_formatter
|
||||
|
||||
# 构建提示词 - 使用优化的格式
|
||||
# 防御性处理:确保 combined_text 是字符串
|
||||
perceptual_texts = []
|
||||
for i, block in enumerate(perceptual_blocks):
|
||||
text = block.combined_text
|
||||
if isinstance(text, list):
|
||||
text = " ".join(str(item) for item in text)
|
||||
elif not isinstance(text, str):
|
||||
text = str(text)
|
||||
perceptual_texts.append(f"记忆块{i+1}:\n{text}")
|
||||
|
||||
perceptual_desc = "\n\n".join(str(item) for item in perceptual_texts)
|
||||
|
||||
# 短期记忆使用 "主体-主题(属性)" 格式
|
||||
short_term_texts = []
|
||||
for mem in short_term_memories:
|
||||
formatted = format_memory_for_prompt(mem, include_metadata=False)
|
||||
if formatted: # 只添加非空的格式化结果
|
||||
short_term_texts.append(f"- {formatted}")
|
||||
|
||||
short_term_desc = "\n".join(str(item) for item in short_term_texts)
|
||||
# 使用新的三级记忆格式化器
|
||||
perceptual_desc = memory_formatter.format_perceptual_memory(perceptual_blocks)
|
||||
short_term_desc = memory_formatter.format_short_term_memory(short_term_memories)
|
||||
|
||||
# 构建聊天历史块(如果提供)
|
||||
chat_history_block = ""
|
||||
@@ -342,10 +324,10 @@ class UnifiedMemoryManager:
|
||||
**用户查询:**
|
||||
{query}
|
||||
|
||||
{chat_history_block}**检索到的感知记忆块:**
|
||||
{chat_history_block}**检索到的感知记忆(即时对话,格式:【时间 (聊天流)】消息列表):**
|
||||
{perceptual_desc or '(无)'}
|
||||
|
||||
**检索到的短期记忆(结构化记忆,格式:主体-主题(属性)):**
|
||||
**检索到的短期记忆(结构化信息,自然语言描述):**
|
||||
{short_term_desc or '(无)'}
|
||||
|
||||
**任务要求:**
|
||||
|
||||
424
src/memory_graph/utils/three_tier_formatter.py
Normal file
424
src/memory_graph/utils/three_tier_formatter.py
Normal file
@@ -0,0 +1,424 @@
|
||||
"""
|
||||
三级记忆系统提示词格式化器
|
||||
|
||||
根据用户需求优化三级记忆的提示词构建格式:
|
||||
- 感知记忆:【时间 (聊天流名字)】+ 消息块列表
|
||||
- 短期记忆:自然语言描述
|
||||
- 长期记忆:[事实] 主体-主题+客体(属性1:内容, 属性2:内容)
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from src.memory_graph.models import Memory, MemoryBlock, ShortTermMemory
|
||||
|
||||
|
||||
class ThreeTierMemoryFormatter:
|
||||
"""三级记忆系统提示词格式化器"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化格式化器"""
|
||||
pass
|
||||
|
||||
def format_perceptual_memory(self, blocks: list[MemoryBlock]) -> str:
|
||||
"""
|
||||
格式化感知记忆为提示词
|
||||
|
||||
格式:
|
||||
- 【时间 (聊天流名字)】
|
||||
xxx: abcd
|
||||
xxx: aaaa
|
||||
xxx: dasd
|
||||
xxx: ddda
|
||||
xxx: adwd
|
||||
|
||||
- 【时间 (聊天流名字)】
|
||||
xxx: abcd
|
||||
xxx: aaaa
|
||||
...
|
||||
|
||||
Args:
|
||||
blocks: 感知记忆块列表
|
||||
|
||||
Returns:
|
||||
格式化后的感知记忆提示词
|
||||
"""
|
||||
if not blocks:
|
||||
return ""
|
||||
|
||||
lines = []
|
||||
|
||||
for block in blocks:
|
||||
# 提取时间和聊天流信息
|
||||
time_str = self._extract_time_from_block(block)
|
||||
stream_name = self._extract_stream_name_from_block(block)
|
||||
|
||||
# 添加块标题
|
||||
lines.append(f"- 【{time_str} ({stream_name})】")
|
||||
|
||||
# 添加消息内容
|
||||
for message in block.messages:
|
||||
sender = self._extract_sender_name(message)
|
||||
content = self._extract_message_content(message)
|
||||
if content:
|
||||
lines.append(f"{sender}: {content}")
|
||||
|
||||
# 块之间添加空行
|
||||
lines.append("")
|
||||
|
||||
# 移除最后的空行并返回
|
||||
if lines and lines[-1] == "":
|
||||
lines.pop()
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def format_short_term_memory(self, memories: list[ShortTermMemory]) -> str:
|
||||
"""
|
||||
格式化短期记忆为提示词
|
||||
|
||||
使用自然语言描述的内容
|
||||
|
||||
Args:
|
||||
memories: 短期记忆列表
|
||||
|
||||
Returns:
|
||||
格式化后的短期记忆提示词
|
||||
"""
|
||||
if not memories:
|
||||
return ""
|
||||
|
||||
lines = []
|
||||
|
||||
for memory in memories:
|
||||
# 使用content字段作为自然语言描述
|
||||
if memory.content:
|
||||
lines.append(f"- {memory.content}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def format_long_term_memory(self, memories: list[Memory]) -> str:
|
||||
"""
|
||||
格式化长期记忆为提示词
|
||||
|
||||
格式:[事实] 主体-主题+客体(属性1:内容, 属性2:内容)
|
||||
|
||||
Args:
|
||||
memories: 长期记忆列表
|
||||
|
||||
Returns:
|
||||
格式化后的长期记忆提示词
|
||||
"""
|
||||
if not memories:
|
||||
return ""
|
||||
|
||||
lines = []
|
||||
|
||||
for memory in memories:
|
||||
formatted = self._format_single_long_term_memory(memory)
|
||||
if formatted:
|
||||
lines.append(f"- {formatted}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def format_all_tiers(
|
||||
self,
|
||||
perceptual_blocks: list[MemoryBlock],
|
||||
short_term_memories: list[ShortTermMemory],
|
||||
long_term_memories: list[Memory]
|
||||
) -> str:
|
||||
"""
|
||||
格式化所有三级记忆为完整的提示词
|
||||
|
||||
Args:
|
||||
perceptual_blocks: 感知记忆块列表
|
||||
short_term_memories: 短期记忆列表
|
||||
long_term_memories: 长期记忆列表
|
||||
|
||||
Returns:
|
||||
完整的三级记忆提示词
|
||||
"""
|
||||
sections = []
|
||||
|
||||
# 感知记忆
|
||||
perceptual_text = self.format_perceptual_memory(perceptual_blocks)
|
||||
if perceptual_text:
|
||||
sections.append("### 感知记忆(即时对话)")
|
||||
sections.append(perceptual_text)
|
||||
sections.append("")
|
||||
|
||||
# 短期记忆
|
||||
short_term_text = self.format_short_term_memory(short_term_memories)
|
||||
if short_term_text:
|
||||
sections.append("### 短期记忆(结构化信息)")
|
||||
sections.append(short_term_text)
|
||||
sections.append("")
|
||||
|
||||
# 长期记忆
|
||||
long_term_text = self.format_long_term_memory(long_term_memories)
|
||||
if long_term_text:
|
||||
sections.append("### 长期记忆(知识图谱)")
|
||||
sections.append(long_term_text)
|
||||
sections.append("")
|
||||
|
||||
# 移除最后的空行
|
||||
if sections and sections[-1] == "":
|
||||
sections.pop()
|
||||
|
||||
return "\n".join(sections)
|
||||
|
||||
def _extract_time_from_block(self, block: MemoryBlock) -> str:
|
||||
"""
|
||||
从记忆块中提取时间信息
|
||||
|
||||
Args:
|
||||
block: 记忆块
|
||||
|
||||
Returns:
|
||||
格式化的时间字符串
|
||||
"""
|
||||
# 优先使用创建时间
|
||||
if block.created_at:
|
||||
return block.created_at.strftime("%H:%M")
|
||||
|
||||
# 如果有消息,尝试从第一条消息提取时间
|
||||
if block.messages:
|
||||
first_msg = block.messages[0]
|
||||
timestamp = first_msg.get("timestamp")
|
||||
if timestamp:
|
||||
if isinstance(timestamp, datetime):
|
||||
return timestamp.strftime("%H:%M")
|
||||
elif isinstance(timestamp, str):
|
||||
try:
|
||||
dt = datetime.fromisoformat(timestamp.replace("Z", "+00:00"))
|
||||
return dt.strftime("%H:%M")
|
||||
except:
|
||||
pass
|
||||
|
||||
return "未知时间"
|
||||
|
||||
def _extract_stream_name_from_block(self, block: MemoryBlock) -> str:
|
||||
"""
|
||||
从记忆块中提取聊天流名称
|
||||
|
||||
Args:
|
||||
block: 记忆块
|
||||
|
||||
Returns:
|
||||
聊天流名称
|
||||
"""
|
||||
# 尝试从元数据中获取
|
||||
if block.metadata:
|
||||
stream_name = block.metadata.get("stream_name") or block.metadata.get("chat_stream")
|
||||
if stream_name:
|
||||
return stream_name
|
||||
|
||||
# 尝试从消息中提取
|
||||
if block.messages:
|
||||
first_msg = block.messages[0]
|
||||
stream_name = first_msg.get("stream_name") or first_msg.get("chat_stream")
|
||||
if stream_name:
|
||||
return stream_name
|
||||
|
||||
return "默认聊天"
|
||||
|
||||
def _extract_sender_name(self, message: dict[str, Any]) -> str:
|
||||
"""
|
||||
从消息中提取发送者名称
|
||||
|
||||
Args:
|
||||
message: 消息字典
|
||||
|
||||
Returns:
|
||||
发送者名称
|
||||
"""
|
||||
sender = message.get("sender_name") or message.get("sender") or message.get("user_name")
|
||||
if sender:
|
||||
return str(sender)
|
||||
|
||||
# 如果没有发送者信息,使用默认值
|
||||
role = message.get("role", "")
|
||||
if role == "user":
|
||||
return "用户"
|
||||
elif role == "assistant":
|
||||
return "助手"
|
||||
else:
|
||||
return "未知"
|
||||
|
||||
def _extract_message_content(self, message: dict[str, Any]) -> str:
|
||||
"""
|
||||
从消息中提取内容
|
||||
|
||||
Args:
|
||||
message: 消息字典
|
||||
|
||||
Returns:
|
||||
消息内容
|
||||
"""
|
||||
content = message.get("content") or message.get("text") or message.get("message")
|
||||
if content:
|
||||
return str(content).strip()
|
||||
return ""
|
||||
|
||||
def _format_single_long_term_memory(self, memory: Memory) -> str:
|
||||
"""
|
||||
格式化单个长期记忆
|
||||
|
||||
格式:[事实] 主体-主题+客体(属性1:内容, 属性2:内容)
|
||||
|
||||
Args:
|
||||
memory: 长期记忆对象
|
||||
|
||||
Returns:
|
||||
格式化后的长期记忆
|
||||
"""
|
||||
try:
|
||||
# 获取记忆类型标签
|
||||
type_label = self._get_memory_type_label(memory.memory_type)
|
||||
|
||||
# 获取主体节点
|
||||
subject_node = memory.get_subject_node()
|
||||
if not subject_node:
|
||||
return ""
|
||||
|
||||
subject = subject_node.content
|
||||
|
||||
# 查找主题节点
|
||||
topic_node = None
|
||||
for edge in memory.edges:
|
||||
edge_type = edge.edge_type.value if hasattr(edge.edge_type, 'value') else str(edge.edge_type)
|
||||
if edge_type == "记忆类型" and edge.source_id == memory.subject_id:
|
||||
topic_node = memory.get_node_by_id(edge.target_id)
|
||||
break
|
||||
|
||||
if not topic_node:
|
||||
return f"[{type_label}] {subject}"
|
||||
|
||||
topic = topic_node.content
|
||||
|
||||
# 查找客体和属性
|
||||
objects = []
|
||||
attributes = []
|
||||
|
||||
for edge in memory.edges:
|
||||
edge_type = edge.edge_type.value if hasattr(edge.edge_type, 'value') else str(edge.edge_type)
|
||||
|
||||
if edge_type == "核心关系" and edge.source_id == topic_node.id:
|
||||
obj_node = memory.get_node_by_id(edge.target_id)
|
||||
if obj_node:
|
||||
if edge.relation and edge.relation != "未知":
|
||||
objects.append(f"{edge.relation}{obj_node.content}")
|
||||
else:
|
||||
objects.append(obj_node.content)
|
||||
|
||||
elif edge_type == "属性关系":
|
||||
attr_node = memory.get_node_by_id(edge.target_id)
|
||||
if attr_node:
|
||||
attr_name = edge.relation if edge.relation else "属性"
|
||||
attributes.append(f"{attr_name}:{attr_node.content}")
|
||||
|
||||
# 检查节点中的属性
|
||||
for node in memory.nodes:
|
||||
if hasattr(node, 'node_type') and str(node.node_type) == "属性":
|
||||
# 处理 "key=value" 格式的属性
|
||||
if "=" in node.content:
|
||||
key, value = node.content.split("=", 1)
|
||||
attributes.append(f"{key.strip()}:{value.strip()}")
|
||||
else:
|
||||
attributes.append(f"属性:{node.content}")
|
||||
|
||||
# 构建最终格式
|
||||
result = f"[{type_label}] {subject}-{topic}"
|
||||
|
||||
if objects:
|
||||
result += "-" + "-".join(objects)
|
||||
|
||||
if attributes:
|
||||
result += "(" + ",".join(attributes) + ")"
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
# 如果格式化失败,返回基本描述
|
||||
return f"[记忆] 格式化失败: {str(e)}"
|
||||
|
||||
def _get_memory_type_label(self, memory_type) -> str:
|
||||
"""
|
||||
获取记忆类型的中文标签
|
||||
|
||||
Args:
|
||||
memory_type: 记忆类型
|
||||
|
||||
Returns:
|
||||
中文标签
|
||||
"""
|
||||
if hasattr(memory_type, 'value'):
|
||||
type_value = memory_type.value
|
||||
else:
|
||||
type_value = str(memory_type)
|
||||
|
||||
type_mapping = {
|
||||
"EVENT": "事件",
|
||||
"event": "事件",
|
||||
"事件": "事件",
|
||||
"FACT": "事实",
|
||||
"fact": "事实",
|
||||
"事实": "事实",
|
||||
"RELATION": "关系",
|
||||
"relation": "关系",
|
||||
"关系": "关系",
|
||||
"OPINION": "观点",
|
||||
"opinion": "观点",
|
||||
"观点": "观点",
|
||||
}
|
||||
|
||||
return type_mapping.get(type_value, "事实")
|
||||
|
||||
def format_for_context_injection(
|
||||
self,
|
||||
query: str,
|
||||
perceptual_blocks: list[MemoryBlock],
|
||||
short_term_memories: list[ShortTermMemory],
|
||||
long_term_memories: list[Memory],
|
||||
max_perceptual: int = 3,
|
||||
max_short_term: int = 5,
|
||||
max_long_term: int = 10
|
||||
) -> str:
|
||||
"""
|
||||
为上下文注入格式化记忆
|
||||
|
||||
Args:
|
||||
query: 用户查询
|
||||
perceptual_blocks: 感知记忆块列表
|
||||
short_term_memories: 短期记忆列表
|
||||
long_term_memories: 长期记忆列表
|
||||
max_perceptual: 最大感知记忆数量
|
||||
max_short_term: 最大短期记忆数量
|
||||
max_long_term: 最大长期记忆数量
|
||||
|
||||
Returns:
|
||||
格式化的上下文
|
||||
"""
|
||||
sections = [f"## 用户查询:{query}", ""]
|
||||
|
||||
# 限制数量并格式化
|
||||
limited_perceptual = perceptual_blocks[:max_perceptual]
|
||||
limited_short_term = short_term_memories[:max_short_term]
|
||||
limited_long_term = long_term_memories[:max_long_term]
|
||||
|
||||
all_tiers_text = self.format_all_tiers(
|
||||
limited_perceptual,
|
||||
limited_short_term,
|
||||
limited_long_term
|
||||
)
|
||||
|
||||
if all_tiers_text:
|
||||
sections.append("## 相关记忆")
|
||||
sections.append(all_tiers_text)
|
||||
|
||||
return "\n".join(sections)
|
||||
|
||||
|
||||
# 创建全局格式化器实例
|
||||
memory_formatter = ThreeTierMemoryFormatter()
|
||||
Reference in New Issue
Block a user