feat: 添加历史消息长度限制,优化上下文管理和消息处理逻辑

This commit is contained in:
Windpicker-owo
2025-11-19 13:59:40 +08:00
parent edc1cd5555
commit 75b806cd41
7 changed files with 245 additions and 71 deletions

View File

@@ -232,7 +232,8 @@ class SingleStreamContextManager:
failed_ids = []
for message_id in message_ids:
try:
self.context.mark_message_as_read(message_id)
# 传递最大历史消息数量限制
self.context.mark_message_as_read(message_id, max_history_size=self.max_context_size)
marked_count += 1
except Exception as e:
failed_ids.append(str(message_id)[:8])
@@ -374,11 +375,11 @@ class SingleStreamContextManager:
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat
# 加载历史消息限制数量为max_context_size的2倍用于丰富上下文
# 加载历史消息限制数量为max_context_size
db_messages = await get_raw_msg_before_timestamp_with_chat(
chat_id=self.stream_id,
timestamp=time.time(),
limit=self.max_context_size * 2,
limit=self.max_context_size,
)
if db_messages:
@@ -401,6 +402,12 @@ class SingleStreamContextManager:
logger.warning(f"转换历史消息失败 (message_id={msg_dict.get('message_id', 'unknown')}): {e}")
continue
# 应用历史消息长度限制
if len(self.context.history_messages) > self.max_context_size:
removed_count = len(self.context.history_messages) - self.max_context_size
self.context.history_messages = self.context.history_messages[-self.max_context_size:]
logger.debug(f"📝 [历史加载] 移除了 {removed_count} 条过旧的历史消息以保持上下文大小限制")
logger.info(f"✅ [历史加载] 成功加载 {loaded_count} 条历史消息到内存: {self.stream_id}")
else:
logger.debug(f"没有历史消息需要加载: {self.stream_id}")

View File

@@ -104,7 +104,19 @@ class HeartFCSender:
# 将MessageSending转换为DatabaseMessages
db_message = await self._convert_to_database_message(message)
if db_message and message.chat_stream.context_manager:
message.chat_stream.context_manager.context.history_messages.append(db_message)
context = message.chat_stream.context_manager.context
# 应用历史消息长度限制
from src.config.config import global_config
max_context_size = getattr(global_config.chat, "max_context_size", 40)
if len(context.history_messages) >= max_context_size:
# 移除最旧的历史消息以保持长度限制
removed_count = 1
context.history_messages = context.history_messages[removed_count:]
logger.debug(f"[{chat_id}] Send API添加前移除了 {removed_count} 条历史消息以保持上下文大小限制")
context.history_messages.append(db_message)
logger.debug(f"[{chat_id}] Send API消息已添加到流上下文: {message_id}")
except Exception as context_error:
logger.warning(f"[{chat_id}] 将Send API消息添加到流上下文失败: {context_error}")

View File

@@ -607,12 +607,12 @@ class DefaultReplyer:
# 添加感知记忆(最近的消息块)
if perceptual_blocks:
memory_parts.append("#### 🌊 感知记忆")
for block in perceptual_blocks[:2]: # 最多显示2个块
for block in perceptual_blocks:
messages = block.messages if hasattr(block, 'messages') else []
if messages:
block_content = "\n".join([
f"{msg.get('sender_name', msg.get('sender_id', ''))}: {msg.get('content', '')[:30]}"
for msg in messages[:3]
for msg in messages
])
memory_parts.append(f"- {block_content}")
memory_parts.append("")
@@ -620,7 +620,7 @@ class DefaultReplyer:
# 添加短期记忆(结构化活跃记忆)
if short_term_memories:
memory_parts.append("#### 💭 短期记忆")
for mem in short_term_memories[:3]: # 最多显示3条
for mem in short_term_memories:
content = format_memory_for_prompt(mem, include_metadata=False)
if content:
memory_parts.append(f"- {content}")
@@ -629,7 +629,7 @@ class DefaultReplyer:
# 添加长期记忆(图谱记忆)
if long_term_memories:
memory_parts.append("#### 🗄️ 长期记忆")
for mem in long_term_memories[:3]: # 最多显示3条
for mem in long_term_memories:
content = format_memory_for_prompt(mem, include_metadata=False)
if content:
memory_parts.append(f"- {content}")

