feat(memory): 增强记忆构建上下文处理能力并优化兴趣度批量更新机制

- 在记忆构建过程中允许检索历史记忆作为上下文补充
- 改进LLM响应解析逻辑,增强JSON提取兼容性
- 优化消息兴趣度计算和批量更新机制,减少数据库写入频率
- 添加构建状态管理,支持在BUILDING状态下进行记忆检索
- 修复stream_id拼写错误处理和历史消息获取逻辑
This commit is contained in:
Windpicker-owo
2025-09-30 14:22:26 +08:00
parent a997a538b0
commit 41c45889c5
8 changed files with 317 additions and 112 deletions

View File

@@ -12,12 +12,11 @@ from typing import Dict, List, Optional, Set, Any, TYPE_CHECKING
from datetime import datetime, timedelta
from dataclasses import dataclass, asdict
from enum import Enum
import numpy as np
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
from src.chat.memory_system.memory_chunk import MemoryChunk
from src.chat.memory_system.memory_builder import MemoryBuilder
from src.chat.memory_system.memory_fusion import MemoryFusionEngine
from src.chat.memory_system.vector_storage import VectorStorageManager, VectorStorageConfig
@@ -64,7 +63,6 @@ class MemorySystemConfig:
@classmethod
def from_global_config(cls):
"""从全局配置创建配置实例"""
from src.config.config import global_config
return cls(
# 记忆构建配置
@@ -130,7 +128,7 @@ class EnhancedMemorySystem:
task_config = (
self.llm_model.model_for_task
if self.llm_model is not None
else model_config.model_task_config.utils
else model_config.model_task_config.utils_small
)
self.value_assessment_model = LLMRequest(
@@ -173,6 +171,49 @@ class EnhancedMemorySystem:
logger.error(f"❌ 记忆系统初始化失败: {e}", exc_info=True)
raise
async def retrieve_memories_for_building(
self,
query_text: str,
user_id: str,
context: Optional[Dict[str, Any]] = None,
limit: int = 5
) -> List[MemoryChunk]:
"""在构建记忆时检索相关记忆允许在BUILDING状态下进行检索
Args:
query_text: 查询文本
user_id: 用户ID
context: 上下文信息
limit: 返回结果数量限制
Returns:
相关记忆列表
"""
if self.status not in [MemorySystemStatus.READY, MemorySystemStatus.BUILDING]:
logger.warning(f"记忆系统状态不允许检索: {self.status.value}")
return []
try:
# 临时切换到检索状态
original_status = self.status
self.status = MemorySystemStatus.RETRIEVING
# 执行检索
memories = await self.vector_storage.search_similar_memories(
query_text=query_text,
user_id=user_id,
limit=limit
)
# 恢复原始状态
self.status = original_status
return memories
except Exception as e:
logger.error(f"构建过程中检索记忆失败: {e}", exc_info=True)
return []
async def build_memory_from_conversation(
self,
conversation_text: str,
@@ -191,9 +232,10 @@ class EnhancedMemorySystem:
Returns:
构建的记忆块列表
"""
if self.status != MemorySystemStatus.READY:
if self.status not in [MemorySystemStatus.READY, MemorySystemStatus.BUILDING]:
raise RuntimeError("记忆系统未就绪")
original_status = self.status
self.status = MemorySystemStatus.BUILDING
start_time = time.time()
@@ -222,7 +264,7 @@ class EnhancedMemorySystem:
build_marker_time = current_time
self._last_memory_build_times[build_scope_key] = current_time
conversation_text = self._resolve_conversation_context(conversation_text, normalized_context)
conversation_text = await self._resolve_conversation_context(conversation_text, normalized_context)
logger.debug(f"开始为用户 {user_id} 构建记忆,文本长度: {len(conversation_text)}")
@@ -231,7 +273,7 @@ class EnhancedMemorySystem:
if value_score < self.config.memory_value_threshold:
logger.info(f"信息价值评分 {value_score:.2f} 低于阈值,跳过记忆构建")
self.status = MemorySystemStatus.READY
self.status = original_status
return []
# 2. 构建记忆块
@@ -244,7 +286,7 @@ class EnhancedMemorySystem:
if not memory_chunks:
logger.debug("未提取到有效记忆块")
self.status = MemorySystemStatus.READY
self.status = original_status
return []
# 3. 记忆融合与去重
@@ -262,7 +304,7 @@ class EnhancedMemorySystem:
build_time = time.time() - start_time
logger.info(f"✅ 为用户 {user_id} 构建了 {len(fused_chunks)} 条记忆,耗时 {build_time:.2f}")
self.status = MemorySystemStatus.READY
self.status = original_status
return fused_chunks
except Exception as e:
@@ -469,49 +511,79 @@ class EnhancedMemorySystem:
return context
def _resolve_conversation_context(self, fallback_text: str, context: Optional[Dict[str, Any]]) -> str:
"""使用 stream_id 历史消息充实对话文本,默认回退到传入文本"""
async def _resolve_conversation_context(self, fallback_text: str, context: Optional[Dict[str, Any]]) -> str:
"""使用 stream_id 历史消息和相关记忆充实对话文本,默认回退到传入文本"""
if not context:
return fallback_text
user_id = context.get("user_id")
stream_id = context.get("stream_id") or context.get("stram_id")
if not stream_id:
return fallback_text
try:
from src.chat.message_receive.chat_stream import get_chat_manager
# 优先使用 stream_id 获取历史消息
if stream_id:
try:
from src.chat.message_receive.chat_stream import get_chat_manager
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
if not chat_stream or not hasattr(chat_stream, "context_manager"):
logger.debug(f"未找到 stream_id={stream_id} 对应的聊天流或上下文管理器")
return fallback_text
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
if chat_stream and hasattr(chat_stream, "context_manager"):
history_limit = self._determine_history_limit(context)
messages = chat_stream.context_manager.get_messages(limit=history_limit, include_unread=True)
if messages:
transcript = self._format_history_messages(messages)
if transcript:
cleaned_fallback = (fallback_text or "").strip()
if cleaned_fallback and cleaned_fallback not in transcript:
transcript = f"{transcript}\n[当前消息] {cleaned_fallback}"
history_limit = self._determine_history_limit(context)
messages = chat_stream.context_manager.get_messages(limit=history_limit, include_unread=True)
if not messages:
logger.debug(f"stream_id={stream_id} 未获取到历史消息")
return fallback_text
logger.debug(
"使用 stream_id=%s 的历史消息构建记忆上下文,消息数=%d,限制=%d",
stream_id,
len(messages),
history_limit,
)
return transcript
else:
logger.debug(f"stream_id={stream_id} 历史消息格式化失败")
else:
logger.debug(f"stream_id={stream_id} 未获取到历史消息")
else:
logger.debug(f"未找到 stream_id={stream_id} 对应的聊天流或上下文管理器")
except Exception as exc:
logger.warning(f"获取 stream_id={stream_id} 的历史消息失败: {exc}", exc_info=True)
transcript = self._format_history_messages(messages)
if not transcript:
return fallback_text
# 如果无法获取历史消息,尝试检索相关记忆作为上下文
if user_id and fallback_text:
try:
relevant_memories = await self.retrieve_memories_for_building(
query_text=fallback_text,
user_id=user_id,
context=context,
limit=3
)
cleaned_fallback = (fallback_text or "").strip()
if cleaned_fallback and cleaned_fallback not in transcript:
transcript = f"{transcript}\n[当前消息] {cleaned_fallback}"
if relevant_memories:
memory_contexts = []
for memory in relevant_memories:
memory_contexts.append(f"[历史记忆] {memory.text_content}")
logger.debug(
"使用 stream_id=%s 的历史消息构建记忆上下文,消息数=%d,限制=%d",
stream_id,
len(messages),
history_limit,
)
return transcript
memory_transcript = "\n".join(memory_contexts)
cleaned_fallback = (fallback_text or "").strip()
if cleaned_fallback and cleaned_fallback not in memory_transcript:
memory_transcript = f"{memory_transcript}\n[当前消息] {cleaned_fallback}"
except Exception as exc:
logger.warning(f"获取 stream_id={stream_id} 的历史消息失败: {exc}", exc_info=True)
return fallback_text
logger.debug(
"使用检索到的历史记忆构建记忆上下文,记忆数=%d,用户=%s",
len(relevant_memories),
user_id
)
return memory_transcript
except Exception as exc:
logger.warning(f"检索历史记忆作为上下文失败: {exc}", exc_info=True)
# 回退到传入文本
return fallback_text
def _get_build_scope_key(self, context: Dict[str, Any], user_id: Optional[str]) -> Optional[str]:
"""确定用于节流控制的记忆构建作用域"""

View File

@@ -7,7 +7,7 @@
import re
import time
import orjson
from typing import Dict, List, Optional, Tuple, Any, Set
from typing import Dict, List, Optional, Any
from datetime import datetime
from dataclasses import dataclass
from enum import Enum
@@ -16,7 +16,7 @@ from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.chat.memory_system.memory_chunk import (
MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel,
ContentStructure, MemoryMetadata, create_memory_chunk
create_memory_chunk
)
logger = get_logger(__name__)
@@ -334,6 +334,28 @@ class MemoryBuilder:
return prompt
def _extract_json_payload(self, response: str) -> Optional[str]:
"""从模型响应中提取JSON部分兼容Markdown代码块等格式"""
if not response:
return None
stripped = response.strip()
# 优先处理Markdown代码块格式 ```json ... ```
code_block_match = re.search(r"```(?:json)?\s*(.*?)```", stripped, re.IGNORECASE | re.DOTALL)
if code_block_match:
candidate = code_block_match.group(1).strip()
if candidate:
return candidate
# 回退到查找第一个 JSON 对象的大括号范围
start = stripped.find("{")
end = stripped.rfind("}")
if start != -1 and end != -1 and end > start:
return stripped[start:end + 1].strip()
return stripped if stripped.startswith("{") and stripped.endswith("}") else None
def _parse_llm_response(
self,
response: str,
@@ -345,7 +367,13 @@ class MemoryBuilder:
memories = []
try:
data = orjson.loads(response)
# 提取JSON负载
json_payload = self._extract_json_payload(response)
if not json_payload:
logger.error("未在响应中找到有效的JSON负载")
return memories
data = orjson.loads(json_payload)
memory_list = data.get("memories", [])
for mem_data in memory_list:
@@ -375,7 +403,8 @@ class MemoryBuilder:
continue
except Exception as e:
logger.error(f"解析LLM响应失败: {e}, 响应: {response}")
preview = response[:200] if response else "空响应"
logger.error(f"解析LLM响应失败: {e}, 响应片段: {preview}")
return memories
@@ -623,7 +652,7 @@ class MemoryBuilder:
try:
# 尝试解析回字典(如果原来是字典)
memory.content.object = eval(obj_str) if obj_str.startswith('{') else obj_str
except:
except Exception:
memory.content.object = obj_str
# 记录时间规范化操作

View File

@@ -47,22 +47,20 @@ class SingleStreamContextManager:
Args:
message: 消息对象
skip_energy_update: 是否跳过能量更新
skip_energy_update: 是否跳过能量更新(兼容参数,当前忽略)
Returns:
bool: 是否成功添加
"""
try:
self.context.add_message(message)
interest_value = await self._calculate_message_interest(message)
message.interest_value = interest_value
# 推迟兴趣度计算到分发阶段
message.interest_value = getattr(message, "interest_value", None)
self.total_messages += 1
self.last_access_time = time.time()
if not skip_energy_update:
await self._update_stream_energy()
# 启动流的循环任务(如果还未启动)
await stream_loop_manager.start_stream_loop(self.stream_id)
logger.info(f"添加消息到单流上下文: {self.stream_id} (兴趣度: {interest_value:.3f})")
# 启动流的循环任务(如果还未启动)
await stream_loop_manager.start_stream_loop(self.stream_id)
logger.info(f"添加消息到单流上下文: {self.stream_id} (兴趣度待计算)")
return True
except Exception as e:
logger.error(f"添加消息到单流上下文失败 {self.stream_id}: {e}", exc_info=True)
@@ -80,8 +78,6 @@ class SingleStreamContextManager:
"""
try:
self.context.update_message_info(message_id, **updates)
if "interest_value" in updates:
await self._update_stream_energy()
logger.debug(f"更新单流上下文消息: {self.stream_id}/{message_id}")
return True
except Exception as e:
@@ -286,18 +282,16 @@ class SingleStreamContextManager:
try:
self.context.add_message(message)
interest_value = await self._calculate_message_interest_async(message)
message.interest_value = interest_value
# 推迟兴趣度计算到分发阶段
message.interest_value = getattr(message, "interest_value", None)
self.total_messages += 1
self.last_access_time = time.time()
if not skip_energy_update:
await self._update_stream_energy()
# 启动流的循环任务(如果还未启动)
await stream_loop_manager.start_stream_loop(self.stream_id)
# 启动流的循环任务(如果还未启动)
await stream_loop_manager.start_stream_loop(self.stream_id)
logger.debug(f"添加消息到单流上下文(异步): {self.stream_id} (兴趣度: {interest_value:.3f})")
logger.debug(f"添加消息到单流上下文(异步): {self.stream_id} (兴趣度待计算)")
return True
except Exception as e:
logger.error(f"添加消息到单流上下文失败 (async) {self.stream_id}: {e}", exc_info=True)
@@ -307,8 +301,6 @@ class SingleStreamContextManager:
"""异步实现的 update_message更新消息并在需要时 await 能量更新。"""
try:
self.context.update_message_info(message_id, **updates)
if "interest_value" in updates:
await self._update_stream_energy()
logger.debug(f"更新单流上下文消息(异步): {self.stream_id}/{message_id}")
return True
@@ -339,27 +331,31 @@ class SingleStreamContextManager:
logger.error(f"清空单流上下文失败 (async) {self.stream_id}: {e}", exc_info=True)
return False
async def _update_stream_energy(self):
async def refresh_focus_energy_from_history(self) -> None:
"""基于历史消息刷新聚焦能量"""
await self._update_stream_energy(include_unread=False)
async def _update_stream_energy(self, include_unread: bool = False) -> None:
"""更新流能量"""
try:
# 获取所有消息
all_messages = self.get_messages(self.max_context_size)
unread_messages = self.get_unread_messages()
combined_messages = all_messages + unread_messages
history_messages = self.context.get_history_messages(limit=self.max_context_size)
messages: List[DatabaseMessages] = list(history_messages)
# 获取用户ID
if include_unread:
messages.extend(self.get_unread_messages())
# 获取用户ID优先使用最新历史消息
user_id = None
if combined_messages:
last_message = combined_messages[-1]
user_id = last_message.user_info.user_id
if messages:
last_message = messages[-1]
if hasattr(last_message, "user_info") and last_message.user_info:
user_id = last_message.user_info.user_id
# 计算能量
energy = await energy_manager.calculate_focus_energy(
stream_id=self.stream_id, messages=combined_messages, user_id=user_id
await energy_manager.calculate_focus_energy(
stream_id=self.stream_id,
messages=messages,
user_id=user_id,
)
# 更新流循环管理器
# 注意能量更新会通过energy_manager自动同步到流循环管理器
except Exception as e:
logger.error(f"更新单流能量失败 {self.stream_id}: {e}")

View File

@@ -246,6 +246,7 @@ class StreamLoopManager:
success = results.get("success", False)
if success:
await self._refresh_focus_energy(stream_id)
process_time = time.time() - start_time
logger.debug(f"流处理成功: {stream_id} (耗时: {process_time:.2f}s)")
else:
@@ -339,6 +340,20 @@ class StreamLoopManager:
"max_concurrent_streams": self.max_concurrent_streams,
}
async def _refresh_focus_energy(self, stream_id: str) -> None:
"""分发完成后基于历史消息刷新能量值"""
try:
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
if not chat_stream:
logger.debug(f"刷新能量时未找到聊天流: {stream_id}")
return
await chat_stream.context_manager.refresh_focus_energy_from_history()
logger.debug(f"已刷新聊天流 {stream_id} 的聚焦能量")
except Exception as e:
logger.warning(f"刷新聊天流 {stream_id} 能量失败: {e}")
# 全局流循环管理器实例
stream_loop_manager = StreamLoopManager()

View File

@@ -6,7 +6,7 @@
import asyncio
import random
import time
from typing import Dict, Optional, Any, TYPE_CHECKING
from typing import Dict, Optional, Any, TYPE_CHECKING, List
from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages
@@ -125,6 +125,44 @@ class MessageManager:
except Exception as e:
logger.error(f"更新消息 {message_id} 时发生错误: {e}")
async def bulk_update_messages(self, stream_id: str, updates: List[Dict[str, Any]]) -> int:
"""批量更新消息信息,降低更新频率"""
if not updates:
return 0
try:
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
if not chat_stream:
logger.warning(f"MessageManager.bulk_update_messages: 聊天流 {stream_id} 不存在")
return 0
updated_count = 0
for item in updates:
message_id = item.get("message_id")
if not message_id:
continue
payload = {
key: value
for key, value in item.items()
if key != "message_id" and value is not None
}
if not payload:
continue
success = await chat_stream.context_manager.update_message(message_id, payload)
if success:
updated_count += 1
if updated_count:
logger.debug(f"批量更新消息 {updated_count} 条 (stream={stream_id})")
return updated_count
except Exception as e:
logger.error(f"批量更新聊天流 {stream_id} 消息失败: {e}")
return 0
async def add_action(self, stream_id: str, message_id: str, action: str):
"""添加动作到消息"""
try:

View File

@@ -219,7 +219,11 @@ class MessageStorage:
return match.group(0)
@staticmethod
async def update_message_interest_value(message_id: str, interest_value: float) -> None:
async def update_message_interest_value(
message_id: str,
interest_value: float,
should_reply: bool | None = None,
) -> None:
"""
更新数据库中消息的interest_value字段
@@ -230,7 +234,11 @@ class MessageStorage:
try:
async with get_db_session() as session:
# 更新消息的interest_value字段
stmt = update(Messages).where(Messages.message_id == message_id).values(interest_value=interest_value)
values = {"interest_value": interest_value}
if should_reply is not None:
values["should_reply"] = should_reply
stmt = update(Messages).where(Messages.message_id == message_id).values(**values)
result = await session.execute(stmt)
await session.commit()
@@ -243,6 +251,31 @@ class MessageStorage:
logger.error(f"更新消息 {message_id} 的interest_value失败: {e}")
raise
@staticmethod
async def bulk_update_interest_values(
interest_map: dict[str, float],
reply_map: dict[str, bool] | None = None,
) -> None:
"""批量更新消息的兴趣度与回复标记"""
if not interest_map:
return
try:
async with get_db_session() as session:
for message_id, interest_value in interest_map.items():
values = {"interest_value": interest_value}
if reply_map and message_id in reply_map:
values["should_reply"] = reply_map[message_id]
stmt = update(Messages).where(Messages.message_id == message_id).values(**values)
await session.execute(stmt)
await session.commit()
logger.debug(f"批量更新兴趣度 {len(interest_map)} 条记录")
except Exception as e:
logger.error(f"批量更新消息兴趣度失败: {e}")
raise
@staticmethod
async def fix_zero_interest_values(chat_id: str, since_time: float) -> int:
"""

View File

@@ -728,8 +728,6 @@ async def initialize_database():
"autocommit": config.mysql_autocommit,
"charset": config.mysql_charset,
"connect_timeout": config.connection_timeout,
"read_timeout": 30,
"write_timeout": 30,
},
}
)

View File

@@ -5,7 +5,7 @@
from dataclasses import asdict
import time
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from src.plugins.built_in.affinity_flow_chatter.plan_executor import ChatterPlanExecutor
from src.plugins.built_in.affinity_flow_chatter.plan_filter import ChatterPlanFilter
@@ -104,45 +104,33 @@ class ChatterActionPlanner:
score = 0.0
should_reply = False
reply_not_available = False
interest_updates: List[Dict[str, Any]] = []
if unread_messages:
# 为每条消息计算兴趣度
# 为每条消息计算兴趣度,并延迟提交数据库更新
for message in unread_messages:
try:
# 使用插件内部的兴趣度评分系统计算
interest_score = await chatter_interest_scoring_system._calculate_single_message_score(
message=message,
bot_nickname=global_config.bot.nickname
bot_nickname=global_config.bot.nickname,
)
message_interest = interest_score.total_score
# 更新消息的兴趣度
message.interest_value = message_interest
# 简单的回复决策逻辑:兴趣度超过阈值则回复
message.should_reply = message_interest > global_config.affinity_flow.non_reply_action_interest_threshold
logger.info(f"消息 {message.message_id} 兴趣度: {message_interest:.3f}, 应回复: {message.should_reply}")
interest_updates.append(
{
"message_id": message.message_id,
"interest_value": message_interest,
"should_reply": message.should_reply,
}
)
# 更新StreamContext中的消息信息并刷新focus_energy
if context:
from src.chat.message_manager.message_manager import message_manager
await message_manager.update_message(
stream_id=self.chat_id,
message_id=message.message_id,
interest_value=message_interest,
should_reply=message.should_reply
)
logger.debug(
f"消息 {message.message_id} 兴趣度: {message_interest:.3f}, 应回复: {message.should_reply}"
)
# 更新数据库中的消息记录
try:
from src.chat.message_receive.storage import MessageStorage
await MessageStorage.update_message_interest_value(message.message_id, message_interest)
logger.info(f"已更新数据库中消息 {message.message_id} 的兴趣度为: {message_interest:.3f}")
except Exception as e:
logger.warning(f"更新数据库消息兴趣度失败: {e}")
# 记录最高分
if message_interest > score:
score = message_interest
if message.should_reply:
@@ -152,9 +140,18 @@ class ChatterActionPlanner:
except Exception as e:
logger.warning(f"计算消息 {message.message_id} 兴趣度失败: {e}")
# 设置默认值
message.interest_value = 0.0
message.should_reply = False
interest_updates.append(
{
"message_id": message.message_id,
"interest_value": 0.0,
"should_reply": False,
}
)
if interest_updates:
await self._commit_interest_updates(interest_updates)
# 检查兴趣度是否达到非回复动作阈值
non_reply_action_interest_threshold = global_config.affinity_flow.non_reply_action_interest_threshold
@@ -194,6 +191,33 @@ class ChatterActionPlanner:
self.planner_stats["failed_plans"] += 1
return [], None
async def _commit_interest_updates(self, updates: List[Dict[str, Any]]) -> None:
"""统一更新消息兴趣度,减少数据库写入次数"""
if not updates:
return
try:
from src.chat.message_manager.message_manager import message_manager
await message_manager.bulk_update_messages(self.chat_id, updates)
except Exception as e:
logger.warning(f"批量更新上下文消息兴趣度失败: {e}")
try:
from src.chat.message_receive.storage import MessageStorage
interest_map = {item["message_id"]: item["interest_value"] for item in updates if "interest_value" in item}
reply_map = {item["message_id"]: item["should_reply"] for item in updates if "should_reply" in item}
await MessageStorage.bulk_update_interest_values(
interest_map=interest_map,
reply_map=reply_map if reply_map else None,
)
logger.debug(f"已批量更新 {len(interest_map)} 条消息的兴趣度")
except Exception as e:
logger.warning(f"批量更新数据库兴趣度失败: {e}")
def _update_stats_from_execution_result(self, execution_result: Dict[str, any]):
"""根据执行结果更新规划器统计"""
if not execution_result: