feat(memory): 增强记忆构建上下文处理能力并优化兴趣度批量更新机制
- 在记忆构建过程中允许检索历史记忆作为上下文补充 - 改进LLM响应解析逻辑,增强JSON提取兼容性 - 优化消息兴趣度计算和批量更新机制,减少数据库写入频率 - 添加构建状态管理,支持在BUILDING状态下进行记忆检索 - 修复stream_id拼写错误处理和历史消息获取逻辑
This commit is contained in:
@@ -12,12 +12,11 @@ from typing import Dict, List, Optional, Set, Any, TYPE_CHECKING
|
||||
from datetime import datetime, timedelta
|
||||
from dataclasses import dataclass, asdict
|
||||
from enum import Enum
|
||||
import numpy as np
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config, global_config
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk
|
||||
from src.chat.memory_system.memory_builder import MemoryBuilder
|
||||
from src.chat.memory_system.memory_fusion import MemoryFusionEngine
|
||||
from src.chat.memory_system.vector_storage import VectorStorageManager, VectorStorageConfig
|
||||
@@ -64,7 +63,6 @@ class MemorySystemConfig:
|
||||
@classmethod
|
||||
def from_global_config(cls):
|
||||
"""从全局配置创建配置实例"""
|
||||
from src.config.config import global_config
|
||||
|
||||
return cls(
|
||||
# 记忆构建配置
|
||||
@@ -130,7 +128,7 @@ class EnhancedMemorySystem:
|
||||
task_config = (
|
||||
self.llm_model.model_for_task
|
||||
if self.llm_model is not None
|
||||
else model_config.model_task_config.utils
|
||||
else model_config.model_task_config.utils_small
|
||||
)
|
||||
|
||||
self.value_assessment_model = LLMRequest(
|
||||
@@ -173,6 +171,49 @@ class EnhancedMemorySystem:
|
||||
logger.error(f"❌ 记忆系统初始化失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def retrieve_memories_for_building(
|
||||
self,
|
||||
query_text: str,
|
||||
user_id: str,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
limit: int = 5
|
||||
) -> List[MemoryChunk]:
|
||||
"""在构建记忆时检索相关记忆,允许在BUILDING状态下进行检索
|
||||
|
||||
Args:
|
||||
query_text: 查询文本
|
||||
user_id: 用户ID
|
||||
context: 上下文信息
|
||||
limit: 返回结果数量限制
|
||||
|
||||
Returns:
|
||||
相关记忆列表
|
||||
"""
|
||||
if self.status not in [MemorySystemStatus.READY, MemorySystemStatus.BUILDING]:
|
||||
logger.warning(f"记忆系统状态不允许检索: {self.status.value}")
|
||||
return []
|
||||
|
||||
try:
|
||||
# 临时切换到检索状态
|
||||
original_status = self.status
|
||||
self.status = MemorySystemStatus.RETRIEVING
|
||||
|
||||
# 执行检索
|
||||
memories = await self.vector_storage.search_similar_memories(
|
||||
query_text=query_text,
|
||||
user_id=user_id,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
# 恢复原始状态
|
||||
self.status = original_status
|
||||
|
||||
return memories
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"构建过程中检索记忆失败: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
async def build_memory_from_conversation(
|
||||
self,
|
||||
conversation_text: str,
|
||||
@@ -191,9 +232,10 @@ class EnhancedMemorySystem:
|
||||
Returns:
|
||||
构建的记忆块列表
|
||||
"""
|
||||
if self.status != MemorySystemStatus.READY:
|
||||
if self.status not in [MemorySystemStatus.READY, MemorySystemStatus.BUILDING]:
|
||||
raise RuntimeError("记忆系统未就绪")
|
||||
|
||||
original_status = self.status
|
||||
self.status = MemorySystemStatus.BUILDING
|
||||
start_time = time.time()
|
||||
|
||||
@@ -222,7 +264,7 @@ class EnhancedMemorySystem:
|
||||
build_marker_time = current_time
|
||||
self._last_memory_build_times[build_scope_key] = current_time
|
||||
|
||||
conversation_text = self._resolve_conversation_context(conversation_text, normalized_context)
|
||||
conversation_text = await self._resolve_conversation_context(conversation_text, normalized_context)
|
||||
|
||||
logger.debug(f"开始为用户 {user_id} 构建记忆,文本长度: {len(conversation_text)}")
|
||||
|
||||
@@ -231,7 +273,7 @@ class EnhancedMemorySystem:
|
||||
|
||||
if value_score < self.config.memory_value_threshold:
|
||||
logger.info(f"信息价值评分 {value_score:.2f} 低于阈值,跳过记忆构建")
|
||||
self.status = MemorySystemStatus.READY
|
||||
self.status = original_status
|
||||
return []
|
||||
|
||||
# 2. 构建记忆块
|
||||
@@ -244,7 +286,7 @@ class EnhancedMemorySystem:
|
||||
|
||||
if not memory_chunks:
|
||||
logger.debug("未提取到有效记忆块")
|
||||
self.status = MemorySystemStatus.READY
|
||||
self.status = original_status
|
||||
return []
|
||||
|
||||
# 3. 记忆融合与去重
|
||||
@@ -262,7 +304,7 @@ class EnhancedMemorySystem:
|
||||
build_time = time.time() - start_time
|
||||
logger.info(f"✅ 为用户 {user_id} 构建了 {len(fused_chunks)} 条记忆,耗时 {build_time:.2f}秒")
|
||||
|
||||
self.status = MemorySystemStatus.READY
|
||||
self.status = original_status
|
||||
return fused_chunks
|
||||
|
||||
except Exception as e:
|
||||
@@ -469,49 +511,79 @@ class EnhancedMemorySystem:
|
||||
|
||||
return context
|
||||
|
||||
def _resolve_conversation_context(self, fallback_text: str, context: Optional[Dict[str, Any]]) -> str:
|
||||
"""使用 stream_id 历史消息充实对话文本,默认回退到传入文本"""
|
||||
async def _resolve_conversation_context(self, fallback_text: str, context: Optional[Dict[str, Any]]) -> str:
|
||||
"""使用 stream_id 历史消息和相关记忆充实对话文本,默认回退到传入文本"""
|
||||
if not context:
|
||||
return fallback_text
|
||||
|
||||
user_id = context.get("user_id")
|
||||
stream_id = context.get("stream_id") or context.get("stram_id")
|
||||
if not stream_id:
|
||||
return fallback_text
|
||||
|
||||
try:
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
# 优先使用 stream_id 获取历史消息
|
||||
if stream_id:
|
||||
try:
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = chat_manager.get_stream(stream_id)
|
||||
if not chat_stream or not hasattr(chat_stream, "context_manager"):
|
||||
logger.debug(f"未找到 stream_id={stream_id} 对应的聊天流或上下文管理器")
|
||||
return fallback_text
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = chat_manager.get_stream(stream_id)
|
||||
if chat_stream and hasattr(chat_stream, "context_manager"):
|
||||
history_limit = self._determine_history_limit(context)
|
||||
messages = chat_stream.context_manager.get_messages(limit=history_limit, include_unread=True)
|
||||
if messages:
|
||||
transcript = self._format_history_messages(messages)
|
||||
if transcript:
|
||||
cleaned_fallback = (fallback_text or "").strip()
|
||||
if cleaned_fallback and cleaned_fallback not in transcript:
|
||||
transcript = f"{transcript}\n[当前消息] {cleaned_fallback}"
|
||||
|
||||
history_limit = self._determine_history_limit(context)
|
||||
messages = chat_stream.context_manager.get_messages(limit=history_limit, include_unread=True)
|
||||
if not messages:
|
||||
logger.debug(f"stream_id={stream_id} 未获取到历史消息")
|
||||
return fallback_text
|
||||
logger.debug(
|
||||
"使用 stream_id=%s 的历史消息构建记忆上下文,消息数=%d,限制=%d",
|
||||
stream_id,
|
||||
len(messages),
|
||||
history_limit,
|
||||
)
|
||||
return transcript
|
||||
else:
|
||||
logger.debug(f"stream_id={stream_id} 历史消息格式化失败")
|
||||
else:
|
||||
logger.debug(f"stream_id={stream_id} 未获取到历史消息")
|
||||
else:
|
||||
logger.debug(f"未找到 stream_id={stream_id} 对应的聊天流或上下文管理器")
|
||||
except Exception as exc:
|
||||
logger.warning(f"获取 stream_id={stream_id} 的历史消息失败: {exc}", exc_info=True)
|
||||
|
||||
transcript = self._format_history_messages(messages)
|
||||
if not transcript:
|
||||
return fallback_text
|
||||
# 如果无法获取历史消息,尝试检索相关记忆作为上下文
|
||||
if user_id and fallback_text:
|
||||
try:
|
||||
relevant_memories = await self.retrieve_memories_for_building(
|
||||
query_text=fallback_text,
|
||||
user_id=user_id,
|
||||
context=context,
|
||||
limit=3
|
||||
)
|
||||
|
||||
cleaned_fallback = (fallback_text or "").strip()
|
||||
if cleaned_fallback and cleaned_fallback not in transcript:
|
||||
transcript = f"{transcript}\n[当前消息] {cleaned_fallback}"
|
||||
if relevant_memories:
|
||||
memory_contexts = []
|
||||
for memory in relevant_memories:
|
||||
memory_contexts.append(f"[历史记忆] {memory.text_content}")
|
||||
|
||||
logger.debug(
|
||||
"使用 stream_id=%s 的历史消息构建记忆上下文,消息数=%d,限制=%d",
|
||||
stream_id,
|
||||
len(messages),
|
||||
history_limit,
|
||||
)
|
||||
return transcript
|
||||
memory_transcript = "\n".join(memory_contexts)
|
||||
cleaned_fallback = (fallback_text or "").strip()
|
||||
if cleaned_fallback and cleaned_fallback not in memory_transcript:
|
||||
memory_transcript = f"{memory_transcript}\n[当前消息] {cleaned_fallback}"
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning(f"获取 stream_id={stream_id} 的历史消息失败: {exc}", exc_info=True)
|
||||
return fallback_text
|
||||
logger.debug(
|
||||
"使用检索到的历史记忆构建记忆上下文,记忆数=%d,用户=%s",
|
||||
len(relevant_memories),
|
||||
user_id
|
||||
)
|
||||
return memory_transcript
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning(f"检索历史记忆作为上下文失败: {exc}", exc_info=True)
|
||||
|
||||
# 回退到传入文本
|
||||
return fallback_text
|
||||
|
||||
def _get_build_scope_key(self, context: Dict[str, Any], user_id: Optional[str]) -> Optional[str]:
|
||||
"""确定用于节流控制的记忆构建作用域"""
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
import re
|
||||
import time
|
||||
import orjson
|
||||
from typing import Dict, List, Optional, Tuple, Any, Set
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
@@ -16,7 +16,7 @@ from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.chat.memory_system.memory_chunk import (
|
||||
MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel,
|
||||
ContentStructure, MemoryMetadata, create_memory_chunk
|
||||
create_memory_chunk
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -334,6 +334,28 @@ class MemoryBuilder:
|
||||
|
||||
return prompt
|
||||
|
||||
def _extract_json_payload(self, response: str) -> Optional[str]:
|
||||
"""从模型响应中提取JSON部分,兼容Markdown代码块等格式"""
|
||||
if not response:
|
||||
return None
|
||||
|
||||
stripped = response.strip()
|
||||
|
||||
# 优先处理Markdown代码块格式 ```json ... ```
|
||||
code_block_match = re.search(r"```(?:json)?\s*(.*?)```", stripped, re.IGNORECASE | re.DOTALL)
|
||||
if code_block_match:
|
||||
candidate = code_block_match.group(1).strip()
|
||||
if candidate:
|
||||
return candidate
|
||||
|
||||
# 回退到查找第一个 JSON 对象的大括号范围
|
||||
start = stripped.find("{")
|
||||
end = stripped.rfind("}")
|
||||
if start != -1 and end != -1 and end > start:
|
||||
return stripped[start:end + 1].strip()
|
||||
|
||||
return stripped if stripped.startswith("{") and stripped.endswith("}") else None
|
||||
|
||||
def _parse_llm_response(
|
||||
self,
|
||||
response: str,
|
||||
@@ -345,7 +367,13 @@ class MemoryBuilder:
|
||||
memories = []
|
||||
|
||||
try:
|
||||
data = orjson.loads(response)
|
||||
# 提取JSON负载
|
||||
json_payload = self._extract_json_payload(response)
|
||||
if not json_payload:
|
||||
logger.error("未在响应中找到有效的JSON负载")
|
||||
return memories
|
||||
|
||||
data = orjson.loads(json_payload)
|
||||
memory_list = data.get("memories", [])
|
||||
|
||||
for mem_data in memory_list:
|
||||
@@ -375,7 +403,8 @@ class MemoryBuilder:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析LLM响应失败: {e}, 响应: {response}")
|
||||
preview = response[:200] if response else "空响应"
|
||||
logger.error(f"解析LLM响应失败: {e}, 响应片段: {preview}")
|
||||
|
||||
return memories
|
||||
|
||||
@@ -623,7 +652,7 @@ class MemoryBuilder:
|
||||
try:
|
||||
# 尝试解析回字典(如果原来是字典)
|
||||
memory.content.object = eval(obj_str) if obj_str.startswith('{') else obj_str
|
||||
except:
|
||||
except Exception:
|
||||
memory.content.object = obj_str
|
||||
|
||||
# 记录时间规范化操作
|
||||
|
||||
Reference in New Issue
Block a user