feat(memory): 增强记忆构建上下文处理能力并优化兴趣度批量更新机制
- 在记忆构建过程中允许检索历史记忆作为上下文补充 - 改进LLM响应解析逻辑,增强JSON提取兼容性 - 优化消息兴趣度计算和批量更新机制,减少数据库写入频率 - 添加构建状态管理,支持在BUILDING状态下进行记忆检索 - 修复stream_id拼写错误处理和历史消息获取逻辑
This commit is contained in:
@@ -47,22 +47,20 @@ class SingleStreamContextManager:
|
||||
|
||||
Args:
|
||||
message: 消息对象
|
||||
skip_energy_update: 是否跳过能量更新
|
||||
skip_energy_update: 是否跳过能量更新(兼容参数,当前忽略)
|
||||
|
||||
Returns:
|
||||
bool: 是否成功添加
|
||||
"""
|
||||
try:
|
||||
self.context.add_message(message)
|
||||
interest_value = await self._calculate_message_interest(message)
|
||||
message.interest_value = interest_value
|
||||
# 推迟兴趣度计算到分发阶段
|
||||
message.interest_value = getattr(message, "interest_value", None)
|
||||
self.total_messages += 1
|
||||
self.last_access_time = time.time()
|
||||
if not skip_energy_update:
|
||||
await self._update_stream_energy()
|
||||
# 启动流的循环任务(如果还未启动)
|
||||
await stream_loop_manager.start_stream_loop(self.stream_id)
|
||||
logger.info(f"添加消息到单流上下文: {self.stream_id} (兴趣度: {interest_value:.3f})")
|
||||
# 启动流的循环任务(如果还未启动)
|
||||
await stream_loop_manager.start_stream_loop(self.stream_id)
|
||||
logger.info(f"添加消息到单流上下文: {self.stream_id} (兴趣度待计算)")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"添加消息到单流上下文失败 {self.stream_id}: {e}", exc_info=True)
|
||||
@@ -80,8 +78,6 @@ class SingleStreamContextManager:
|
||||
"""
|
||||
try:
|
||||
self.context.update_message_info(message_id, **updates)
|
||||
if "interest_value" in updates:
|
||||
await self._update_stream_energy()
|
||||
logger.debug(f"更新单流上下文消息: {self.stream_id}/{message_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
@@ -286,18 +282,16 @@ class SingleStreamContextManager:
|
||||
try:
|
||||
self.context.add_message(message)
|
||||
|
||||
interest_value = await self._calculate_message_interest_async(message)
|
||||
message.interest_value = interest_value
|
||||
# 推迟兴趣度计算到分发阶段
|
||||
message.interest_value = getattr(message, "interest_value", None)
|
||||
|
||||
self.total_messages += 1
|
||||
self.last_access_time = time.time()
|
||||
|
||||
if not skip_energy_update:
|
||||
await self._update_stream_energy()
|
||||
# 启动流的循环任务(如果还未启动)
|
||||
await stream_loop_manager.start_stream_loop(self.stream_id)
|
||||
# 启动流的循环任务(如果还未启动)
|
||||
await stream_loop_manager.start_stream_loop(self.stream_id)
|
||||
|
||||
logger.debug(f"添加消息到单流上下文(异步): {self.stream_id} (兴趣度: {interest_value:.3f})")
|
||||
logger.debug(f"添加消息到单流上下文(异步): {self.stream_id} (兴趣度待计算)")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"添加消息到单流上下文失败 (async) {self.stream_id}: {e}", exc_info=True)
|
||||
@@ -307,8 +301,6 @@ class SingleStreamContextManager:
|
||||
"""异步实现的 update_message:更新消息并在需要时 await 能量更新。"""
|
||||
try:
|
||||
self.context.update_message_info(message_id, **updates)
|
||||
if "interest_value" in updates:
|
||||
await self._update_stream_energy()
|
||||
|
||||
logger.debug(f"更新单流上下文消息(异步): {self.stream_id}/{message_id}")
|
||||
return True
|
||||
@@ -339,27 +331,31 @@ class SingleStreamContextManager:
|
||||
logger.error(f"清空单流上下文失败 (async) {self.stream_id}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
async def _update_stream_energy(self):
|
||||
async def refresh_focus_energy_from_history(self) -> None:
|
||||
"""基于历史消息刷新聚焦能量"""
|
||||
await self._update_stream_energy(include_unread=False)
|
||||
|
||||
async def _update_stream_energy(self, include_unread: bool = False) -> None:
|
||||
"""更新流能量"""
|
||||
try:
|
||||
# 获取所有消息
|
||||
all_messages = self.get_messages(self.max_context_size)
|
||||
unread_messages = self.get_unread_messages()
|
||||
combined_messages = all_messages + unread_messages
|
||||
history_messages = self.context.get_history_messages(limit=self.max_context_size)
|
||||
messages: List[DatabaseMessages] = list(history_messages)
|
||||
|
||||
# 获取用户ID
|
||||
if include_unread:
|
||||
messages.extend(self.get_unread_messages())
|
||||
|
||||
# 获取用户ID(优先使用最新历史消息)
|
||||
user_id = None
|
||||
if combined_messages:
|
||||
last_message = combined_messages[-1]
|
||||
user_id = last_message.user_info.user_id
|
||||
if messages:
|
||||
last_message = messages[-1]
|
||||
if hasattr(last_message, "user_info") and last_message.user_info:
|
||||
user_id = last_message.user_info.user_id
|
||||
|
||||
# 计算能量
|
||||
energy = await energy_manager.calculate_focus_energy(
|
||||
stream_id=self.stream_id, messages=combined_messages, user_id=user_id
|
||||
await energy_manager.calculate_focus_energy(
|
||||
stream_id=self.stream_id,
|
||||
messages=messages,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# 更新流循环管理器
|
||||
# 注意:能量更新会通过energy_manager自动同步到流循环管理器
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新单流能量失败 {self.stream_id}: {e}")
|
||||
|
||||
@@ -246,6 +246,7 @@ class StreamLoopManager:
|
||||
success = results.get("success", False)
|
||||
|
||||
if success:
|
||||
await self._refresh_focus_energy(stream_id)
|
||||
process_time = time.time() - start_time
|
||||
logger.debug(f"流处理成功: {stream_id} (耗时: {process_time:.2f}s)")
|
||||
else:
|
||||
@@ -339,6 +340,20 @@ class StreamLoopManager:
|
||||
"max_concurrent_streams": self.max_concurrent_streams,
|
||||
}
|
||||
|
||||
async def _refresh_focus_energy(self, stream_id: str) -> None:
|
||||
"""分发完成后基于历史消息刷新能量值"""
|
||||
try:
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = chat_manager.get_stream(stream_id)
|
||||
if not chat_stream:
|
||||
logger.debug(f"刷新能量时未找到聊天流: {stream_id}")
|
||||
return
|
||||
|
||||
await chat_stream.context_manager.refresh_focus_energy_from_history()
|
||||
logger.debug(f"已刷新聊天流 {stream_id} 的聚焦能量")
|
||||
except Exception as e:
|
||||
logger.warning(f"刷新聊天流 {stream_id} 能量失败: {e}")
|
||||
|
||||
|
||||
# 全局流循环管理器实例
|
||||
stream_loop_manager = StreamLoopManager()
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
import asyncio
|
||||
import random
|
||||
import time
|
||||
from typing import Dict, Optional, Any, TYPE_CHECKING
|
||||
from typing import Dict, Optional, Any, TYPE_CHECKING, List
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
@@ -125,6 +125,44 @@ class MessageManager:
|
||||
except Exception as e:
|
||||
logger.error(f"更新消息 {message_id} 时发生错误: {e}")
|
||||
|
||||
async def bulk_update_messages(self, stream_id: str, updates: List[Dict[str, Any]]) -> int:
|
||||
"""批量更新消息信息,降低更新频率"""
|
||||
if not updates:
|
||||
return 0
|
||||
|
||||
try:
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = chat_manager.get_stream(stream_id)
|
||||
if not chat_stream:
|
||||
logger.warning(f"MessageManager.bulk_update_messages: 聊天流 {stream_id} 不存在")
|
||||
return 0
|
||||
|
||||
updated_count = 0
|
||||
for item in updates:
|
||||
message_id = item.get("message_id")
|
||||
if not message_id:
|
||||
continue
|
||||
|
||||
payload = {
|
||||
key: value
|
||||
for key, value in item.items()
|
||||
if key != "message_id" and value is not None
|
||||
}
|
||||
|
||||
if not payload:
|
||||
continue
|
||||
|
||||
success = await chat_stream.context_manager.update_message(message_id, payload)
|
||||
if success:
|
||||
updated_count += 1
|
||||
|
||||
if updated_count:
|
||||
logger.debug(f"批量更新消息 {updated_count} 条 (stream={stream_id})")
|
||||
return updated_count
|
||||
except Exception as e:
|
||||
logger.error(f"批量更新聊天流 {stream_id} 消息失败: {e}")
|
||||
return 0
|
||||
|
||||
async def add_action(self, stream_id: str, message_id: str, action: str):
|
||||
"""添加动作到消息"""
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user