View File

@@ -97,7 +97,7 @@ class StreamContext(BaseDataModel):
message.add_action(action)
break
def mark_message_as_read(self, message_id: str):
def mark_message_as_read(self, message_id: str, max_history_size: int | None = None):
"""标记消息为已读"""
# 先找到要标记的消息(处理 int/str 类型不匹配问题)
message_to_mark = None
@@ -110,6 +110,19 @@ class StreamContext(BaseDataModel):
# 然后移动到历史消息
if message_to_mark:
message_to_mark.is_read = True
# 应用历史消息长度限制
if max_history_size is None:
# 从全局配置获取最大历史消息数量
from src.config.config import global_config
max_history_size = getattr(global_config.chat, "max_context_size", 40)
# 如果历史消息已达到最大长度,移除最旧的消息
if len(self.history_messages) >= max_history_size:
# 移除最旧的历史消息(保持先进先出)
removed_count = len(self.history_messages) - max_history_size + 1
self.history_messages = self.history_messages[removed_count:]
self.history_messages.append(message_to_mark)
self.unread_messages.remove(message_to_mark)

View File

@@ -492,43 +492,44 @@ class LongTermMemoryManager:
try:
success_count = 0
temp_id_map: dict[str, str] = {}
for op in operations:
try:
if op.operation_type == GraphOperationType.CREATE_MEMORY:
await self._execute_create_memory(op, source_stm)
await self._execute_create_memory(op, source_stm, temp_id_map)
success_count += 1
elif op.operation_type == GraphOperationType.UPDATE_MEMORY:
await self._execute_update_memory(op)
await self._execute_update_memory(op, temp_id_map)
success_count += 1
elif op.operation_type == GraphOperationType.MERGE_MEMORIES:
await self._execute_merge_memories(op, source_stm)
await self._execute_merge_memories(op, source_stm, temp_id_map)
success_count += 1
elif op.operation_type == GraphOperationType.CREATE_NODE:
await self._execute_create_node(op)
await self._execute_create_node(op, temp_id_map)
success_count += 1
elif op.operation_type == GraphOperationType.UPDATE_NODE:
await self._execute_update_node(op)
await self._execute_update_node(op, temp_id_map)
success_count += 1
elif op.operation_type == GraphOperationType.MERGE_NODES:
await self._execute_merge_nodes(op)
await self._execute_merge_nodes(op, temp_id_map)
success_count += 1
elif op.operation_type == GraphOperationType.CREATE_EDGE:
await self._execute_create_edge(op)
await self._execute_create_edge(op, temp_id_map)
success_count += 1
elif op.operation_type == GraphOperationType.UPDATE_EDGE:
await self._execute_update_edge(op)
await self._execute_update_edge(op, temp_id_map)
success_count += 1
elif op.operation_type == GraphOperationType.DELETE_EDGE:
await self._execute_delete_edge(op)
await self._execute_delete_edge(op, temp_id_map)
success_count += 1
else:
@@ -544,11 +545,64 @@ class LongTermMemoryManager:
logger.error(f"执行图操作失败: {e}", exc_info=True)
return False
@staticmethod
def _is_placeholder_id(candidate: str | None) -> bool:
if not candidate or not isinstance(candidate, str):
return False
lowered = candidate.strip().lower()
return lowered.startswith(("new_", "temp_"))
def _register_temp_id(
self, placeholder: str | None, actual_id: str, temp_id_map: dict[str, str]
) -> None:
if actual_id and placeholder and self._is_placeholder_id(placeholder):
temp_id_map[placeholder] = actual_id
def _resolve_id(self, raw_id: str | None, temp_id_map: dict[str, str]) -> str | None:
if raw_id is None:
return None
return temp_id_map.get(raw_id, raw_id)
def _resolve_value(self, value: Any, temp_id_map: dict[str, str]) -> Any:
if isinstance(value, str):
return self._resolve_id(value, temp_id_map)
if isinstance(value, list):
return [self._resolve_value(v, temp_id_map) for v in value]
if isinstance(value, dict):
return {k: self._resolve_value(v, temp_id_map) for k, v in value.items()}
return value
def _resolve_parameters(
self, params: dict[str, Any], temp_id_map: dict[str, str]
) -> dict[str, Any]:
return {k: self._resolve_value(v, temp_id_map) for k, v in params.items()}
def _register_aliases_from_params(
self, params: dict[str, Any], actual_id: str, temp_id_map: dict[str, str]
) -> None:
alias_keywords = ("alias", "placeholder", "temp_id", "register_as")
for key, value in params.items():
if isinstance(value, str):
lower_key = key.lower()
if any(keyword in lower_key for keyword in alias_keywords):
self._register_temp_id(value, actual_id, temp_id_map)
elif isinstance(value, list):
lower_key = key.lower()
if any(keyword in lower_key for keyword in alias_keywords):
for item in value:
if isinstance(item, str):
self._register_temp_id(item, actual_id, temp_id_map)
elif isinstance(value, dict):
self._register_aliases_from_params(value, actual_id, temp_id_map)
async def _execute_create_memory(
self, op: GraphOperation, source_stm: ShortTermMemory
self,
op: GraphOperation,
source_stm: ShortTermMemory,
temp_id_map: dict[str, str],
) -> None:
"""执行创建记忆操作"""
params = op.parameters
params = self._resolve_parameters(op.parameters, temp_id_map)
memory = await self.memory_manager.create_memory(
subject=params.get("subject", source_stm.subject or "未知"),
@@ -565,17 +619,26 @@ class LongTermMemoryManager:
memory.metadata["transfer_time"] = datetime.now().isoformat()
logger.info(f"✅ 创建长期记忆: {memory.id} (来自短期记忆 {source_stm.id})")
self._register_temp_id(op.target_id, memory.id, temp_id_map)
self._register_aliases_from_params(op.parameters, memory.id, temp_id_map)
else:
logger.error(f"创建长期记忆失败: {op}")
async def _execute_update_memory(self, op: GraphOperation) -> None:
async def _execute_update_memory(
self, op: GraphOperation, temp_id_map: dict[str, str]
) -> None:
"""执行更新记忆操作"""
memory_id = op.target_id
memory_id = self._resolve_id(op.target_id, temp_id_map)
if not memory_id:
logger.error("更新操作缺少目标记忆ID")
return
updates = op.parameters.get("updated_fields", {})
updates_raw = op.parameters.get("updated_fields", {})
updates = (
self._resolve_parameters(updates_raw, temp_id_map)
if isinstance(updates_raw, dict)
else updates_raw
)
success = await self.memory_manager.update_memory(memory_id, **updates)
@@ -585,12 +648,16 @@ class LongTermMemoryManager:
logger.error(f"更新长期记忆失败: {memory_id}")
async def _execute_merge_memories(
self, op: GraphOperation, source_stm: ShortTermMemory
self,
op: GraphOperation,
source_stm: ShortTermMemory,
temp_id_map: dict[str, str],
) -> None:
"""执行合并记忆操作 (智能合并版)"""
source_ids = op.parameters.get("source_memory_ids", [])
merged_content = op.parameters.get("merged_content", "")
merged_importance = op.parameters.get("merged_importance", source_stm.importance)
params = self._resolve_parameters(op.parameters, temp_id_map)
source_ids = params.get("source_memory_ids", [])
merged_content = params.get("merged_content", "")
merged_importance = params.get("merged_importance", source_stm.importance)
if not source_ids:
logger.warning("合并操作缺少源记忆ID跳过")
@@ -626,9 +693,11 @@ class LongTermMemoryManager:
else:
logger.error(f"合并记忆失败: {source_ids}")
async def _execute_create_node(self, op: GraphOperation) -> None:
async def _execute_create_node(
self, op: GraphOperation, temp_id_map: dict[str, str]
) -> None:
"""执行创建节点操作"""
params = op.parameters
params = self._resolve_parameters(op.parameters, temp_id_map)
content = params.get("content")
node_type = params.get("node_type", "OBJECT")
memory_id = params.get("memory_id")
@@ -652,13 +721,17 @@ class LongTermMemoryManager:
# 尝试为新节点生成 embedding (异步)
asyncio.create_task(self._generate_node_embedding(node_id, content))
logger.info(f"✅ 创建节点: {content} ({node_type}) -> {memory_id}")
self._register_temp_id(op.target_id, node_id, temp_id_map)
self._register_aliases_from_params(op.parameters, node_id, temp_id_map)
else:
logger.error(f"创建节点失败: {op}")
async def _execute_update_node(self, op: GraphOperation) -> None:
async def _execute_update_node(
self, op: GraphOperation, temp_id_map: dict[str, str]
) -> None:
"""执行更新节点操作"""
node_id = op.target_id
params = op.parameters
node_id = self._resolve_id(op.target_id, temp_id_map)
params = self._resolve_parameters(op.parameters, temp_id_map)
updated_content = params.get("updated_content")
if not node_id:
@@ -675,9 +748,11 @@ class LongTermMemoryManager:
else:
logger.error(f"更新节点失败: {node_id}")
async def _execute_merge_nodes(self, op: GraphOperation) -> None:
async def _execute_merge_nodes(
self, op: GraphOperation, temp_id_map: dict[str, str]
) -> None:
"""执行合并节点操作"""
params = op.parameters
params = self._resolve_parameters(op.parameters, temp_id_map)
source_node_ids = params.get("source_node_ids", [])
merged_content = params.get("merged_content")
@@ -698,9 +773,11 @@ class LongTermMemoryManager:
logger.info(f"✅ 合并节点: {sources} -> {target_id}")
async def _execute_create_edge(self, op: GraphOperation) -> None:
async def _execute_create_edge(
self, op: GraphOperation, temp_id_map: dict[str, str]
) -> None:
"""执行创建边操作"""
params = op.parameters
params = self._resolve_parameters(op.parameters, temp_id_map)
source_id = params.get("source_node_id")
target_id = params.get("target_node_id")
relation = params.get("relation", "related")
@@ -725,10 +802,12 @@ class LongTermMemoryManager:
else:
logger.error(f"创建边失败: {op}")
async def _execute_update_edge(self, op: GraphOperation) -> None:
async def _execute_update_edge(
self, op: GraphOperation, temp_id_map: dict[str, str]
) -> None:
"""执行更新边操作"""
edge_id = op.target_id
params = op.parameters
edge_id = self._resolve_id(op.target_id, temp_id_map)
params = self._resolve_parameters(op.parameters, temp_id_map)
updated_relation = params.get("updated_relation")
updated_importance = params.get("updated_importance")
@@ -747,9 +826,11 @@ class LongTermMemoryManager:
else:
logger.error(f"更新边失败: {edge_id}")
async def _execute_delete_edge(self, op: GraphOperation) -> None:
async def _execute_delete_edge(
self, op: GraphOperation, temp_id_map: dict[str, str]
) -> None:
"""执行删除边操作"""
edge_id = op.target_id
edge_id = self._resolve_id(op.target_id, temp_id_map)
if not edge_id:
logger.warning("删除边失败: 缺少 edge_id")

View File

@@ -163,6 +163,7 @@ async def initialize_unified_memory_manager():
long_term_batch_size=getattr(config, "long_term_batch_size", 10),
long_term_search_top_k=getattr(config, "long_term_search_top_k", 5),
long_term_decay_factor=getattr(config, "long_term_decay_factor", 0.95),
long_term_auto_transfer_interval=getattr(config, "long_term_auto_transfer_interval", 600),
# 智能检索配置
judge_confidence_threshold=getattr(config, "judge_confidence_threshold", 0.7),
)

