feat(memory): 增强记忆构建上下文处理能力并优化兴趣度批量更新机制

- 在记忆构建过程中允许检索历史记忆作为上下文补充
- 改进LLM响应解析逻辑,增强JSON提取兼容性
- 优化消息兴趣度计算和批量更新机制,减少数据库写入频率
- 添加构建状态管理,支持在BUILDING状态下进行记忆检索
- 修复stream_id拼写错误处理和历史消息获取逻辑
This commit is contained in:
Windpicker-owo
2025-09-30 14:22:26 +08:00
parent 1ccf50f3c7
commit 0a3c908654
8 changed files with 317 additions and 112 deletions

View File

@@ -12,12 +12,11 @@ from typing import Dict, List, Optional, Set, Any, TYPE_CHECKING
from datetime import datetime, timedelta
from dataclasses import dataclass, asdict
from enum import Enum
import numpy as np
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
from src.chat.memory_system.memory_chunk import MemoryChunk
from src.chat.memory_system.memory_builder import MemoryBuilder
from src.chat.memory_system.memory_fusion import MemoryFusionEngine
from src.chat.memory_system.vector_storage import VectorStorageManager, VectorStorageConfig
@@ -64,7 +63,6 @@ class MemorySystemConfig:
@classmethod
def from_global_config(cls):
"""从全局配置创建配置实例"""
from src.config.config import global_config
return cls(
# 记忆构建配置
@@ -130,7 +128,7 @@ class EnhancedMemorySystem:
task_config = (
self.llm_model.model_for_task
if self.llm_model is not None
else model_config.model_task_config.utils
else model_config.model_task_config.utils_small
)
self.value_assessment_model = LLMRequest(
@@ -173,6 +171,49 @@ class EnhancedMemorySystem:
logger.error(f"❌ 记忆系统初始化失败: {e}", exc_info=True)
raise
async def retrieve_memories_for_building(
self,
query_text: str,
user_id: str,
context: Optional[Dict[str, Any]] = None,
limit: int = 5
) -> List[MemoryChunk]:
"""在构建记忆时检索相关记忆允许在BUILDING状态下进行检索
Args:
query_text: 查询文本
user_id: 用户ID
context: 上下文信息
limit: 返回结果数量限制
Returns:
相关记忆列表
"""
if self.status not in [MemorySystemStatus.READY, MemorySystemStatus.BUILDING]:
logger.warning(f"记忆系统状态不允许检索: {self.status.value}")
return []
try:
# 临时切换到检索状态
original_status = self.status
self.status = MemorySystemStatus.RETRIEVING
# 执行检索
memories = await self.vector_storage.search_similar_memories(
query_text=query_text,
user_id=user_id,
limit=limit
)
# 恢复原始状态
self.status = original_status
return memories
except Exception as e:
logger.error(f"构建过程中检索记忆失败: {e}", exc_info=True)
return []
async def build_memory_from_conversation(
self,
conversation_text: str,
@@ -191,9 +232,10 @@ class EnhancedMemorySystem:
Returns:
构建的记忆块列表
"""
if self.status != MemorySystemStatus.READY:
if self.status not in [MemorySystemStatus.READY, MemorySystemStatus.BUILDING]:
raise RuntimeError("记忆系统未就绪")
original_status = self.status
self.status = MemorySystemStatus.BUILDING
start_time = time.time()
@@ -222,7 +264,7 @@ class EnhancedMemorySystem:
build_marker_time = current_time
self._last_memory_build_times[build_scope_key] = current_time
conversation_text = self._resolve_conversation_context(conversation_text, normalized_context)
conversation_text = await self._resolve_conversation_context(conversation_text, normalized_context)
logger.debug(f"开始为用户 {user_id} 构建记忆,文本长度: {len(conversation_text)}")
@@ -231,7 +273,7 @@ class EnhancedMemorySystem:
if value_score < self.config.memory_value_threshold:
logger.info(f"信息价值评分 {value_score:.2f} 低于阈值,跳过记忆构建")
self.status = MemorySystemStatus.READY
self.status = original_status
return []
# 2. 构建记忆块
@@ -244,7 +286,7 @@ class EnhancedMemorySystem:
if not memory_chunks:
logger.debug("未提取到有效记忆块")
self.status = MemorySystemStatus.READY
self.status = original_status
return []
# 3. 记忆融合与去重
@@ -262,7 +304,7 @@ class EnhancedMemorySystem:
build_time = time.time() - start_time
logger.info(f"✅ 为用户 {user_id} 构建了 {len(fused_chunks)} 条记忆,耗时 {build_time:.2f}")
self.status = MemorySystemStatus.READY
self.status = original_status
return fused_chunks
except Exception as e:
@@ -469,49 +511,79 @@ class EnhancedMemorySystem:
return context
def _resolve_conversation_context(self, fallback_text: str, context: Optional[Dict[str, Any]]) -> str:
"""使用 stream_id 历史消息充实对话文本,默认回退到传入文本"""
async def _resolve_conversation_context(self, fallback_text: str, context: Optional[Dict[str, Any]]) -> str:
"""使用 stream_id 历史消息和相关记忆充实对话文本,默认回退到传入文本"""
if not context:
return fallback_text
user_id = context.get("user_id")
stream_id = context.get("stream_id") or context.get("stram_id")
if not stream_id:
return fallback_text
try:
from src.chat.message_receive.chat_stream import get_chat_manager
# 优先使用 stream_id 获取历史消息
if stream_id:
try:
from src.chat.message_receive.chat_stream import get_chat_manager
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
if not chat_stream or not hasattr(chat_stream, "context_manager"):
logger.debug(f"未找到 stream_id={stream_id} 对应的聊天流或上下文管理器")
return fallback_text
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
if chat_stream and hasattr(chat_stream, "context_manager"):
history_limit = self._determine_history_limit(context)
messages = chat_stream.context_manager.get_messages(limit=history_limit, include_unread=True)
if messages:
transcript = self._format_history_messages(messages)
if transcript:
cleaned_fallback = (fallback_text or "").strip()
if cleaned_fallback and cleaned_fallback not in transcript:
transcript = f"{transcript}\n[当前消息] {cleaned_fallback}"
history_limit = self._determine_history_limit(context)
messages = chat_stream.context_manager.get_messages(limit=history_limit, include_unread=True)
if not messages:
logger.debug(f"stream_id={stream_id} 未获取到历史消息")
return fallback_text
logger.debug(
"使用 stream_id=%s 的历史消息构建记忆上下文,消息数=%d,限制=%d",
stream_id,
len(messages),
history_limit,
)
return transcript
else:
logger.debug(f"stream_id={stream_id} 历史消息格式化失败")
else:
logger.debug(f"stream_id={stream_id} 未获取到历史消息")
else:
logger.debug(f"未找到 stream_id={stream_id} 对应的聊天流或上下文管理器")
except Exception as exc:
logger.warning(f"获取 stream_id={stream_id} 的历史消息失败: {exc}", exc_info=True)
transcript = self._format_history_messages(messages)
if not transcript:
return fallback_text
# 如果无法获取历史消息,尝试检索相关记忆作为上下文
if user_id and fallback_text:
try:
relevant_memories = await self.retrieve_memories_for_building(
query_text=fallback_text,
user_id=user_id,
context=context,
limit=3
)
cleaned_fallback = (fallback_text or "").strip()
if cleaned_fallback and cleaned_fallback not in transcript:
transcript = f"{transcript}\n[当前消息] {cleaned_fallback}"
if relevant_memories:
memory_contexts = []
for memory in relevant_memories:
memory_contexts.append(f"[历史记忆] {memory.text_content}")
logger.debug(
"使用 stream_id=%s 的历史消息构建记忆上下文,消息数=%d,限制=%d",
stream_id,
len(messages),
history_limit,
)
return transcript
memory_transcript = "\n".join(memory_contexts)
cleaned_fallback = (fallback_text or "").strip()
if cleaned_fallback and cleaned_fallback not in memory_transcript:
memory_transcript = f"{memory_transcript}\n[当前消息] {cleaned_fallback}"
except Exception as exc:
logger.warning(f"获取 stream_id={stream_id} 的历史消息失败: {exc}", exc_info=True)
return fallback_text
logger.debug(
"使用检索到的历史记忆构建记忆上下文,记忆数=%d,用户=%s",
len(relevant_memories),
user_id
)
return memory_transcript
except Exception as exc:
logger.warning(f"检索历史记忆作为上下文失败: {exc}", exc_info=True)
# 回退到传入文本
return fallback_text
def _get_build_scope_key(self, context: Dict[str, Any], user_id: Optional[str]) -> Optional[str]:
"""确定用于节流控制的记忆构建作用域"""

View File

@@ -7,7 +7,7 @@
import re
import time
import orjson
from typing import Dict, List, Optional, Tuple, Any, Set
from typing import Dict, List, Optional, Any
from datetime import datetime
from dataclasses import dataclass
from enum import Enum
@@ -16,7 +16,7 @@ from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.chat.memory_system.memory_chunk import (
MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel,
ContentStructure, MemoryMetadata, create_memory_chunk
create_memory_chunk
)
logger = get_logger(__name__)
@@ -334,6 +334,28 @@ class MemoryBuilder:
return prompt
def _extract_json_payload(self, response: str) -> Optional[str]:
"""从模型响应中提取JSON部分兼容Markdown代码块等格式"""
if not response:
return None
stripped = response.strip()
# 优先处理Markdown代码块格式 ```json ... ```
code_block_match = re.search(r"```(?:json)?\s*(.*?)```", stripped, re.IGNORECASE | re.DOTALL)
if code_block_match:
candidate = code_block_match.group(1).strip()
if candidate:
return candidate
# 回退到查找第一个 JSON 对象的大括号范围
start = stripped.find("{")
end = stripped.rfind("}")
if start != -1 and end != -1 and end > start:
return stripped[start:end + 1].strip()
return stripped if stripped.startswith("{") and stripped.endswith("}") else None
def _parse_llm_response(
self,
response: str,
@@ -345,7 +367,13 @@ class MemoryBuilder:
memories = []
try:
data = orjson.loads(response)
# 提取JSON负载
json_payload = self._extract_json_payload(response)
if not json_payload:
logger.error("未在响应中找到有效的JSON负载")
return memories
data = orjson.loads(json_payload)
memory_list = data.get("memories", [])
for mem_data in memory_list:
@@ -375,7 +403,8 @@ class MemoryBuilder:
continue
except Exception as e:
logger.error(f"解析LLM响应失败: {e}, 响应: {response}")
preview = response[:200] if response else "空响应"
logger.error(f"解析LLM响应失败: {e}, 响应片段: {preview}")
return memories
@@ -623,7 +652,7 @@ class MemoryBuilder:
try:
# 尝试解析回字典(如果原来是字典)
memory.content.object = eval(obj_str) if obj_str.startswith('{') else obj_str
except:
except Exception:
memory.content.object = obj_str
# 记录时间规范化操作