feat(memory): 增强记忆构建上下文处理能力并优化兴趣度批量更新机制
- 在记忆构建过程中允许检索历史记忆作为上下文补充 - 改进LLM响应解析逻辑,增强JSON提取兼容性 - 优化消息兴趣度计算和批量更新机制,减少数据库写入频率 - 添加构建状态管理,支持在BUILDING状态下进行记忆检索 - 修复stream_id拼写错误处理和历史消息获取逻辑
This commit is contained in:
@@ -12,12 +12,11 @@ from typing import Dict, List, Optional, Set, Any, TYPE_CHECKING
|
||||
from datetime import datetime, timedelta
|
||||
from dataclasses import dataclass, asdict
|
||||
from enum import Enum
|
||||
import numpy as np
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config, global_config
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk
|
||||
from src.chat.memory_system.memory_builder import MemoryBuilder
|
||||
from src.chat.memory_system.memory_fusion import MemoryFusionEngine
|
||||
from src.chat.memory_system.vector_storage import VectorStorageManager, VectorStorageConfig
|
||||
@@ -64,7 +63,6 @@ class MemorySystemConfig:
|
||||
@classmethod
|
||||
def from_global_config(cls):
|
||||
"""从全局配置创建配置实例"""
|
||||
from src.config.config import global_config
|
||||
|
||||
return cls(
|
||||
# 记忆构建配置
|
||||
@@ -130,7 +128,7 @@ class EnhancedMemorySystem:
|
||||
task_config = (
|
||||
self.llm_model.model_for_task
|
||||
if self.llm_model is not None
|
||||
else model_config.model_task_config.utils
|
||||
else model_config.model_task_config.utils_small
|
||||
)
|
||||
|
||||
self.value_assessment_model = LLMRequest(
|
||||
@@ -173,6 +171,49 @@ class EnhancedMemorySystem:
|
||||
logger.error(f"❌ 记忆系统初始化失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def retrieve_memories_for_building(
|
||||
self,
|
||||
query_text: str,
|
||||
user_id: str,
|
||||
context: Optional[Dict[str, Any]] = None,
|
||||
limit: int = 5
|
||||
) -> List[MemoryChunk]:
|
||||
"""在构建记忆时检索相关记忆,允许在BUILDING状态下进行检索
|
||||
|
||||
Args:
|
||||
query_text: 查询文本
|
||||
user_id: 用户ID
|
||||
context: 上下文信息
|
||||
limit: 返回结果数量限制
|
||||
|
||||
Returns:
|
||||
相关记忆列表
|
||||
"""
|
||||
if self.status not in [MemorySystemStatus.READY, MemorySystemStatus.BUILDING]:
|
||||
logger.warning(f"记忆系统状态不允许检索: {self.status.value}")
|
||||
return []
|
||||
|
||||
try:
|
||||
# 临时切换到检索状态
|
||||
original_status = self.status
|
||||
self.status = MemorySystemStatus.RETRIEVING
|
||||
|
||||
# 执行检索
|
||||
memories = await self.vector_storage.search_similar_memories(
|
||||
query_text=query_text,
|
||||
user_id=user_id,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
# 恢复原始状态
|
||||
self.status = original_status
|
||||
|
||||
return memories
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"构建过程中检索记忆失败: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
async def build_memory_from_conversation(
|
||||
self,
|
||||
conversation_text: str,
|
||||
@@ -191,9 +232,10 @@ class EnhancedMemorySystem:
|
||||
Returns:
|
||||
构建的记忆块列表
|
||||
"""
|
||||
if self.status != MemorySystemStatus.READY:
|
||||
if self.status not in [MemorySystemStatus.READY, MemorySystemStatus.BUILDING]:
|
||||
raise RuntimeError("记忆系统未就绪")
|
||||
|
||||
original_status = self.status
|
||||
self.status = MemorySystemStatus.BUILDING
|
||||
start_time = time.time()
|
||||
|
||||
@@ -222,7 +264,7 @@ class EnhancedMemorySystem:
|
||||
build_marker_time = current_time
|
||||
self._last_memory_build_times[build_scope_key] = current_time
|
||||
|
||||
conversation_text = self._resolve_conversation_context(conversation_text, normalized_context)
|
||||
conversation_text = await self._resolve_conversation_context(conversation_text, normalized_context)
|
||||
|
||||
logger.debug(f"开始为用户 {user_id} 构建记忆,文本长度: {len(conversation_text)}")
|
||||
|
||||
@@ -231,7 +273,7 @@ class EnhancedMemorySystem:
|
||||
|
||||
if value_score < self.config.memory_value_threshold:
|
||||
logger.info(f"信息价值评分 {value_score:.2f} 低于阈值,跳过记忆构建")
|
||||
self.status = MemorySystemStatus.READY
|
||||
self.status = original_status
|
||||
return []
|
||||
|
||||
# 2. 构建记忆块
|
||||
@@ -244,7 +286,7 @@ class EnhancedMemorySystem:
|
||||
|
||||
if not memory_chunks:
|
||||
logger.debug("未提取到有效记忆块")
|
||||
self.status = MemorySystemStatus.READY
|
||||
self.status = original_status
|
||||
return []
|
||||
|
||||
# 3. 记忆融合与去重
|
||||
@@ -262,7 +304,7 @@ class EnhancedMemorySystem:
|
||||
build_time = time.time() - start_time
|
||||
logger.info(f"✅ 为用户 {user_id} 构建了 {len(fused_chunks)} 条记忆,耗时 {build_time:.2f}秒")
|
||||
|
||||
self.status = MemorySystemStatus.READY
|
||||
self.status = original_status
|
||||
return fused_chunks
|
||||
|
||||
except Exception as e:
|
||||
@@ -469,49 +511,79 @@ class EnhancedMemorySystem:
|
||||
|
||||
return context
|
||||
|
||||
def _resolve_conversation_context(self, fallback_text: str, context: Optional[Dict[str, Any]]) -> str:
|
||||
"""使用 stream_id 历史消息充实对话文本,默认回退到传入文本"""
|
||||
async def _resolve_conversation_context(self, fallback_text: str, context: Optional[Dict[str, Any]]) -> str:
|
||||
"""使用 stream_id 历史消息和相关记忆充实对话文本,默认回退到传入文本"""
|
||||
if not context:
|
||||
return fallback_text
|
||||
|
||||
user_id = context.get("user_id")
|
||||
stream_id = context.get("stream_id") or context.get("stram_id")
|
||||
if not stream_id:
|
||||
return fallback_text
|
||||
|
||||
try:
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
# 优先使用 stream_id 获取历史消息
|
||||
if stream_id:
|
||||
try:
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = chat_manager.get_stream(stream_id)
|
||||
if not chat_stream or not hasattr(chat_stream, "context_manager"):
|
||||
logger.debug(f"未找到 stream_id={stream_id} 对应的聊天流或上下文管理器")
|
||||
return fallback_text
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = chat_manager.get_stream(stream_id)
|
||||
if chat_stream and hasattr(chat_stream, "context_manager"):
|
||||
history_limit = self._determine_history_limit(context)
|
||||
messages = chat_stream.context_manager.get_messages(limit=history_limit, include_unread=True)
|
||||
if messages:
|
||||
transcript = self._format_history_messages(messages)
|
||||
if transcript:
|
||||
cleaned_fallback = (fallback_text or "").strip()
|
||||
if cleaned_fallback and cleaned_fallback not in transcript:
|
||||
transcript = f"{transcript}\n[当前消息] {cleaned_fallback}"
|
||||
|
||||
history_limit = self._determine_history_limit(context)
|
||||
messages = chat_stream.context_manager.get_messages(limit=history_limit, include_unread=True)
|
||||
if not messages:
|
||||
logger.debug(f"stream_id={stream_id} 未获取到历史消息")
|
||||
return fallback_text
|
||||
logger.debug(
|
||||
"使用 stream_id=%s 的历史消息构建记忆上下文,消息数=%d,限制=%d",
|
||||
stream_id,
|
||||
len(messages),
|
||||
history_limit,
|
||||
)
|
||||
return transcript
|
||||
else:
|
||||
logger.debug(f"stream_id={stream_id} 历史消息格式化失败")
|
||||
else:
|
||||
logger.debug(f"stream_id={stream_id} 未获取到历史消息")
|
||||
else:
|
||||
logger.debug(f"未找到 stream_id={stream_id} 对应的聊天流或上下文管理器")
|
||||
except Exception as exc:
|
||||
logger.warning(f"获取 stream_id={stream_id} 的历史消息失败: {exc}", exc_info=True)
|
||||
|
||||
transcript = self._format_history_messages(messages)
|
||||
if not transcript:
|
||||
return fallback_text
|
||||
# 如果无法获取历史消息,尝试检索相关记忆作为上下文
|
||||
if user_id and fallback_text:
|
||||
try:
|
||||
relevant_memories = await self.retrieve_memories_for_building(
|
||||
query_text=fallback_text,
|
||||
user_id=user_id,
|
||||
context=context,
|
||||
limit=3
|
||||
)
|
||||
|
||||
cleaned_fallback = (fallback_text or "").strip()
|
||||
if cleaned_fallback and cleaned_fallback not in transcript:
|
||||
transcript = f"{transcript}\n[当前消息] {cleaned_fallback}"
|
||||
if relevant_memories:
|
||||
memory_contexts = []
|
||||
for memory in relevant_memories:
|
||||
memory_contexts.append(f"[历史记忆] {memory.text_content}")
|
||||
|
||||
logger.debug(
|
||||
"使用 stream_id=%s 的历史消息构建记忆上下文,消息数=%d,限制=%d",
|
||||
stream_id,
|
||||
len(messages),
|
||||
history_limit,
|
||||
)
|
||||
return transcript
|
||||
memory_transcript = "\n".join(memory_contexts)
|
||||
cleaned_fallback = (fallback_text or "").strip()
|
||||
if cleaned_fallback and cleaned_fallback not in memory_transcript:
|
||||
memory_transcript = f"{memory_transcript}\n[当前消息] {cleaned_fallback}"
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning(f"获取 stream_id={stream_id} 的历史消息失败: {exc}", exc_info=True)
|
||||
return fallback_text
|
||||
logger.debug(
|
||||
"使用检索到的历史记忆构建记忆上下文,记忆数=%d,用户=%s",
|
||||
len(relevant_memories),
|
||||
user_id
|
||||
)
|
||||
return memory_transcript
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning(f"检索历史记忆作为上下文失败: {exc}", exc_info=True)
|
||||
|
||||
# 回退到传入文本
|
||||
return fallback_text
|
||||
|
||||
def _get_build_scope_key(self, context: Dict[str, Any], user_id: Optional[str]) -> Optional[str]:
|
||||
"""确定用于节流控制的记忆构建作用域"""
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
import re
|
||||
import time
|
||||
import orjson
|
||||
from typing import Dict, List, Optional, Tuple, Any, Set
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
@@ -16,7 +16,7 @@ from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.chat.memory_system.memory_chunk import (
|
||||
MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel,
|
||||
ContentStructure, MemoryMetadata, create_memory_chunk
|
||||
create_memory_chunk
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -334,6 +334,28 @@ class MemoryBuilder:
|
||||
|
||||
return prompt
|
||||
|
||||
def _extract_json_payload(self, response: str) -> Optional[str]:
|
||||
"""从模型响应中提取JSON部分,兼容Markdown代码块等格式"""
|
||||
if not response:
|
||||
return None
|
||||
|
||||
stripped = response.strip()
|
||||
|
||||
# 优先处理Markdown代码块格式 ```json ... ```
|
||||
code_block_match = re.search(r"```(?:json)?\s*(.*?)```", stripped, re.IGNORECASE | re.DOTALL)
|
||||
if code_block_match:
|
||||
candidate = code_block_match.group(1).strip()
|
||||
if candidate:
|
||||
return candidate
|
||||
|
||||
# 回退到查找第一个 JSON 对象的大括号范围
|
||||
start = stripped.find("{")
|
||||
end = stripped.rfind("}")
|
||||
if start != -1 and end != -1 and end > start:
|
||||
return stripped[start:end + 1].strip()
|
||||
|
||||
return stripped if stripped.startswith("{") and stripped.endswith("}") else None
|
||||
|
||||
def _parse_llm_response(
|
||||
self,
|
||||
response: str,
|
||||
@@ -345,7 +367,13 @@ class MemoryBuilder:
|
||||
memories = []
|
||||
|
||||
try:
|
||||
data = orjson.loads(response)
|
||||
# 提取JSON负载
|
||||
json_payload = self._extract_json_payload(response)
|
||||
if not json_payload:
|
||||
logger.error("未在响应中找到有效的JSON负载")
|
||||
return memories
|
||||
|
||||
data = orjson.loads(json_payload)
|
||||
memory_list = data.get("memories", [])
|
||||
|
||||
for mem_data in memory_list:
|
||||
@@ -375,7 +403,8 @@ class MemoryBuilder:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析LLM响应失败: {e}, 响应: {response}")
|
||||
preview = response[:200] if response else "空响应"
|
||||
logger.error(f"解析LLM响应失败: {e}, 响应片段: {preview}")
|
||||
|
||||
return memories
|
||||
|
||||
@@ -623,7 +652,7 @@ class MemoryBuilder:
|
||||
try:
|
||||
# 尝试解析回字典(如果原来是字典)
|
||||
memory.content.object = eval(obj_str) if obj_str.startswith('{') else obj_str
|
||||
except:
|
||||
except Exception:
|
||||
memory.content.object = obj_str
|
||||
|
||||
# 记录时间规范化操作
|
||||
|
||||
@@ -47,22 +47,20 @@ class SingleStreamContextManager:
|
||||
|
||||
Args:
|
||||
message: 消息对象
|
||||
skip_energy_update: 是否跳过能量更新
|
||||
skip_energy_update: 是否跳过能量更新(兼容参数,当前忽略)
|
||||
|
||||
Returns:
|
||||
bool: 是否成功添加
|
||||
"""
|
||||
try:
|
||||
self.context.add_message(message)
|
||||
interest_value = await self._calculate_message_interest(message)
|
||||
message.interest_value = interest_value
|
||||
# 推迟兴趣度计算到分发阶段
|
||||
message.interest_value = getattr(message, "interest_value", None)
|
||||
self.total_messages += 1
|
||||
self.last_access_time = time.time()
|
||||
if not skip_energy_update:
|
||||
await self._update_stream_energy()
|
||||
# 启动流的循环任务(如果还未启动)
|
||||
await stream_loop_manager.start_stream_loop(self.stream_id)
|
||||
logger.info(f"添加消息到单流上下文: {self.stream_id} (兴趣度: {interest_value:.3f})")
|
||||
# 启动流的循环任务(如果还未启动)
|
||||
await stream_loop_manager.start_stream_loop(self.stream_id)
|
||||
logger.info(f"添加消息到单流上下文: {self.stream_id} (兴趣度待计算)")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"添加消息到单流上下文失败 {self.stream_id}: {e}", exc_info=True)
|
||||
@@ -80,8 +78,6 @@ class SingleStreamContextManager:
|
||||
"""
|
||||
try:
|
||||
self.context.update_message_info(message_id, **updates)
|
||||
if "interest_value" in updates:
|
||||
await self._update_stream_energy()
|
||||
logger.debug(f"更新单流上下文消息: {self.stream_id}/{message_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
@@ -286,18 +282,16 @@ class SingleStreamContextManager:
|
||||
try:
|
||||
self.context.add_message(message)
|
||||
|
||||
interest_value = await self._calculate_message_interest_async(message)
|
||||
message.interest_value = interest_value
|
||||
# 推迟兴趣度计算到分发阶段
|
||||
message.interest_value = getattr(message, "interest_value", None)
|
||||
|
||||
self.total_messages += 1
|
||||
self.last_access_time = time.time()
|
||||
|
||||
if not skip_energy_update:
|
||||
await self._update_stream_energy()
|
||||
# 启动流的循环任务(如果还未启动)
|
||||
await stream_loop_manager.start_stream_loop(self.stream_id)
|
||||
# 启动流的循环任务(如果还未启动)
|
||||
await stream_loop_manager.start_stream_loop(self.stream_id)
|
||||
|
||||
logger.debug(f"添加消息到单流上下文(异步): {self.stream_id} (兴趣度: {interest_value:.3f})")
|
||||
logger.debug(f"添加消息到单流上下文(异步): {self.stream_id} (兴趣度待计算)")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"添加消息到单流上下文失败 (async) {self.stream_id}: {e}", exc_info=True)
|
||||
@@ -307,8 +301,6 @@ class SingleStreamContextManager:
|
||||
"""异步实现的 update_message:更新消息并在需要时 await 能量更新。"""
|
||||
try:
|
||||
self.context.update_message_info(message_id, **updates)
|
||||
if "interest_value" in updates:
|
||||
await self._update_stream_energy()
|
||||
|
||||
logger.debug(f"更新单流上下文消息(异步): {self.stream_id}/{message_id}")
|
||||
return True
|
||||
@@ -339,27 +331,31 @@ class SingleStreamContextManager:
|
||||
logger.error(f"清空单流上下文失败 (async) {self.stream_id}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
async def _update_stream_energy(self):
|
||||
async def refresh_focus_energy_from_history(self) -> None:
|
||||
"""基于历史消息刷新聚焦能量"""
|
||||
await self._update_stream_energy(include_unread=False)
|
||||
|
||||
async def _update_stream_energy(self, include_unread: bool = False) -> None:
|
||||
"""更新流能量"""
|
||||
try:
|
||||
# 获取所有消息
|
||||
all_messages = self.get_messages(self.max_context_size)
|
||||
unread_messages = self.get_unread_messages()
|
||||
combined_messages = all_messages + unread_messages
|
||||
history_messages = self.context.get_history_messages(limit=self.max_context_size)
|
||||
messages: List[DatabaseMessages] = list(history_messages)
|
||||
|
||||
# 获取用户ID
|
||||
if include_unread:
|
||||
messages.extend(self.get_unread_messages())
|
||||
|
||||
# 获取用户ID(优先使用最新历史消息)
|
||||
user_id = None
|
||||
if combined_messages:
|
||||
last_message = combined_messages[-1]
|
||||
user_id = last_message.user_info.user_id
|
||||
if messages:
|
||||
last_message = messages[-1]
|
||||
if hasattr(last_message, "user_info") and last_message.user_info:
|
||||
user_id = last_message.user_info.user_id
|
||||
|
||||
# 计算能量
|
||||
energy = await energy_manager.calculate_focus_energy(
|
||||
stream_id=self.stream_id, messages=combined_messages, user_id=user_id
|
||||
await energy_manager.calculate_focus_energy(
|
||||
stream_id=self.stream_id,
|
||||
messages=messages,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# 更新流循环管理器
|
||||
# 注意:能量更新会通过energy_manager自动同步到流循环管理器
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"更新单流能量失败 {self.stream_id}: {e}")
|
||||
|
||||
@@ -246,6 +246,7 @@ class StreamLoopManager:
|
||||
success = results.get("success", False)
|
||||
|
||||
if success:
|
||||
await self._refresh_focus_energy(stream_id)
|
||||
process_time = time.time() - start_time
|
||||
logger.debug(f"流处理成功: {stream_id} (耗时: {process_time:.2f}s)")
|
||||
else:
|
||||
@@ -339,6 +340,20 @@ class StreamLoopManager:
|
||||
"max_concurrent_streams": self.max_concurrent_streams,
|
||||
}
|
||||
|
||||
async def _refresh_focus_energy(self, stream_id: str) -> None:
|
||||
"""分发完成后基于历史消息刷新能量值"""
|
||||
try:
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = chat_manager.get_stream(stream_id)
|
||||
if not chat_stream:
|
||||
logger.debug(f"刷新能量时未找到聊天流: {stream_id}")
|
||||
return
|
||||
|
||||
await chat_stream.context_manager.refresh_focus_energy_from_history()
|
||||
logger.debug(f"已刷新聊天流 {stream_id} 的聚焦能量")
|
||||
except Exception as e:
|
||||
logger.warning(f"刷新聊天流 {stream_id} 能量失败: {e}")
|
||||
|
||||
|
||||
# 全局流循环管理器实例
|
||||
stream_loop_manager = StreamLoopManager()
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
import asyncio
|
||||
import random
|
||||
import time
|
||||
from typing import Dict, Optional, Any, TYPE_CHECKING
|
||||
from typing import Dict, Optional, Any, TYPE_CHECKING, List
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
@@ -125,6 +125,44 @@ class MessageManager:
|
||||
except Exception as e:
|
||||
logger.error(f"更新消息 {message_id} 时发生错误: {e}")
|
||||
|
||||
async def bulk_update_messages(self, stream_id: str, updates: List[Dict[str, Any]]) -> int:
|
||||
"""批量更新消息信息,降低更新频率"""
|
||||
if not updates:
|
||||
return 0
|
||||
|
||||
try:
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = chat_manager.get_stream(stream_id)
|
||||
if not chat_stream:
|
||||
logger.warning(f"MessageManager.bulk_update_messages: 聊天流 {stream_id} 不存在")
|
||||
return 0
|
||||
|
||||
updated_count = 0
|
||||
for item in updates:
|
||||
message_id = item.get("message_id")
|
||||
if not message_id:
|
||||
continue
|
||||
|
||||
payload = {
|
||||
key: value
|
||||
for key, value in item.items()
|
||||
if key != "message_id" and value is not None
|
||||
}
|
||||
|
||||
if not payload:
|
||||
continue
|
||||
|
||||
success = await chat_stream.context_manager.update_message(message_id, payload)
|
||||
if success:
|
||||
updated_count += 1
|
||||
|
||||
if updated_count:
|
||||
logger.debug(f"批量更新消息 {updated_count} 条 (stream={stream_id})")
|
||||
return updated_count
|
||||
except Exception as e:
|
||||
logger.error(f"批量更新聊天流 {stream_id} 消息失败: {e}")
|
||||
return 0
|
||||
|
||||
async def add_action(self, stream_id: str, message_id: str, action: str):
|
||||
"""添加动作到消息"""
|
||||
try:
|
||||
|
||||
@@ -219,7 +219,11 @@ class MessageStorage:
|
||||
return match.group(0)
|
||||
|
||||
@staticmethod
|
||||
async def update_message_interest_value(message_id: str, interest_value: float) -> None:
|
||||
async def update_message_interest_value(
|
||||
message_id: str,
|
||||
interest_value: float,
|
||||
should_reply: bool | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
更新数据库中消息的interest_value字段
|
||||
|
||||
@@ -230,7 +234,11 @@ class MessageStorage:
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
# 更新消息的interest_value字段
|
||||
stmt = update(Messages).where(Messages.message_id == message_id).values(interest_value=interest_value)
|
||||
values = {"interest_value": interest_value}
|
||||
if should_reply is not None:
|
||||
values["should_reply"] = should_reply
|
||||
|
||||
stmt = update(Messages).where(Messages.message_id == message_id).values(**values)
|
||||
result = await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
@@ -243,6 +251,31 @@ class MessageStorage:
|
||||
logger.error(f"更新消息 {message_id} 的interest_value失败: {e}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
async def bulk_update_interest_values(
|
||||
interest_map: dict[str, float],
|
||||
reply_map: dict[str, bool] | None = None,
|
||||
) -> None:
|
||||
"""批量更新消息的兴趣度与回复标记"""
|
||||
if not interest_map:
|
||||
return
|
||||
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
for message_id, interest_value in interest_map.items():
|
||||
values = {"interest_value": interest_value}
|
||||
if reply_map and message_id in reply_map:
|
||||
values["should_reply"] = reply_map[message_id]
|
||||
|
||||
stmt = update(Messages).where(Messages.message_id == message_id).values(**values)
|
||||
await session.execute(stmt)
|
||||
|
||||
await session.commit()
|
||||
logger.debug(f"批量更新兴趣度 {len(interest_map)} 条记录")
|
||||
except Exception as e:
|
||||
logger.error(f"批量更新消息兴趣度失败: {e}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
async def fix_zero_interest_values(chat_id: str, since_time: float) -> int:
|
||||
"""
|
||||
|
||||
@@ -728,8 +728,6 @@ async def initialize_database():
|
||||
"autocommit": config.mysql_autocommit,
|
||||
"charset": config.mysql_charset,
|
||||
"connect_timeout": config.connection_timeout,
|
||||
"read_timeout": 30,
|
||||
"write_timeout": 30,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
from dataclasses import asdict
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
from src.plugins.built_in.affinity_flow_chatter.plan_executor import ChatterPlanExecutor
|
||||
from src.plugins.built_in.affinity_flow_chatter.plan_filter import ChatterPlanFilter
|
||||
@@ -104,45 +104,33 @@ class ChatterActionPlanner:
|
||||
score = 0.0
|
||||
should_reply = False
|
||||
reply_not_available = False
|
||||
interest_updates: List[Dict[str, Any]] = []
|
||||
|
||||
if unread_messages:
|
||||
# 为每条消息计算兴趣度
|
||||
# 为每条消息计算兴趣度,并延迟提交数据库更新
|
||||
for message in unread_messages:
|
||||
try:
|
||||
# 使用插件内部的兴趣度评分系统计算
|
||||
interest_score = await chatter_interest_scoring_system._calculate_single_message_score(
|
||||
message=message,
|
||||
bot_nickname=global_config.bot.nickname
|
||||
bot_nickname=global_config.bot.nickname,
|
||||
)
|
||||
message_interest = interest_score.total_score
|
||||
|
||||
# 更新消息的兴趣度
|
||||
message.interest_value = message_interest
|
||||
|
||||
# 简单的回复决策逻辑:兴趣度超过阈值则回复
|
||||
message.should_reply = message_interest > global_config.affinity_flow.non_reply_action_interest_threshold
|
||||
|
||||
logger.info(f"消息 {message.message_id} 兴趣度: {message_interest:.3f}, 应回复: {message.should_reply}")
|
||||
interest_updates.append(
|
||||
{
|
||||
"message_id": message.message_id,
|
||||
"interest_value": message_interest,
|
||||
"should_reply": message.should_reply,
|
||||
}
|
||||
)
|
||||
|
||||
# 更新StreamContext中的消息信息并刷新focus_energy
|
||||
if context:
|
||||
from src.chat.message_manager.message_manager import message_manager
|
||||
await message_manager.update_message(
|
||||
stream_id=self.chat_id,
|
||||
message_id=message.message_id,
|
||||
interest_value=message_interest,
|
||||
should_reply=message.should_reply
|
||||
)
|
||||
logger.debug(
|
||||
f"消息 {message.message_id} 兴趣度: {message_interest:.3f}, 应回复: {message.should_reply}"
|
||||
)
|
||||
|
||||
# 更新数据库中的消息记录
|
||||
try:
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
await MessageStorage.update_message_interest_value(message.message_id, message_interest)
|
||||
logger.info(f"已更新数据库中消息 {message.message_id} 的兴趣度为: {message_interest:.3f}")
|
||||
except Exception as e:
|
||||
logger.warning(f"更新数据库消息兴趣度失败: {e}")
|
||||
|
||||
# 记录最高分
|
||||
if message_interest > score:
|
||||
score = message_interest
|
||||
if message.should_reply:
|
||||
@@ -152,9 +140,18 @@ class ChatterActionPlanner:
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"计算消息 {message.message_id} 兴趣度失败: {e}")
|
||||
# 设置默认值
|
||||
message.interest_value = 0.0
|
||||
message.should_reply = False
|
||||
interest_updates.append(
|
||||
{
|
||||
"message_id": message.message_id,
|
||||
"interest_value": 0.0,
|
||||
"should_reply": False,
|
||||
}
|
||||
)
|
||||
|
||||
if interest_updates:
|
||||
await self._commit_interest_updates(interest_updates)
|
||||
|
||||
# 检查兴趣度是否达到非回复动作阈值
|
||||
non_reply_action_interest_threshold = global_config.affinity_flow.non_reply_action_interest_threshold
|
||||
@@ -194,6 +191,33 @@ class ChatterActionPlanner:
|
||||
self.planner_stats["failed_plans"] += 1
|
||||
return [], None
|
||||
|
||||
async def _commit_interest_updates(self, updates: List[Dict[str, Any]]) -> None:
|
||||
"""统一更新消息兴趣度,减少数据库写入次数"""
|
||||
if not updates:
|
||||
return
|
||||
|
||||
try:
|
||||
from src.chat.message_manager.message_manager import message_manager
|
||||
|
||||
await message_manager.bulk_update_messages(self.chat_id, updates)
|
||||
except Exception as e:
|
||||
logger.warning(f"批量更新上下文消息兴趣度失败: {e}")
|
||||
|
||||
try:
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
|
||||
interest_map = {item["message_id"]: item["interest_value"] for item in updates if "interest_value" in item}
|
||||
reply_map = {item["message_id"]: item["should_reply"] for item in updates if "should_reply" in item}
|
||||
|
||||
await MessageStorage.bulk_update_interest_values(
|
||||
interest_map=interest_map,
|
||||
reply_map=reply_map if reply_map else None,
|
||||
)
|
||||
|
||||
logger.debug(f"已批量更新 {len(interest_map)} 条消息的兴趣度")
|
||||
except Exception as e:
|
||||
logger.warning(f"批量更新数据库兴趣度失败: {e}")
|
||||
|
||||
def _update_stats_from_execution_result(self, execution_result: Dict[str, any]):
|
||||
"""根据执行结果更新规划器统计"""
|
||||
if not execution_result:
|
||||
|
||||
Reference in New Issue
Block a user