feat: 批量生成文本embedding,优化兴趣匹配计算逻辑,支持消息兴趣值的批量更新

This commit is contained in:
Windpicker-owo
2025-11-19 16:30:44 +08:00
parent a11d251ec1
commit 14133410e6
15 changed files with 231 additions and 323 deletions

View File

@@ -308,7 +308,7 @@ class UnifiedMemoryManager:
from src.memory_graph.utils.three_tier_formatter import memory_formatter
# 使用新的三级记忆格式化器
perceptual_desc = memory_formatter.format_perceptual_memory(perceptual_blocks)
perceptual_desc = await memory_formatter.format_perceptual_memory(perceptual_blocks)
short_term_desc = memory_formatter.format_short_term_memory(short_term_memories)
# 构建聊天历史块(如果提供)

View File

@@ -1,234 +0,0 @@
"""
记忆格式化工具
提供将记忆对象格式化为提示词的功能,使用 "主体-主题(属性)" 格式。
"""
from src.memory_graph.models import Memory, MemoryNode, NodeType
from src.memory_graph.models import ShortTermMemory
def get_memory_type_label(memory_type: str) -> str:
"""
获取记忆类型的中文标签
Args:
memory_type: 记忆类型(英文)
Returns:
中文标签
"""
type_mapping = {
"事实": "事实",
"事件": "事件",
"观点": "观点",
"关系": "关系",
"目标": "目标",
"计划": "计划",
"fact": "事实",
"event": "事件",
"opinion": "观点",
"relation": "关系",
"goal": "目标",
"plan": "计划",
"unknown": "未知",
}
return type_mapping.get(memory_type.lower(), memory_type)
def format_memory_for_prompt(memory: Memory | ShortTermMemory, include_metadata: bool = True) -> str:
"""
格式化记忆为提示词文本
使用 "主体-主题(属性)" 格式,例如:
- "张三-职业(程序员, 公司=MoFox)"
- "小明-喜欢(Python, 原因=简洁优雅)"
- "拾风-地址(https://mofox.com)"
Args:
memory: Memory 或 ShortTermMemory 对象
include_metadata: 是否包含元数据(如重要性、时间等)
Returns:
格式化后的记忆文本
"""
if isinstance(memory, ShortTermMemory):
return _format_short_term_memory(memory, include_metadata)
elif isinstance(memory, Memory):
return _format_long_term_memory(memory, include_metadata)
else:
return str(memory)
def _format_short_term_memory(mem: ShortTermMemory, include_metadata: bool) -> str:
"""
格式化短期记忆
Args:
mem: ShortTermMemory 对象
include_metadata: 是否包含元数据
Returns:
格式化后的文本
"""
parts = []
# 主体
subject = mem.subject or ""
# 主题
topic = mem.topic or ""
# 客体
obj = mem.object or ""
# 构建基础格式:主体-主题
if subject and topic:
base = f"{subject}-{topic}"
elif subject:
base = subject
elif topic:
base = topic
else:
# 如果没有结构化字段,使用 content
# 防御性编程:确保 content 是字符串
if isinstance(mem.content, list):
return " ".join(str(item) for item in mem.content)
return str(mem.content) if mem.content else ""
# 添加客体和属性
attr_parts = []
if obj:
attr_parts.append(obj)
# 添加属性
if mem.attributes:
for key, value in mem.attributes.items():
if value:
attr_parts.append(f"{key}={value}")
# 组合
if attr_parts:
result = f"{base}({', '.join(attr_parts)})"
else:
result = base
# 添加元数据(可选)
if include_metadata:
metadata_parts = []
if mem.memory_type:
metadata_parts.append(f"类型:{get_memory_type_label(mem.memory_type)}")
if mem.importance > 0:
metadata_parts.append(f"重要性:{mem.importance:.2f}")
if metadata_parts:
result = f"{result} [{', '.join(metadata_parts)}]"
return result
def _format_long_term_memory(mem: Memory, include_metadata: bool) -> str:
"""
格式化长期记忆Memory 对象)
Args:
mem: Memory 对象
include_metadata: 是否包含元数据
Returns:
格式化后的文本
"""
from src.memory_graph.models import EdgeType
# 获取主体节点
subject_node = mem.get_subject_node()
if not subject_node:
return mem.to_text()
subject = subject_node.content
# 查找主题节点
topic_node = None
for edge in mem.edges:
edge_type = edge.edge_type.value if hasattr(edge.edge_type, 'value') else str(edge.edge_type)
if edge_type == "记忆类型" and edge.source_id == mem.subject_id:
topic_node = mem.get_node_by_id(edge.target_id)
break
if not topic_node:
return subject
topic = topic_node.content
# 基础格式:主体-主题
base = f"{subject}-{topic}"
# 收集客体和属性
attr_parts = []
# 查找客体节点(通过核心关系边)
for edge in mem.edges:
edge_type = edge.edge_type.value if hasattr(edge.edge_type, 'value') else str(edge.edge_type)
if edge_type == "核心关系" and edge.source_id == topic_node.id:
obj_node = mem.get_node_by_id(edge.target_id)
if obj_node:
# 如果有关系名称,使用关系名称
if edge.relation and edge.relation != "未知":
attr_parts.append(f"{edge.relation}={obj_node.content}")
else:
attr_parts.append(obj_node.content)
# 查找属性节点
for node in mem.nodes:
if node.node_type == NodeType.ATTRIBUTE:
# 属性节点的 content 格式可能是 "key=value" 或 "value"
attr_parts.append(node.content)
# 组合
if attr_parts:
result = f"{base}({', '.join(attr_parts)})"
else:
result = base
# 添加元数据(可选)
if include_metadata:
metadata_parts = []
if mem.memory_type:
type_value = mem.memory_type.value if hasattr(mem.memory_type, 'value') else str(mem.memory_type)
metadata_parts.append(f"类型:{get_memory_type_label(type_value)}")
if mem.importance > 0:
metadata_parts.append(f"重要性:{mem.importance:.2f}")
if metadata_parts:
result = f"{result} [{', '.join(metadata_parts)}]"
return result
def format_memories_block(
memories: list[Memory | ShortTermMemory],
title: str = "相关记忆",
max_count: int = 10,
include_metadata: bool = False,
) -> str:
"""
格式化多个记忆为提示词块
Args:
memories: 记忆列表
title: 块标题
max_count: 最多显示的记忆数量
include_metadata: 是否包含元数据
Returns:
格式化后的记忆块
"""
if not memories:
return ""
lines = [f"### 🧠 {title}", ""]
for mem in memories[:max_count]:
formatted = format_memory_for_prompt(mem, include_metadata=include_metadata)
if formatted:
lines.append(f"- {formatted}")
return "\n".join(lines)