View File

@@ -10,6 +10,7 @@
"""
import asyncio
import time
from datetime import datetime
from pathlib import Path
from typing import Any
@@ -47,6 +48,7 @@ class UnifiedMemoryManager:
long_term_batch_size: int = 10,
long_term_search_top_k: int = 5,
long_term_decay_factor: float = 0.95,
long_term_auto_transfer_interval: int = 600,
# 智能检索配置
judge_confidence_threshold: float = 0.7,
):
@@ -65,6 +67,7 @@ class UnifiedMemoryManager:
long_term_batch_size: 批量处理的短期记忆数量
long_term_search_top_k: 检索相似记忆的数量
long_term_decay_factor: 长期记忆的衰减因子
long_term_auto_transfer_interval: 自动转移间隔(秒)
judge_confidence_threshold: 裁判模型的置信度阈值
"""
self.data_dir = data_dir or Path("data/memory_graph/three_tier")
@@ -104,6 +107,9 @@ class UnifiedMemoryManager:
# 状态
self._initialized = False
self._auto_transfer_task: asyncio.Task | None = None
self._auto_transfer_interval = max(10.0, float(long_term_auto_transfer_interval))
self._max_transfer_delay = min(max(30.0, self._auto_transfer_interval), 300.0)
self._transfer_wakeup_event: asyncio.Event | None = None
logger.info("统一记忆管理器已创建")
@@ -428,6 +434,31 @@ class UnifiedMemoryManager:
task.add_done_callback(_callback)
def _trigger_transfer_wakeup(self) -> None:
"""通知自动转移任务立即检查缓存"""
if self._transfer_wakeup_event and not self._transfer_wakeup_event.is_set():
self._transfer_wakeup_event.set()
def _calculate_auto_sleep_interval(self) -> float:
"""根据短期内存压力计算自适应等待间隔"""
base_interval = self._auto_transfer_interval
if not getattr(self, "short_term_manager", None):
return base_interval
max_memories = max(1, getattr(self.short_term_manager, "max_memories", 1))
occupancy = len(self.short_term_manager.memories) / max_memories
if occupancy >= 0.9:
return max(5.0, base_interval * 0.1)
if occupancy >= 0.75:
return max(10.0, base_interval * 0.2)
if occupancy >= 0.5:
return max(15.0, base_interval * 0.4)
if occupancy >= 0.3:
return max(20.0, base_interval * 0.6)
return base_interval
async def _transfer_blocks_to_short_term(self, blocks: list[MemoryBlock]) -> None:
"""实际转换逻辑在后台执行"""
logger.info(f"正在后台处理 {len(blocks)} 个感知记忆块")
@@ -438,6 +469,7 @@ class UnifiedMemoryManager:
continue
await self.perceptual_manager.remove_block(block.id)
self._trigger_transfer_wakeup()
logger.info(f"✓ 记忆块 {block.id} 已被转移到短期记忆 {stm.id}")
except Exception as exc:
logger.error(f"后台转移失败,记忆块 {block.id}: {exc}", exc_info=True)
@@ -519,64 +551,92 @@ class UnifiedMemoryManager:
logger.warning("自动转移任务已在运行")
return
if self._transfer_wakeup_event is None:
self._transfer_wakeup_event = asyncio.Event()
else:
self._transfer_wakeup_event.clear()
self._auto_transfer_task = asyncio.create_task(self._auto_transfer_loop())
logger.info("自动转移任务已启动")
async def _auto_transfer_loop(self) -> None:
"""自动转移循环(批量缓存模式)"""
transfer_cache = [] # 缓存待转移的短期记忆
cache_size_threshold = self._config["long_term"]["batch_size"] # 使用配置的批量大小
transfer_cache: list[ShortTermMemory] = []
cached_ids: set[str] = set()
cache_size_threshold = max(1, self._config["long_term"].get("batch_size", 1))
last_transfer_time = time.monotonic()
while True:
try:
# 每 10 分钟检查一次
await asyncio.sleep(600)
sleep_interval = self._calculate_auto_sleep_interval()
if self._transfer_wakeup_event is not None:
try:
await asyncio.wait_for(
self._transfer_wakeup_event.wait(),
timeout=sleep_interval,
)
self._transfer_wakeup_event.clear()
except asyncio.TimeoutError:
pass
else:
await asyncio.sleep(sleep_interval)
# 检查短期记忆是否有需要转移的
memories_to_transfer = self.short_term_manager.get_memories_for_transfer()
if memories_to_transfer:
# 添加到缓存
transfer_cache.extend(memories_to_transfer)
logger.info(
f"缓存待转移记忆: 新增{len(memories_to_transfer)}条, "
f"缓存总数{len(transfer_cache)}/{cache_size_threshold}"
)
# 检查是否达到批量转移阈值或短期记忆已满
added = 0
for memory in memories_to_transfer:
mem_id = getattr(memory, "id", None)
if mem_id and mem_id in cached_ids:
continue
transfer_cache.append(memory)
if mem_id:
cached_ids.add(mem_id)
added += 1
if added:
logger.info(
f"自动转移缓存: 新增{added}条, 当前缓存{len(transfer_cache)}/{cache_size_threshold}"
)
max_memories = max(1, getattr(self.short_term_manager, 'max_memories', 1))
occupancy_ratio = len(self.short_term_manager.memories) / max_memories
time_since_last_transfer = time.monotonic() - last_transfer_time
should_transfer = (
len(transfer_cache) >= cache_size_threshold or
len(self.short_term_manager.memories) >= self.short_term_manager.max_memories
len(transfer_cache) >= cache_size_threshold
or occupancy_ratio >= 0.85
or (transfer_cache and time_since_last_transfer >= self._max_transfer_delay)
or len(self.short_term_manager.memories) >= self.short_term_manager.max_memories
)
if should_transfer and transfer_cache:
logger.info(f"触发批量转移: {len(transfer_cache)}条短期记忆→长期记忆")
# 执行批量转移
result = await self.long_term_manager.transfer_from_short_term(
transfer_cache
logger.info(
f"准备批量转移: {len(transfer_cache)}条短期记忆到长期记忆 (占用率 {occupancy_ratio:.0%})"
)
# 清除已转移的记忆
result = await self.long_term_manager.transfer_from_short_term(list(transfer_cache))
if result.get("transferred_memory_ids"):
await self.short_term_manager.clear_transferred_memories(
result["transferred_memory_ids"]
)
# 从缓存中移除已转移的
transferred_ids = set(result["transferred_memory_ids"])
transfer_cache = [
m for m in transfer_cache
if m.id not in transferred_ids
m
for m in transfer_cache
if getattr(m, "id", None) not in transferred_ids
]
cached_ids.difference_update(transferred_ids)
last_transfer_time = time.monotonic()
logger.info(f"✅ 批量转移完成: {result}")
except asyncio.CancelledError:
logger.info("自动转移任务已取消")
logger.info("自动转移循环被取消")
break
except Exception as e:
logger.error(f"自动转移任务错误: {e}", exc_info=True)
# 继续运行
logger.error(f"自动转移循环异常: {e}", exc_info=True)
async def manual_transfer(self) -> dict[str, Any]:
"""