feat(memory): 增强记忆构建上下文处理能力并优化兴趣度批量更新机制

- 在记忆构建过程中允许检索历史记忆作为上下文补充
- 改进LLM响应解析逻辑,增强JSON提取兼容性
- 优化消息兴趣度计算和批量更新机制,减少数据库写入频率
- 添加构建状态管理,支持在BUILDING状态下进行记忆检索
- 修复stream_id拼写错误处理和历史消息获取逻辑
This commit is contained in:
Windpicker-owo
2025-09-30 14:22:26 +08:00
parent 1ccf50f3c7
commit 0a3c908654
8 changed files with 317 additions and 112 deletions

View File

@@ -12,12 +12,11 @@ from typing import Dict, List, Optional, Set, Any, TYPE_CHECKING
from datetime import datetime, timedelta from datetime import datetime, timedelta
from dataclasses import dataclass, asdict from dataclasses import dataclass, asdict
from enum import Enum from enum import Enum
import numpy as np
from src.common.logger import get_logger from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config from src.config.config import model_config, global_config
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType from src.chat.memory_system.memory_chunk import MemoryChunk
from src.chat.memory_system.memory_builder import MemoryBuilder from src.chat.memory_system.memory_builder import MemoryBuilder
from src.chat.memory_system.memory_fusion import MemoryFusionEngine from src.chat.memory_system.memory_fusion import MemoryFusionEngine
from src.chat.memory_system.vector_storage import VectorStorageManager, VectorStorageConfig from src.chat.memory_system.vector_storage import VectorStorageManager, VectorStorageConfig
@@ -64,7 +63,6 @@ class MemorySystemConfig:
@classmethod @classmethod
def from_global_config(cls): def from_global_config(cls):
"""从全局配置创建配置实例""" """从全局配置创建配置实例"""
from src.config.config import global_config
return cls( return cls(
# 记忆构建配置 # 记忆构建配置
@@ -130,7 +128,7 @@ class EnhancedMemorySystem:
task_config = ( task_config = (
self.llm_model.model_for_task self.llm_model.model_for_task
if self.llm_model is not None if self.llm_model is not None
else model_config.model_task_config.utils else model_config.model_task_config.utils_small
) )
self.value_assessment_model = LLMRequest( self.value_assessment_model = LLMRequest(
@@ -173,6 +171,49 @@ class EnhancedMemorySystem:
logger.error(f"❌ 记忆系统初始化失败: {e}", exc_info=True) logger.error(f"❌ 记忆系统初始化失败: {e}", exc_info=True)
raise raise
async def retrieve_memories_for_building(
self,
query_text: str,
user_id: str,
context: Optional[Dict[str, Any]] = None,
limit: int = 5
) -> List[MemoryChunk]:
"""在构建记忆时检索相关记忆允许在BUILDING状态下进行检索
Args:
query_text: 查询文本
user_id: 用户ID
context: 上下文信息
limit: 返回结果数量限制
Returns:
相关记忆列表
"""
if self.status not in [MemorySystemStatus.READY, MemorySystemStatus.BUILDING]:
logger.warning(f"记忆系统状态不允许检索: {self.status.value}")
return []
try:
# 临时切换到检索状态
original_status = self.status
self.status = MemorySystemStatus.RETRIEVING
# 执行检索
memories = await self.vector_storage.search_similar_memories(
query_text=query_text,
user_id=user_id,
limit=limit
)
# 恢复原始状态
self.status = original_status
return memories
except Exception as e:
logger.error(f"构建过程中检索记忆失败: {e}", exc_info=True)
return []
async def build_memory_from_conversation( async def build_memory_from_conversation(
self, self,
conversation_text: str, conversation_text: str,
@@ -191,9 +232,10 @@ class EnhancedMemorySystem:
Returns: Returns:
构建的记忆块列表 构建的记忆块列表
""" """
if self.status != MemorySystemStatus.READY: if self.status not in [MemorySystemStatus.READY, MemorySystemStatus.BUILDING]:
raise RuntimeError("记忆系统未就绪") raise RuntimeError("记忆系统未就绪")
original_status = self.status
self.status = MemorySystemStatus.BUILDING self.status = MemorySystemStatus.BUILDING
start_time = time.time() start_time = time.time()
@@ -222,7 +264,7 @@ class EnhancedMemorySystem:
build_marker_time = current_time build_marker_time = current_time
self._last_memory_build_times[build_scope_key] = current_time self._last_memory_build_times[build_scope_key] = current_time
conversation_text = self._resolve_conversation_context(conversation_text, normalized_context) conversation_text = await self._resolve_conversation_context(conversation_text, normalized_context)
logger.debug(f"开始为用户 {user_id} 构建记忆,文本长度: {len(conversation_text)}") logger.debug(f"开始为用户 {user_id} 构建记忆,文本长度: {len(conversation_text)}")
@@ -231,7 +273,7 @@ class EnhancedMemorySystem:
if value_score < self.config.memory_value_threshold: if value_score < self.config.memory_value_threshold:
logger.info(f"信息价值评分 {value_score:.2f} 低于阈值,跳过记忆构建") logger.info(f"信息价值评分 {value_score:.2f} 低于阈值,跳过记忆构建")
self.status = MemorySystemStatus.READY self.status = original_status
return [] return []
# 2. 构建记忆块 # 2. 构建记忆块
@@ -244,7 +286,7 @@ class EnhancedMemorySystem:
if not memory_chunks: if not memory_chunks:
logger.debug("未提取到有效记忆块") logger.debug("未提取到有效记忆块")
self.status = MemorySystemStatus.READY self.status = original_status
return [] return []
# 3. 记忆融合与去重 # 3. 记忆融合与去重
@@ -262,7 +304,7 @@ class EnhancedMemorySystem:
build_time = time.time() - start_time build_time = time.time() - start_time
logger.info(f"✅ 为用户 {user_id} 构建了 {len(fused_chunks)} 条记忆,耗时 {build_time:.2f}") logger.info(f"✅ 为用户 {user_id} 构建了 {len(fused_chunks)} 条记忆,耗时 {build_time:.2f}")
self.status = MemorySystemStatus.READY self.status = original_status
return fused_chunks return fused_chunks
except Exception as e: except Exception as e:
@@ -469,49 +511,79 @@ class EnhancedMemorySystem:
return context return context
def _resolve_conversation_context(self, fallback_text: str, context: Optional[Dict[str, Any]]) -> str: async def _resolve_conversation_context(self, fallback_text: str, context: Optional[Dict[str, Any]]) -> str:
"""使用 stream_id 历史消息充实对话文本,默认回退到传入文本""" """使用 stream_id 历史消息和相关记忆充实对话文本,默认回退到传入文本"""
if not context: if not context:
return fallback_text return fallback_text
user_id = context.get("user_id")
stream_id = context.get("stream_id") or context.get("stram_id") stream_id = context.get("stream_id") or context.get("stram_id")
if not stream_id:
return fallback_text
try: # 优先使用 stream_id 获取历史消息
from src.chat.message_receive.chat_stream import get_chat_manager if stream_id:
try:
from src.chat.message_receive.chat_stream import get_chat_manager
chat_manager = get_chat_manager() chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id) chat_stream = chat_manager.get_stream(stream_id)
if not chat_stream or not hasattr(chat_stream, "context_manager"): if chat_stream and hasattr(chat_stream, "context_manager"):
logger.debug(f"未找到 stream_id={stream_id} 对应的聊天流或上下文管理器") history_limit = self._determine_history_limit(context)
return fallback_text messages = chat_stream.context_manager.get_messages(limit=history_limit, include_unread=True)
if messages:
transcript = self._format_history_messages(messages)
if transcript:
cleaned_fallback = (fallback_text or "").strip()
if cleaned_fallback and cleaned_fallback not in transcript:
transcript = f"{transcript}\n[当前消息] {cleaned_fallback}"
history_limit = self._determine_history_limit(context) logger.debug(
messages = chat_stream.context_manager.get_messages(limit=history_limit, include_unread=True) "使用 stream_id=%s 的历史消息构建记忆上下文,消息数=%d,限制=%d",
if not messages: stream_id,
logger.debug(f"stream_id={stream_id} 未获取到历史消息") len(messages),
return fallback_text history_limit,
)
return transcript
else:
logger.debug(f"stream_id={stream_id} 历史消息格式化失败")
else:
logger.debug(f"stream_id={stream_id} 未获取到历史消息")
else:
logger.debug(f"未找到 stream_id={stream_id} 对应的聊天流或上下文管理器")
except Exception as exc:
logger.warning(f"获取 stream_id={stream_id} 的历史消息失败: {exc}", exc_info=True)
transcript = self._format_history_messages(messages) # 如果无法获取历史消息,尝试检索相关记忆作为上下文
if not transcript: if user_id and fallback_text:
return fallback_text try:
relevant_memories = await self.retrieve_memories_for_building(
query_text=fallback_text,
user_id=user_id,
context=context,
limit=3
)
cleaned_fallback = (fallback_text or "").strip() if relevant_memories:
if cleaned_fallback and cleaned_fallback not in transcript: memory_contexts = []
transcript = f"{transcript}\n[当前消息] {cleaned_fallback}" for memory in relevant_memories:
memory_contexts.append(f"[历史记忆] {memory.text_content}")
logger.debug( memory_transcript = "\n".join(memory_contexts)
"使用 stream_id=%s 的历史消息构建记忆上下文,消息数=%d,限制=%d", cleaned_fallback = (fallback_text or "").strip()
stream_id, if cleaned_fallback and cleaned_fallback not in memory_transcript:
len(messages), memory_transcript = f"{memory_transcript}\n[当前消息] {cleaned_fallback}"
history_limit,
)
return transcript
except Exception as exc: logger.debug(
logger.warning(f"获取 stream_id={stream_id} 的历史消息失败: {exc}", exc_info=True) "使用检索到的历史记忆构建记忆上下文,记忆数=%d,用户=%s",
return fallback_text len(relevant_memories),
user_id
)
return memory_transcript
except Exception as exc:
logger.warning(f"检索历史记忆作为上下文失败: {exc}", exc_info=True)
# 回退到传入文本
return fallback_text
def _get_build_scope_key(self, context: Dict[str, Any], user_id: Optional[str]) -> Optional[str]: def _get_build_scope_key(self, context: Dict[str, Any], user_id: Optional[str]) -> Optional[str]:
"""确定用于节流控制的记忆构建作用域""" """确定用于节流控制的记忆构建作用域"""

View File

@@ -7,7 +7,7 @@
import re import re
import time import time
import orjson import orjson
from typing import Dict, List, Optional, Tuple, Any, Set from typing import Dict, List, Optional, Any
from datetime import datetime from datetime import datetime
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
@@ -16,7 +16,7 @@ from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.chat.memory_system.memory_chunk import ( from src.chat.memory_system.memory_chunk import (
MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel, MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel,
ContentStructure, MemoryMetadata, create_memory_chunk create_memory_chunk
) )
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -334,6 +334,28 @@ class MemoryBuilder:
return prompt return prompt
def _extract_json_payload(self, response: str) -> Optional[str]:
"""从模型响应中提取JSON部分兼容Markdown代码块等格式"""
if not response:
return None
stripped = response.strip()
# 优先处理Markdown代码块格式 ```json ... ```
code_block_match = re.search(r"```(?:json)?\s*(.*?)```", stripped, re.IGNORECASE | re.DOTALL)
if code_block_match:
candidate = code_block_match.group(1).strip()
if candidate:
return candidate
# 回退到查找第一个 JSON 对象的大括号范围
start = stripped.find("{")
end = stripped.rfind("}")
if start != -1 and end != -1 and end > start:
return stripped[start:end + 1].strip()
return stripped if stripped.startswith("{") and stripped.endswith("}") else None
def _parse_llm_response( def _parse_llm_response(
self, self,
response: str, response: str,
@@ -345,7 +367,13 @@ class MemoryBuilder:
memories = [] memories = []
try: try:
data = orjson.loads(response) # 提取JSON负载
json_payload = self._extract_json_payload(response)
if not json_payload:
logger.error("未在响应中找到有效的JSON负载")
return memories
data = orjson.loads(json_payload)
memory_list = data.get("memories", []) memory_list = data.get("memories", [])
for mem_data in memory_list: for mem_data in memory_list:
@@ -375,7 +403,8 @@ class MemoryBuilder:
continue continue
except Exception as e: except Exception as e:
logger.error(f"解析LLM响应失败: {e}, 响应: {response}") preview = response[:200] if response else "空响应"
logger.error(f"解析LLM响应失败: {e}, 响应片段: {preview}")
return memories return memories
@@ -623,7 +652,7 @@ class MemoryBuilder:
try: try:
# 尝试解析回字典(如果原来是字典) # 尝试解析回字典(如果原来是字典)
memory.content.object = eval(obj_str) if obj_str.startswith('{') else obj_str memory.content.object = eval(obj_str) if obj_str.startswith('{') else obj_str
except: except Exception:
memory.content.object = obj_str memory.content.object = obj_str
# 记录时间规范化操作 # 记录时间规范化操作

View File

@@ -47,22 +47,20 @@ class SingleStreamContextManager:
Args: Args:
message: 消息对象 message: 消息对象
skip_energy_update: 是否跳过能量更新 skip_energy_update: 是否跳过能量更新(兼容参数,当前忽略)
Returns: Returns:
bool: 是否成功添加 bool: 是否成功添加
""" """
try: try:
self.context.add_message(message) self.context.add_message(message)
interest_value = await self._calculate_message_interest(message) # 推迟兴趣度计算到分发阶段
message.interest_value = interest_value message.interest_value = getattr(message, "interest_value", None)
self.total_messages += 1 self.total_messages += 1
self.last_access_time = time.time() self.last_access_time = time.time()
if not skip_energy_update: # 启动流的循环任务(如果还未启动)
await self._update_stream_energy() await stream_loop_manager.start_stream_loop(self.stream_id)
# 启动流的循环任务(如果还未启动) logger.info(f"添加消息到单流上下文: {self.stream_id} (兴趣度待计算)")
await stream_loop_manager.start_stream_loop(self.stream_id)
logger.info(f"添加消息到单流上下文: {self.stream_id} (兴趣度: {interest_value:.3f})")
return True return True
except Exception as e: except Exception as e:
logger.error(f"添加消息到单流上下文失败 {self.stream_id}: {e}", exc_info=True) logger.error(f"添加消息到单流上下文失败 {self.stream_id}: {e}", exc_info=True)
@@ -80,8 +78,6 @@ class SingleStreamContextManager:
""" """
try: try:
self.context.update_message_info(message_id, **updates) self.context.update_message_info(message_id, **updates)
if "interest_value" in updates:
await self._update_stream_energy()
logger.debug(f"更新单流上下文消息: {self.stream_id}/{message_id}") logger.debug(f"更新单流上下文消息: {self.stream_id}/{message_id}")
return True return True
except Exception as e: except Exception as e:
@@ -286,18 +282,16 @@ class SingleStreamContextManager:
try: try:
self.context.add_message(message) self.context.add_message(message)
interest_value = await self._calculate_message_interest_async(message) # 推迟兴趣度计算到分发阶段
message.interest_value = interest_value message.interest_value = getattr(message, "interest_value", None)
self.total_messages += 1 self.total_messages += 1
self.last_access_time = time.time() self.last_access_time = time.time()
if not skip_energy_update: # 启动流的循环任务(如果还未启动)
await self._update_stream_energy() await stream_loop_manager.start_stream_loop(self.stream_id)
# 启动流的循环任务(如果还未启动)
await stream_loop_manager.start_stream_loop(self.stream_id)
logger.debug(f"添加消息到单流上下文(异步): {self.stream_id} (兴趣度: {interest_value:.3f})") logger.debug(f"添加消息到单流上下文(异步): {self.stream_id} (兴趣度待计算)")
return True return True
except Exception as e: except Exception as e:
logger.error(f"添加消息到单流上下文失败 (async) {self.stream_id}: {e}", exc_info=True) logger.error(f"添加消息到单流上下文失败 (async) {self.stream_id}: {e}", exc_info=True)
@@ -307,8 +301,6 @@ class SingleStreamContextManager:
"""异步实现的 update_message更新消息并在需要时 await 能量更新。""" """异步实现的 update_message更新消息并在需要时 await 能量更新。"""
try: try:
self.context.update_message_info(message_id, **updates) self.context.update_message_info(message_id, **updates)
if "interest_value" in updates:
await self._update_stream_energy()
logger.debug(f"更新单流上下文消息(异步): {self.stream_id}/{message_id}") logger.debug(f"更新单流上下文消息(异步): {self.stream_id}/{message_id}")
return True return True
@@ -339,27 +331,31 @@ class SingleStreamContextManager:
logger.error(f"清空单流上下文失败 (async) {self.stream_id}: {e}", exc_info=True) logger.error(f"清空单流上下文失败 (async) {self.stream_id}: {e}", exc_info=True)
return False return False
async def _update_stream_energy(self): async def refresh_focus_energy_from_history(self) -> None:
"""基于历史消息刷新聚焦能量"""
await self._update_stream_energy(include_unread=False)
async def _update_stream_energy(self, include_unread: bool = False) -> None:
"""更新流能量""" """更新流能量"""
try: try:
# 获取所有消息 history_messages = self.context.get_history_messages(limit=self.max_context_size)
all_messages = self.get_messages(self.max_context_size) messages: List[DatabaseMessages] = list(history_messages)
unread_messages = self.get_unread_messages()
combined_messages = all_messages + unread_messages
# 获取用户ID if include_unread:
messages.extend(self.get_unread_messages())
# 获取用户ID优先使用最新历史消息
user_id = None user_id = None
if combined_messages: if messages:
last_message = combined_messages[-1] last_message = messages[-1]
user_id = last_message.user_info.user_id if hasattr(last_message, "user_info") and last_message.user_info:
user_id = last_message.user_info.user_id
# 计算能量 await energy_manager.calculate_focus_energy(
energy = await energy_manager.calculate_focus_energy( stream_id=self.stream_id,
stream_id=self.stream_id, messages=combined_messages, user_id=user_id messages=messages,
user_id=user_id,
) )
# 更新流循环管理器
# 注意能量更新会通过energy_manager自动同步到流循环管理器
except Exception as e: except Exception as e:
logger.error(f"更新单流能量失败 {self.stream_id}: {e}") logger.error(f"更新单流能量失败 {self.stream_id}: {e}")

View File

@@ -246,6 +246,7 @@ class StreamLoopManager:
success = results.get("success", False) success = results.get("success", False)
if success: if success:
await self._refresh_focus_energy(stream_id)
process_time = time.time() - start_time process_time = time.time() - start_time
logger.debug(f"流处理成功: {stream_id} (耗时: {process_time:.2f}s)") logger.debug(f"流处理成功: {stream_id} (耗时: {process_time:.2f}s)")
else: else:
@@ -339,6 +340,20 @@ class StreamLoopManager:
"max_concurrent_streams": self.max_concurrent_streams, "max_concurrent_streams": self.max_concurrent_streams,
} }
async def _refresh_focus_energy(self, stream_id: str) -> None:
"""分发完成后基于历史消息刷新能量值"""
try:
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
if not chat_stream:
logger.debug(f"刷新能量时未找到聊天流: {stream_id}")
return
await chat_stream.context_manager.refresh_focus_energy_from_history()
logger.debug(f"已刷新聊天流 {stream_id} 的聚焦能量")
except Exception as e:
logger.warning(f"刷新聊天流 {stream_id} 能量失败: {e}")
# 全局流循环管理器实例 # 全局流循环管理器实例
stream_loop_manager = StreamLoopManager() stream_loop_manager = StreamLoopManager()

View File

@@ -6,7 +6,7 @@
import asyncio import asyncio
import random import random
import time import time
from typing import Dict, Optional, Any, TYPE_CHECKING from typing import Dict, Optional, Any, TYPE_CHECKING, List
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages from src.common.data_models.database_data_model import DatabaseMessages
@@ -125,6 +125,44 @@ class MessageManager:
except Exception as e: except Exception as e:
logger.error(f"更新消息 {message_id} 时发生错误: {e}") logger.error(f"更新消息 {message_id} 时发生错误: {e}")
async def bulk_update_messages(self, stream_id: str, updates: List[Dict[str, Any]]) -> int:
"""批量更新消息信息,降低更新频率"""
if not updates:
return 0
try:
chat_manager = get_chat_manager()
chat_stream = chat_manager.get_stream(stream_id)
if not chat_stream:
logger.warning(f"MessageManager.bulk_update_messages: 聊天流 {stream_id} 不存在")
return 0
updated_count = 0
for item in updates:
message_id = item.get("message_id")
if not message_id:
continue
payload = {
key: value
for key, value in item.items()
if key != "message_id" and value is not None
}
if not payload:
continue
success = await chat_stream.context_manager.update_message(message_id, payload)
if success:
updated_count += 1
if updated_count:
logger.debug(f"批量更新消息 {updated_count} 条 (stream={stream_id})")
return updated_count
except Exception as e:
logger.error(f"批量更新聊天流 {stream_id} 消息失败: {e}")
return 0
async def add_action(self, stream_id: str, message_id: str, action: str): async def add_action(self, stream_id: str, message_id: str, action: str):
"""添加动作到消息""" """添加动作到消息"""
try: try:

View File

@@ -212,7 +212,11 @@ class MessageStorage:
return match.group(0) return match.group(0)
@staticmethod @staticmethod
async def update_message_interest_value(message_id: str, interest_value: float) -> None: async def update_message_interest_value(
message_id: str,
interest_value: float,
should_reply: bool | None = None,
) -> None:
""" """
更新数据库中消息的interest_value字段 更新数据库中消息的interest_value字段
@@ -223,7 +227,11 @@ class MessageStorage:
try: try:
async with get_db_session() as session: async with get_db_session() as session:
# 更新消息的interest_value字段 # 更新消息的interest_value字段
stmt = update(Messages).where(Messages.message_id == message_id).values(interest_value=interest_value) values = {"interest_value": interest_value}
if should_reply is not None:
values["should_reply"] = should_reply
stmt = update(Messages).where(Messages.message_id == message_id).values(**values)
result = await session.execute(stmt) result = await session.execute(stmt)
await session.commit() await session.commit()
@@ -236,6 +244,31 @@ class MessageStorage:
logger.error(f"更新消息 {message_id} 的interest_value失败: {e}") logger.error(f"更新消息 {message_id} 的interest_value失败: {e}")
raise raise
@staticmethod
async def bulk_update_interest_values(
interest_map: dict[str, float],
reply_map: dict[str, bool] | None = None,
) -> None:
"""批量更新消息的兴趣度与回复标记"""
if not interest_map:
return
try:
async with get_db_session() as session:
for message_id, interest_value in interest_map.items():
values = {"interest_value": interest_value}
if reply_map and message_id in reply_map:
values["should_reply"] = reply_map[message_id]
stmt = update(Messages).where(Messages.message_id == message_id).values(**values)
await session.execute(stmt)
await session.commit()
logger.debug(f"批量更新兴趣度 {len(interest_map)} 条记录")
except Exception as e:
logger.error(f"批量更新消息兴趣度失败: {e}")
raise
@staticmethod @staticmethod
async def fix_zero_interest_values(chat_id: str, since_time: float) -> int: async def fix_zero_interest_values(chat_id: str, since_time: float) -> int:
""" """

View File

@@ -728,8 +728,6 @@ async def initialize_database():
"autocommit": config.mysql_autocommit, "autocommit": config.mysql_autocommit,
"charset": config.mysql_charset, "charset": config.mysql_charset,
"connect_timeout": config.connection_timeout, "connect_timeout": config.connection_timeout,
"read_timeout": 30,
"write_timeout": 30,
}, },
} }
) )