View File

@@ -22,7 +22,7 @@ class ThreeTierMemoryFormatter:
"""初始化格式化器"""
pass
def format_perceptual_memory(self, blocks: list[MemoryBlock]) -> str:
async def format_perceptual_memory(self, blocks: list[MemoryBlock]) -> str:
"""
格式化感知记忆为提示词
@@ -53,7 +53,7 @@ class ThreeTierMemoryFormatter:
for block in blocks:
# 提取时间和聊天流信息
time_str = self._extract_time_from_block(block)
stream_name = self._extract_stream_name_from_block(block)
stream_name = await self._extract_stream_name_from_block(block)
# 添加块标题
lines.append(f"- 【{time_str} ({stream_name})】")
@@ -122,7 +122,7 @@ class ThreeTierMemoryFormatter:
return "\n".join(lines)
def format_all_tiers(
async def format_all_tiers(
self,
perceptual_blocks: list[MemoryBlock],
short_term_memories: list[ShortTermMemory],
@@ -142,7 +142,7 @@ class ThreeTierMemoryFormatter:
sections = []
# 感知记忆
perceptual_text = self.format_perceptual_memory(perceptual_blocks)
perceptual_text = await self.format_perceptual_memory(perceptual_blocks)
if perceptual_text:
sections.append("### 感知记忆(即时对话)")
sections.append(perceptual_text)
@@ -198,7 +198,7 @@ class ThreeTierMemoryFormatter:
return "未知时间"
def _extract_stream_name_from_block(self, block: MemoryBlock) -> str:
async def _extract_stream_name_from_block(self, block: MemoryBlock) -> str:
"""
从记忆块中提取聊天流名称
@@ -208,18 +208,31 @@ class ThreeTierMemoryFormatter:
Returns:
聊天流名称
"""
# 尝试从元数据中获取
if block.metadata:
stream_name = block.metadata.get("stream_name") or block.metadata.get("chat_stream")
if stream_name:
return stream_name
stream_id = None
# 尝试从消息中提取
if block.messages:
# 首先尝试从元数据中获取 stream_id
if block.metadata:
stream_id = block.metadata.get("stream_id")
# 如果从元数据中没找到,尝试从消息中提取
if not stream_id and block.messages:
first_msg = block.messages[0]
stream_name = first_msg.get("stream_name") or first_msg.get("chat_stream")
if stream_name:
return stream_name
stream_id = first_msg.get("stream_id") or first_msg.get("chat_id")
# 如果有 stream_id尝试获取实际的流名称
if stream_id:
try:
from src.chat.message_receive.chat_stream import get_chat_manager
chat_manager = get_chat_manager()
actual_name = await chat_manager.get_stream_name(stream_id)
if actual_name:
return actual_name
else:
# 如果获取不到名称,返回 stream_id 的截断版本
return stream_id[:12] + "..." if len(stream_id) > 12 else stream_id
except Exception:
# 如果获取失败,返回 stream_id 的截断版本
return stream_id[:12] + "..." if len(stream_id) > 12 else stream_id
return "默认聊天"
@@ -375,7 +388,7 @@ class ThreeTierMemoryFormatter:
return type_mapping.get(type_value, "事实")
def format_for_context_injection(
async def format_for_context_injection(
self,
query: str,
perceptual_blocks: list[MemoryBlock],
@@ -407,7 +420,7 @@ class ThreeTierMemoryFormatter:
limited_short_term = short_term_memories[:max_short_term]
limited_long_term = long_term_memories[:max_long_term]
all_tiers_text = self.format_all_tiers(
all_tiers_text = await self.format_all_tiers(
limited_perceptual,
limited_short_term,
limited_long_term