feat: 添加三级记忆系统提示词格式化器,优化记忆块和短期记忆的格式化逻辑

This commit is contained in:
Windpicker-owo
2025-11-19 14:18:54 +08:00
parent f685ea9fa8
commit ea164ca74e
2 changed files with 430 additions and 24 deletions

View File

@@ -305,29 +305,11 @@ class UnifiedMemoryManager:
try:
from src.config.config import model_config
from src.llm_models.utils_model import LLMRequest
from src.memory_graph.utils.memory_formatter import format_memory_for_prompt
from src.memory_graph.utils.three_tier_formatter import memory_formatter
# 构建提示词 - 使用优化的格式
# 防御性处理:确保 combined_text 是字符串
perceptual_texts = []
for i, block in enumerate(perceptual_blocks):
text = block.combined_text
if isinstance(text, list):
text = " ".join(str(item) for item in text)
elif not isinstance(text, str):
text = str(text)
perceptual_texts.append(f"记忆块{i+1}:\n{text}")
perceptual_desc = "\n\n".join(str(item) for item in perceptual_texts)
# 短期记忆使用 "主体-主题(属性)" 格式
short_term_texts = []
for mem in short_term_memories:
formatted = format_memory_for_prompt(mem, include_metadata=False)
if formatted: # 只添加非空的格式化结果
short_term_texts.append(f"- {formatted}")
short_term_desc = "\n".join(str(item) for item in short_term_texts)
# 使用新的三级记忆格式化器
perceptual_desc = memory_formatter.format_perceptual_memory(perceptual_blocks)
short_term_desc = memory_formatter.format_short_term_memory(short_term_memories)
# 构建聊天历史块(如果提供)
chat_history_block = ""
@@ -342,10 +324,10 @@ class UnifiedMemoryManager:
**用户查询:**
{query}
{chat_history_block}**检索到的感知记忆**
{chat_history_block}**检索到的感知记忆(即时对话,格式:【时间 (聊天流)】消息列表)**
{perceptual_desc or '(无)'}
**检索到的短期记忆(结构化记忆,格式:主体-主题(属性)**
**检索到的短期记忆(结构化信息,自然语言描述**
{short_term_desc or '(无)'}
**任务要求:**

View File

@@ -0,0 +1,424 @@
"""
三级记忆系统提示词格式化器
根据用户需求优化三级记忆的提示词构建格式:
- 感知记忆:【时间 (聊天流名字)】+ 消息块列表
- 短期记忆:自然语言描述
- 长期记忆:[事实] 主体-主题+客体属性1内容 属性2内容
"""
import json
from datetime import datetime
from pathlib import Path
from typing import Any
from src.memory_graph.models import Memory, MemoryBlock, ShortTermMemory
class ThreeTierMemoryFormatter:
"""三级记忆系统提示词格式化器"""
def __init__(self):
"""初始化格式化器"""
pass
def format_perceptual_memory(self, blocks: list[MemoryBlock]) -> str:
"""
格式化感知记忆为提示词
格式:
- 【时间 (聊天流名字)】
xxx: abcd
xxx: aaaa
xxx: dasd
xxx: ddda
xxx: adwd
- 【时间 (聊天流名字)】
xxx: abcd
xxx: aaaa
...
Args:
blocks: 感知记忆块列表
Returns:
格式化后的感知记忆提示词
"""
if not blocks:
return ""
lines = []
for block in blocks:
# 提取时间和聊天流信息
time_str = self._extract_time_from_block(block)
stream_name = self._extract_stream_name_from_block(block)
# 添加块标题
lines.append(f"- 【{time_str} ({stream_name})】")
# 添加消息内容
for message in block.messages:
sender = self._extract_sender_name(message)
content = self._extract_message_content(message)
if content:
lines.append(f"{sender}: {content}")
# 块之间添加空行
lines.append("")
# 移除最后的空行并返回
if lines and lines[-1] == "":
lines.pop()
return "\n".join(lines)
def format_short_term_memory(self, memories: list[ShortTermMemory]) -> str:
"""
格式化短期记忆为提示词
使用自然语言描述的内容
Args:
memories: 短期记忆列表
Returns:
格式化后的短期记忆提示词
"""
if not memories:
return ""
lines = []
for memory in memories:
# 使用content字段作为自然语言描述
if memory.content:
lines.append(f"- {memory.content}")
return "\n".join(lines)
def format_long_term_memory(self, memories: list[Memory]) -> str:
"""
格式化长期记忆为提示词
格式:[事实] 主体-主题+客体属性1内容 属性2内容
Args:
memories: 长期记忆列表
Returns:
格式化后的长期记忆提示词
"""
if not memories:
return ""
lines = []
for memory in memories:
formatted = self._format_single_long_term_memory(memory)
if formatted:
lines.append(f"- {formatted}")
return "\n".join(lines)
def format_all_tiers(
self,
perceptual_blocks: list[MemoryBlock],
short_term_memories: list[ShortTermMemory],
long_term_memories: list[Memory]
) -> str:
"""
格式化所有三级记忆为完整的提示词
Args:
perceptual_blocks: 感知记忆块列表
short_term_memories: 短期记忆列表
long_term_memories: 长期记忆列表
Returns:
完整的三级记忆提示词
"""
sections = []
# 感知记忆
perceptual_text = self.format_perceptual_memory(perceptual_blocks)
if perceptual_text:
sections.append("### 感知记忆(即时对话)")
sections.append(perceptual_text)
sections.append("")
# 短期记忆
short_term_text = self.format_short_term_memory(short_term_memories)
if short_term_text:
sections.append("### 短期记忆(结构化信息)")
sections.append(short_term_text)
sections.append("")
# 长期记忆
long_term_text = self.format_long_term_memory(long_term_memories)
if long_term_text:
sections.append("### 长期记忆(知识图谱)")
sections.append(long_term_text)
sections.append("")
# 移除最后的空行
if sections and sections[-1] == "":
sections.pop()
return "\n".join(sections)
def _extract_time_from_block(self, block: MemoryBlock) -> str:
"""
从记忆块中提取时间信息
Args:
block: 记忆块
Returns:
格式化的时间字符串
"""
# 优先使用创建时间
if block.created_at:
return block.created_at.strftime("%H:%M")
# 如果有消息,尝试从第一条消息提取时间
if block.messages:
first_msg = block.messages[0]
timestamp = first_msg.get("timestamp")
if timestamp:
if isinstance(timestamp, datetime):
return timestamp.strftime("%H:%M")
elif isinstance(timestamp, str):
try:
dt = datetime.fromisoformat(timestamp.replace("Z", "+00:00"))
return dt.strftime("%H:%M")
except:
pass
return "未知时间"
def _extract_stream_name_from_block(self, block: MemoryBlock) -> str:
"""
从记忆块中提取聊天流名称
Args:
block: 记忆块
Returns:
聊天流名称
"""
# 尝试从元数据中获取
if block.metadata:
stream_name = block.metadata.get("stream_name") or block.metadata.get("chat_stream")
if stream_name:
return stream_name
# 尝试从消息中提取
if block.messages:
first_msg = block.messages[0]
stream_name = first_msg.get("stream_name") or first_msg.get("chat_stream")
if stream_name:
return stream_name
return "默认聊天"
def _extract_sender_name(self, message: dict[str, Any]) -> str:
"""
从消息中提取发送者名称
Args:
message: 消息字典
Returns:
发送者名称
"""
sender = message.get("sender_name") or message.get("sender") or message.get("user_name")
if sender:
return str(sender)
# 如果没有发送者信息,使用默认值
role = message.get("role", "")
if role == "user":
return "用户"
elif role == "assistant":
return "助手"
else:
return "未知"
def _extract_message_content(self, message: dict[str, Any]) -> str:
"""
从消息中提取内容
Args:
message: 消息字典
Returns:
消息内容
"""
content = message.get("content") or message.get("text") or message.get("message")
if content:
return str(content).strip()
return ""
def _format_single_long_term_memory(self, memory: Memory) -> str:
"""
格式化单个长期记忆
格式:[事实] 主体-主题+客体属性1内容 属性2内容
Args:
memory: 长期记忆对象
Returns:
格式化后的长期记忆
"""
try:
# 获取记忆类型标签
type_label = self._get_memory_type_label(memory.memory_type)
# 获取主体节点
subject_node = memory.get_subject_node()
if not subject_node:
return ""
subject = subject_node.content
# 查找主题节点
topic_node = None
for edge in memory.edges:
edge_type = edge.edge_type.value if hasattr(edge.edge_type, 'value') else str(edge.edge_type)
if edge_type == "记忆类型" and edge.source_id == memory.subject_id:
topic_node = memory.get_node_by_id(edge.target_id)
break
if not topic_node:
return f"[{type_label}] {subject}"
topic = topic_node.content
# 查找客体和属性
objects = []
attributes = []
for edge in memory.edges:
edge_type = edge.edge_type.value if hasattr(edge.edge_type, 'value') else str(edge.edge_type)
if edge_type == "核心关系" and edge.source_id == topic_node.id:
obj_node = memory.get_node_by_id(edge.target_id)
if obj_node:
if edge.relation and edge.relation != "未知":
objects.append(f"{edge.relation}{obj_node.content}")
else:
objects.append(obj_node.content)
elif edge_type == "属性关系":
attr_node = memory.get_node_by_id(edge.target_id)
if attr_node:
attr_name = edge.relation if edge.relation else "属性"
attributes.append(f"{attr_name}{attr_node.content}")
# 检查节点中的属性
for node in memory.nodes:
if hasattr(node, 'node_type') and str(node.node_type) == "属性":
# 处理 "key=value" 格式的属性
if "=" in node.content:
key, value = node.content.split("=", 1)
attributes.append(f"{key.strip()}{value.strip()}")
else:
attributes.append(f"属性:{node.content}")
# 构建最终格式
result = f"[{type_label}] {subject}-{topic}"
if objects:
result += "-" + "-".join(objects)
if attributes:
result += "" + "".join(attributes) + ""
return result
except Exception as e:
# 如果格式化失败,返回基本描述
return f"[记忆] 格式化失败: {str(e)}"
def _get_memory_type_label(self, memory_type) -> str:
"""
获取记忆类型的中文标签
Args:
memory_type: 记忆类型
Returns:
中文标签
"""
if hasattr(memory_type, 'value'):
type_value = memory_type.value
else:
type_value = str(memory_type)
type_mapping = {
"EVENT": "事件",
"event": "事件",
"事件": "事件",
"FACT": "事实",
"fact": "事实",
"事实": "事实",
"RELATION": "关系",
"relation": "关系",
"关系": "关系",
"OPINION": "观点",
"opinion": "观点",
"观点": "观点",
}
return type_mapping.get(type_value, "事实")
def format_for_context_injection(
self,
query: str,
perceptual_blocks: list[MemoryBlock],
short_term_memories: list[ShortTermMemory],
long_term_memories: list[Memory],
max_perceptual: int = 3,
max_short_term: int = 5,
max_long_term: int = 10
) -> str:
"""
为上下文注入格式化记忆
Args:
query: 用户查询
perceptual_blocks: 感知记忆块列表
short_term_memories: 短期记忆列表
long_term_memories: 长期记忆列表
max_perceptual: 最大感知记忆数量
max_short_term: 最大短期记忆数量
max_long_term: 最大长期记忆数量
Returns:
格式化的上下文
"""
sections = [f"## 用户查询:{query}", ""]
# 限制数量并格式化
limited_perceptual = perceptual_blocks[:max_perceptual]
limited_short_term = short_term_memories[:max_short_term]
limited_long_term = long_term_memories[:max_long_term]
all_tiers_text = self.format_all_tiers(
limited_perceptual,
limited_short_term,
limited_long_term
)
if all_tiers_text:
sections.append("## 相关记忆")
sections.append(all_tiers_text)
return "\n".join(sections)
# 创建全局格式化器实例
memory_formatter = ThreeTierMemoryFormatter()