View File

@@ -5,7 +5,7 @@
from dataclasses import asdict from dataclasses import asdict
import time import time
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from src.plugins.built_in.affinity_flow_chatter.plan_executor import ChatterPlanExecutor from src.plugins.built_in.affinity_flow_chatter.plan_executor import ChatterPlanExecutor
from src.plugins.built_in.affinity_flow_chatter.plan_filter import ChatterPlanFilter from src.plugins.built_in.affinity_flow_chatter.plan_filter import ChatterPlanFilter
@@ -104,45 +104,33 @@ class ChatterActionPlanner:
score = 0.0 score = 0.0
should_reply = False should_reply = False
reply_not_available = False reply_not_available = False
interest_updates: List[Dict[str, Any]] = []
if unread_messages: if unread_messages:
# 为每条消息计算兴趣度 # 为每条消息计算兴趣度,并延迟提交数据库更新
for message in unread_messages: for message in unread_messages:
try: try:
# 使用插件内部的兴趣度评分系统计算
interest_score = await chatter_interest_scoring_system._calculate_single_message_score( interest_score = await chatter_interest_scoring_system._calculate_single_message_score(
message=message, message=message,
bot_nickname=global_config.bot.nickname bot_nickname=global_config.bot.nickname,
) )
message_interest = interest_score.total_score message_interest = interest_score.total_score
# 更新消息的兴趣度
message.interest_value = message_interest message.interest_value = message_interest
# 简单的回复决策逻辑:兴趣度超过阈值则回复
message.should_reply = message_interest > global_config.affinity_flow.non_reply_action_interest_threshold message.should_reply = message_interest > global_config.affinity_flow.non_reply_action_interest_threshold
logger.info(f"消息 {message.message_id} 兴趣度: {message_interest:.3f}, 应回复: {message.should_reply}") interest_updates.append(
{
"message_id": message.message_id,
"interest_value": message_interest,
"should_reply": message.should_reply,
}
)
# 更新StreamContext中的消息信息并刷新focus_energy logger.debug(
if context: f"消息 {message.message_id} 兴趣度: {message_interest:.3f}, 应回复: {message.should_reply}"
from src.chat.message_manager.message_manager import message_manager )
await message_manager.update_message(
stream_id=self.chat_id,
message_id=message.message_id,
interest_value=message_interest,
should_reply=message.should_reply
)
# 更新数据库中的消息记录
try:
from src.chat.message_receive.storage import MessageStorage
await MessageStorage.update_message_interest_value(message.message_id, message_interest)
logger.info(f"已更新数据库中消息 {message.message_id} 的兴趣度为: {message_interest:.3f}")
except Exception as e:
logger.warning(f"更新数据库消息兴趣度失败: {e}")
# 记录最高分
if message_interest > score: if message_interest > score:
score = message_interest score = message_interest
if message.should_reply: if message.should_reply:
@@ -152,9 +140,18 @@ class ChatterActionPlanner:
except Exception as e: except Exception as e:
logger.warning(f"计算消息 {message.message_id} 兴趣度失败: {e}") logger.warning(f"计算消息 {message.message_id} 兴趣度失败: {e}")
# 设置默认值
message.interest_value = 0.0 message.interest_value = 0.0
message.should_reply = False message.should_reply = False
interest_updates.append(
{
"message_id": message.message_id,
"interest_value": 0.0,
"should_reply": False,
}
)
if interest_updates:
await self._commit_interest_updates(interest_updates)
# 检查兴趣度是否达到非回复动作阈值 # 检查兴趣度是否达到非回复动作阈值
non_reply_action_interest_threshold = global_config.affinity_flow.non_reply_action_interest_threshold non_reply_action_interest_threshold = global_config.affinity_flow.non_reply_action_interest_threshold
@@ -194,6 +191,33 @@ class ChatterActionPlanner:
self.planner_stats["failed_plans"] += 1 self.planner_stats["failed_plans"] += 1
return [], None return [], None
async def _commit_interest_updates(self, updates: List[Dict[str, Any]]) -> None:
"""统一更新消息兴趣度,减少数据库写入次数"""
if not updates:
return
try:
from src.chat.message_manager.message_manager import message_manager
await message_manager.bulk_update_messages(self.chat_id, updates)
except Exception as e:
logger.warning(f"批量更新上下文消息兴趣度失败: {e}")
try:
from src.chat.message_receive.storage import MessageStorage
interest_map = {item["message_id"]: item["interest_value"] for item in updates if "interest_value" in item}
reply_map = {item["message_id"]: item["should_reply"] for item in updates if "should_reply" in item}
await MessageStorage.bulk_update_interest_values(
interest_map=interest_map,
reply_map=reply_map if reply_map else None,
)
logger.debug(f"已批量更新 {len(interest_map)} 条消息的兴趣度")
except Exception as e:
logger.warning(f"批量更新数据库兴趣度失败: {e}")
def _update_stats_from_execution_result(self, execution_result: Dict[str, any]): def _update_stats_from_execution_result(self, execution_result: Dict[str, any]):
"""根据执行结果更新规划器统计""" """根据执行结果更新规划器统计"""
if not execution_result: if not execution_result: