feat(memory): 增强记忆构建上下文处理能力并优化兴趣度批量更新机制
- 在记忆构建过程中允许检索历史记忆作为上下文补充 - 改进LLM响应解析逻辑,增强JSON提取兼容性 - 优化消息兴趣度计算和批量更新机制,减少数据库写入频率 - 添加构建状态管理,支持在BUILDING状态下进行记忆检索 - 修复stream_id拼写错误处理和历史消息获取逻辑
This commit is contained in:
@@ -12,12 +12,11 @@ from typing import Dict, List, Optional, Set, Any, TYPE_CHECKING
|
|||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from dataclasses import dataclass, asdict
|
from dataclasses import dataclass, asdict
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import model_config, global_config
|
from src.config.config import model_config, global_config
|
||||||
from src.chat.memory_system.memory_chunk import MemoryChunk, MemoryType
|
from src.chat.memory_system.memory_chunk import MemoryChunk
|
||||||
from src.chat.memory_system.memory_builder import MemoryBuilder
|
from src.chat.memory_system.memory_builder import MemoryBuilder
|
||||||
from src.chat.memory_system.memory_fusion import MemoryFusionEngine
|
from src.chat.memory_system.memory_fusion import MemoryFusionEngine
|
||||||
from src.chat.memory_system.vector_storage import VectorStorageManager, VectorStorageConfig
|
from src.chat.memory_system.vector_storage import VectorStorageManager, VectorStorageConfig
|
||||||
@@ -64,7 +63,6 @@ class MemorySystemConfig:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_global_config(cls):
|
def from_global_config(cls):
|
||||||
"""从全局配置创建配置实例"""
|
"""从全局配置创建配置实例"""
|
||||||
from src.config.config import global_config
|
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
# 记忆构建配置
|
# 记忆构建配置
|
||||||
@@ -130,7 +128,7 @@ class EnhancedMemorySystem:
|
|||||||
task_config = (
|
task_config = (
|
||||||
self.llm_model.model_for_task
|
self.llm_model.model_for_task
|
||||||
if self.llm_model is not None
|
if self.llm_model is not None
|
||||||
else model_config.model_task_config.utils
|
else model_config.model_task_config.utils_small
|
||||||
)
|
)
|
||||||
|
|
||||||
self.value_assessment_model = LLMRequest(
|
self.value_assessment_model = LLMRequest(
|
||||||
@@ -173,6 +171,49 @@ class EnhancedMemorySystem:
|
|||||||
logger.error(f"❌ 记忆系统初始化失败: {e}", exc_info=True)
|
logger.error(f"❌ 记忆系统初始化失败: {e}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
async def retrieve_memories_for_building(
|
||||||
|
self,
|
||||||
|
query_text: str,
|
||||||
|
user_id: str,
|
||||||
|
context: Optional[Dict[str, Any]] = None,
|
||||||
|
limit: int = 5
|
||||||
|
) -> List[MemoryChunk]:
|
||||||
|
"""在构建记忆时检索相关记忆,允许在BUILDING状态下进行检索
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_text: 查询文本
|
||||||
|
user_id: 用户ID
|
||||||
|
context: 上下文信息
|
||||||
|
limit: 返回结果数量限制
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
相关记忆列表
|
||||||
|
"""
|
||||||
|
if self.status not in [MemorySystemStatus.READY, MemorySystemStatus.BUILDING]:
|
||||||
|
logger.warning(f"记忆系统状态不允许检索: {self.status.value}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 临时切换到检索状态
|
||||||
|
original_status = self.status
|
||||||
|
self.status = MemorySystemStatus.RETRIEVING
|
||||||
|
|
||||||
|
# 执行检索
|
||||||
|
memories = await self.vector_storage.search_similar_memories(
|
||||||
|
query_text=query_text,
|
||||||
|
user_id=user_id,
|
||||||
|
limit=limit
|
||||||
|
)
|
||||||
|
|
||||||
|
# 恢复原始状态
|
||||||
|
self.status = original_status
|
||||||
|
|
||||||
|
return memories
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"构建过程中检索记忆失败: {e}", exc_info=True)
|
||||||
|
return []
|
||||||
|
|
||||||
async def build_memory_from_conversation(
|
async def build_memory_from_conversation(
|
||||||
self,
|
self,
|
||||||
conversation_text: str,
|
conversation_text: str,
|
||||||
@@ -191,9 +232,10 @@ class EnhancedMemorySystem:
|
|||||||
Returns:
|
Returns:
|
||||||
构建的记忆块列表
|
构建的记忆块列表
|
||||||
"""
|
"""
|
||||||
if self.status != MemorySystemStatus.READY:
|
if self.status not in [MemorySystemStatus.READY, MemorySystemStatus.BUILDING]:
|
||||||
raise RuntimeError("记忆系统未就绪")
|
raise RuntimeError("记忆系统未就绪")
|
||||||
|
|
||||||
|
original_status = self.status
|
||||||
self.status = MemorySystemStatus.BUILDING
|
self.status = MemorySystemStatus.BUILDING
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
@@ -222,7 +264,7 @@ class EnhancedMemorySystem:
|
|||||||
build_marker_time = current_time
|
build_marker_time = current_time
|
||||||
self._last_memory_build_times[build_scope_key] = current_time
|
self._last_memory_build_times[build_scope_key] = current_time
|
||||||
|
|
||||||
conversation_text = self._resolve_conversation_context(conversation_text, normalized_context)
|
conversation_text = await self._resolve_conversation_context(conversation_text, normalized_context)
|
||||||
|
|
||||||
logger.debug(f"开始为用户 {user_id} 构建记忆,文本长度: {len(conversation_text)}")
|
logger.debug(f"开始为用户 {user_id} 构建记忆,文本长度: {len(conversation_text)}")
|
||||||
|
|
||||||
@@ -231,7 +273,7 @@ class EnhancedMemorySystem:
|
|||||||
|
|
||||||
if value_score < self.config.memory_value_threshold:
|
if value_score < self.config.memory_value_threshold:
|
||||||
logger.info(f"信息价值评分 {value_score:.2f} 低于阈值,跳过记忆构建")
|
logger.info(f"信息价值评分 {value_score:.2f} 低于阈值,跳过记忆构建")
|
||||||
self.status = MemorySystemStatus.READY
|
self.status = original_status
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# 2. 构建记忆块
|
# 2. 构建记忆块
|
||||||
@@ -244,7 +286,7 @@ class EnhancedMemorySystem:
|
|||||||
|
|
||||||
if not memory_chunks:
|
if not memory_chunks:
|
||||||
logger.debug("未提取到有效记忆块")
|
logger.debug("未提取到有效记忆块")
|
||||||
self.status = MemorySystemStatus.READY
|
self.status = original_status
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# 3. 记忆融合与去重
|
# 3. 记忆融合与去重
|
||||||
@@ -262,7 +304,7 @@ class EnhancedMemorySystem:
|
|||||||
build_time = time.time() - start_time
|
build_time = time.time() - start_time
|
||||||
logger.info(f"✅ 为用户 {user_id} 构建了 {len(fused_chunks)} 条记忆,耗时 {build_time:.2f}秒")
|
logger.info(f"✅ 为用户 {user_id} 构建了 {len(fused_chunks)} 条记忆,耗时 {build_time:.2f}秒")
|
||||||
|
|
||||||
self.status = MemorySystemStatus.READY
|
self.status = original_status
|
||||||
return fused_chunks
|
return fused_chunks
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -469,49 +511,79 @@ class EnhancedMemorySystem:
|
|||||||
|
|
||||||
return context
|
return context
|
||||||
|
|
||||||
def _resolve_conversation_context(self, fallback_text: str, context: Optional[Dict[str, Any]]) -> str:
|
async def _resolve_conversation_context(self, fallback_text: str, context: Optional[Dict[str, Any]]) -> str:
|
||||||
"""使用 stream_id 历史消息充实对话文本,默认回退到传入文本"""
|
"""使用 stream_id 历史消息和相关记忆充实对话文本,默认回退到传入文本"""
|
||||||
if not context:
|
if not context:
|
||||||
return fallback_text
|
return fallback_text
|
||||||
|
|
||||||
|
user_id = context.get("user_id")
|
||||||
stream_id = context.get("stream_id") or context.get("stram_id")
|
stream_id = context.get("stream_id") or context.get("stram_id")
|
||||||
if not stream_id:
|
|
||||||
return fallback_text
|
|
||||||
|
|
||||||
try:
|
# 优先使用 stream_id 获取历史消息
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
if stream_id:
|
||||||
|
try:
|
||||||
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
|
|
||||||
chat_manager = get_chat_manager()
|
chat_manager = get_chat_manager()
|
||||||
chat_stream = chat_manager.get_stream(stream_id)
|
chat_stream = chat_manager.get_stream(stream_id)
|
||||||
if not chat_stream or not hasattr(chat_stream, "context_manager"):
|
if chat_stream and hasattr(chat_stream, "context_manager"):
|
||||||
logger.debug(f"未找到 stream_id={stream_id} 对应的聊天流或上下文管理器")
|
history_limit = self._determine_history_limit(context)
|
||||||
return fallback_text
|
messages = chat_stream.context_manager.get_messages(limit=history_limit, include_unread=True)
|
||||||
|
if messages:
|
||||||
|
transcript = self._format_history_messages(messages)
|
||||||
|
if transcript:
|
||||||
|
cleaned_fallback = (fallback_text or "").strip()
|
||||||
|
if cleaned_fallback and cleaned_fallback not in transcript:
|
||||||
|
transcript = f"{transcript}\n[当前消息] {cleaned_fallback}"
|
||||||
|
|
||||||
history_limit = self._determine_history_limit(context)
|
logger.debug(
|
||||||
messages = chat_stream.context_manager.get_messages(limit=history_limit, include_unread=True)
|
"使用 stream_id=%s 的历史消息构建记忆上下文,消息数=%d,限制=%d",
|
||||||
if not messages:
|
stream_id,
|
||||||
logger.debug(f"stream_id={stream_id} 未获取到历史消息")
|
len(messages),
|
||||||
return fallback_text
|
history_limit,
|
||||||
|
)
|
||||||
|
return transcript
|
||||||
|
else:
|
||||||
|
logger.debug(f"stream_id={stream_id} 历史消息格式化失败")
|
||||||
|
else:
|
||||||
|
logger.debug(f"stream_id={stream_id} 未获取到历史消息")
|
||||||
|
else:
|
||||||
|
logger.debug(f"未找到 stream_id={stream_id} 对应的聊天流或上下文管理器")
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(f"获取 stream_id={stream_id} 的历史消息失败: {exc}", exc_info=True)
|
||||||
|
|
||||||
transcript = self._format_history_messages(messages)
|
# 如果无法获取历史消息,尝试检索相关记忆作为上下文
|
||||||
if not transcript:
|
if user_id and fallback_text:
|
||||||
return fallback_text
|
try:
|
||||||
|
relevant_memories = await self.retrieve_memories_for_building(
|
||||||
|
query_text=fallback_text,
|
||||||
|
user_id=user_id,
|
||||||
|
context=context,
|
||||||
|
limit=3
|
||||||
|
)
|
||||||
|
|
||||||
cleaned_fallback = (fallback_text or "").strip()
|
if relevant_memories:
|
||||||
if cleaned_fallback and cleaned_fallback not in transcript:
|
memory_contexts = []
|
||||||
transcript = f"{transcript}\n[当前消息] {cleaned_fallback}"
|
for memory in relevant_memories:
|
||||||
|
memory_contexts.append(f"[历史记忆] {memory.text_content}")
|
||||||
|
|
||||||
logger.debug(
|
memory_transcript = "\n".join(memory_contexts)
|
||||||
"使用 stream_id=%s 的历史消息构建记忆上下文,消息数=%d,限制=%d",
|
cleaned_fallback = (fallback_text or "").strip()
|
||||||
stream_id,
|
if cleaned_fallback and cleaned_fallback not in memory_transcript:
|
||||||
len(messages),
|
memory_transcript = f"{memory_transcript}\n[当前消息] {cleaned_fallback}"
|
||||||
history_limit,
|
|
||||||
)
|
|
||||||
return transcript
|
|
||||||
|
|
||||||
except Exception as exc:
|
logger.debug(
|
||||||
logger.warning(f"获取 stream_id={stream_id} 的历史消息失败: {exc}", exc_info=True)
|
"使用检索到的历史记忆构建记忆上下文,记忆数=%d,用户=%s",
|
||||||
return fallback_text
|
len(relevant_memories),
|
||||||
|
user_id
|
||||||
|
)
|
||||||
|
return memory_transcript
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(f"检索历史记忆作为上下文失败: {exc}", exc_info=True)
|
||||||
|
|
||||||
|
# 回退到传入文本
|
||||||
|
return fallback_text
|
||||||
|
|
||||||
def _get_build_scope_key(self, context: Dict[str, Any], user_id: Optional[str]) -> Optional[str]:
|
def _get_build_scope_key(self, context: Dict[str, Any], user_id: Optional[str]) -> Optional[str]:
|
||||||
"""确定用于节流控制的记忆构建作用域"""
|
"""确定用于节流控制的记忆构建作用域"""
|
||||||
|
|||||||
@@ -7,7 +7,7 @@
|
|||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
import orjson
|
import orjson
|
||||||
from typing import Dict, List, Optional, Tuple, Any, Set
|
from typing import Dict, List, Optional, Any
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@@ -16,7 +16,7 @@ from src.common.logger import get_logger
|
|||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.chat.memory_system.memory_chunk import (
|
from src.chat.memory_system.memory_chunk import (
|
||||||
MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel,
|
MemoryChunk, MemoryType, ConfidenceLevel, ImportanceLevel,
|
||||||
ContentStructure, MemoryMetadata, create_memory_chunk
|
create_memory_chunk
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
@@ -334,6 +334,28 @@ class MemoryBuilder:
|
|||||||
|
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
def _extract_json_payload(self, response: str) -> Optional[str]:
|
||||||
|
"""从模型响应中提取JSON部分,兼容Markdown代码块等格式"""
|
||||||
|
if not response:
|
||||||
|
return None
|
||||||
|
|
||||||
|
stripped = response.strip()
|
||||||
|
|
||||||
|
# 优先处理Markdown代码块格式 ```json ... ```
|
||||||
|
code_block_match = re.search(r"```(?:json)?\s*(.*?)```", stripped, re.IGNORECASE | re.DOTALL)
|
||||||
|
if code_block_match:
|
||||||
|
candidate = code_block_match.group(1).strip()
|
||||||
|
if candidate:
|
||||||
|
return candidate
|
||||||
|
|
||||||
|
# 回退到查找第一个 JSON 对象的大括号范围
|
||||||
|
start = stripped.find("{")
|
||||||
|
end = stripped.rfind("}")
|
||||||
|
if start != -1 and end != -1 and end > start:
|
||||||
|
return stripped[start:end + 1].strip()
|
||||||
|
|
||||||
|
return stripped if stripped.startswith("{") and stripped.endswith("}") else None
|
||||||
|
|
||||||
def _parse_llm_response(
|
def _parse_llm_response(
|
||||||
self,
|
self,
|
||||||
response: str,
|
response: str,
|
||||||
@@ -345,7 +367,13 @@ class MemoryBuilder:
|
|||||||
memories = []
|
memories = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data = orjson.loads(response)
|
# 提取JSON负载
|
||||||
|
json_payload = self._extract_json_payload(response)
|
||||||
|
if not json_payload:
|
||||||
|
logger.error("未在响应中找到有效的JSON负载")
|
||||||
|
return memories
|
||||||
|
|
||||||
|
data = orjson.loads(json_payload)
|
||||||
memory_list = data.get("memories", [])
|
memory_list = data.get("memories", [])
|
||||||
|
|
||||||
for mem_data in memory_list:
|
for mem_data in memory_list:
|
||||||
@@ -375,7 +403,8 @@ class MemoryBuilder:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"解析LLM响应失败: {e}, 响应: {response}")
|
preview = response[:200] if response else "空响应"
|
||||||
|
logger.error(f"解析LLM响应失败: {e}, 响应片段: {preview}")
|
||||||
|
|
||||||
return memories
|
return memories
|
||||||
|
|
||||||
@@ -623,7 +652,7 @@ class MemoryBuilder:
|
|||||||
try:
|
try:
|
||||||
# 尝试解析回字典(如果原来是字典)
|
# 尝试解析回字典(如果原来是字典)
|
||||||
memory.content.object = eval(obj_str) if obj_str.startswith('{') else obj_str
|
memory.content.object = eval(obj_str) if obj_str.startswith('{') else obj_str
|
||||||
except:
|
except Exception:
|
||||||
memory.content.object = obj_str
|
memory.content.object = obj_str
|
||||||
|
|
||||||
# 记录时间规范化操作
|
# 记录时间规范化操作
|
||||||
|
|||||||
@@ -47,22 +47,20 @@ class SingleStreamContextManager:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
message: 消息对象
|
message: 消息对象
|
||||||
skip_energy_update: 是否跳过能量更新
|
skip_energy_update: 是否跳过能量更新(兼容参数,当前忽略)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: 是否成功添加
|
bool: 是否成功添加
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
self.context.add_message(message)
|
self.context.add_message(message)
|
||||||
interest_value = await self._calculate_message_interest(message)
|
# 推迟兴趣度计算到分发阶段
|
||||||
message.interest_value = interest_value
|
message.interest_value = getattr(message, "interest_value", None)
|
||||||
self.total_messages += 1
|
self.total_messages += 1
|
||||||
self.last_access_time = time.time()
|
self.last_access_time = time.time()
|
||||||
if not skip_energy_update:
|
# 启动流的循环任务(如果还未启动)
|
||||||
await self._update_stream_energy()
|
await stream_loop_manager.start_stream_loop(self.stream_id)
|
||||||
# 启动流的循环任务(如果还未启动)
|
logger.info(f"添加消息到单流上下文: {self.stream_id} (兴趣度待计算)")
|
||||||
await stream_loop_manager.start_stream_loop(self.stream_id)
|
|
||||||
logger.info(f"添加消息到单流上下文: {self.stream_id} (兴趣度: {interest_value:.3f})")
|
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"添加消息到单流上下文失败 {self.stream_id}: {e}", exc_info=True)
|
logger.error(f"添加消息到单流上下文失败 {self.stream_id}: {e}", exc_info=True)
|
||||||
@@ -80,8 +78,6 @@ class SingleStreamContextManager:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
self.context.update_message_info(message_id, **updates)
|
self.context.update_message_info(message_id, **updates)
|
||||||
if "interest_value" in updates:
|
|
||||||
await self._update_stream_energy()
|
|
||||||
logger.debug(f"更新单流上下文消息: {self.stream_id}/{message_id}")
|
logger.debug(f"更新单流上下文消息: {self.stream_id}/{message_id}")
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -286,18 +282,16 @@ class SingleStreamContextManager:
|
|||||||
try:
|
try:
|
||||||
self.context.add_message(message)
|
self.context.add_message(message)
|
||||||
|
|
||||||
interest_value = await self._calculate_message_interest_async(message)
|
# 推迟兴趣度计算到分发阶段
|
||||||
message.interest_value = interest_value
|
message.interest_value = getattr(message, "interest_value", None)
|
||||||
|
|
||||||
self.total_messages += 1
|
self.total_messages += 1
|
||||||
self.last_access_time = time.time()
|
self.last_access_time = time.time()
|
||||||
|
|
||||||
if not skip_energy_update:
|
# 启动流的循环任务(如果还未启动)
|
||||||
await self._update_stream_energy()
|
await stream_loop_manager.start_stream_loop(self.stream_id)
|
||||||
# 启动流的循环任务(如果还未启动)
|
|
||||||
await stream_loop_manager.start_stream_loop(self.stream_id)
|
|
||||||
|
|
||||||
logger.debug(f"添加消息到单流上下文(异步): {self.stream_id} (兴趣度: {interest_value:.3f})")
|
logger.debug(f"添加消息到单流上下文(异步): {self.stream_id} (兴趣度待计算)")
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"添加消息到单流上下文失败 (async) {self.stream_id}: {e}", exc_info=True)
|
logger.error(f"添加消息到单流上下文失败 (async) {self.stream_id}: {e}", exc_info=True)
|
||||||
@@ -307,8 +301,6 @@ class SingleStreamContextManager:
|
|||||||
"""异步实现的 update_message:更新消息并在需要时 await 能量更新。"""
|
"""异步实现的 update_message:更新消息并在需要时 await 能量更新。"""
|
||||||
try:
|
try:
|
||||||
self.context.update_message_info(message_id, **updates)
|
self.context.update_message_info(message_id, **updates)
|
||||||
if "interest_value" in updates:
|
|
||||||
await self._update_stream_energy()
|
|
||||||
|
|
||||||
logger.debug(f"更新单流上下文消息(异步): {self.stream_id}/{message_id}")
|
logger.debug(f"更新单流上下文消息(异步): {self.stream_id}/{message_id}")
|
||||||
return True
|
return True
|
||||||
@@ -339,27 +331,31 @@ class SingleStreamContextManager:
|
|||||||
logger.error(f"清空单流上下文失败 (async) {self.stream_id}: {e}", exc_info=True)
|
logger.error(f"清空单流上下文失败 (async) {self.stream_id}: {e}", exc_info=True)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def _update_stream_energy(self):
|
async def refresh_focus_energy_from_history(self) -> None:
|
||||||
|
"""基于历史消息刷新聚焦能量"""
|
||||||
|
await self._update_stream_energy(include_unread=False)
|
||||||
|
|
||||||
|
async def _update_stream_energy(self, include_unread: bool = False) -> None:
|
||||||
"""更新流能量"""
|
"""更新流能量"""
|
||||||
try:
|
try:
|
||||||
# 获取所有消息
|
history_messages = self.context.get_history_messages(limit=self.max_context_size)
|
||||||
all_messages = self.get_messages(self.max_context_size)
|
messages: List[DatabaseMessages] = list(history_messages)
|
||||||
unread_messages = self.get_unread_messages()
|
|
||||||
combined_messages = all_messages + unread_messages
|
|
||||||
|
|
||||||
# 获取用户ID
|
if include_unread:
|
||||||
|
messages.extend(self.get_unread_messages())
|
||||||
|
|
||||||
|
# 获取用户ID(优先使用最新历史消息)
|
||||||
user_id = None
|
user_id = None
|
||||||
if combined_messages:
|
if messages:
|
||||||
last_message = combined_messages[-1]
|
last_message = messages[-1]
|
||||||
user_id = last_message.user_info.user_id
|
if hasattr(last_message, "user_info") and last_message.user_info:
|
||||||
|
user_id = last_message.user_info.user_id
|
||||||
|
|
||||||
# 计算能量
|
await energy_manager.calculate_focus_energy(
|
||||||
energy = await energy_manager.calculate_focus_energy(
|
stream_id=self.stream_id,
|
||||||
stream_id=self.stream_id, messages=combined_messages, user_id=user_id
|
messages=messages,
|
||||||
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 更新流循环管理器
|
|
||||||
# 注意:能量更新会通过energy_manager自动同步到流循环管理器
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"更新单流能量失败 {self.stream_id}: {e}")
|
logger.error(f"更新单流能量失败 {self.stream_id}: {e}")
|
||||||
|
|||||||
@@ -246,6 +246,7 @@ class StreamLoopManager:
|
|||||||
success = results.get("success", False)
|
success = results.get("success", False)
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
|
await self._refresh_focus_energy(stream_id)
|
||||||
process_time = time.time() - start_time
|
process_time = time.time() - start_time
|
||||||
logger.debug(f"流处理成功: {stream_id} (耗时: {process_time:.2f}s)")
|
logger.debug(f"流处理成功: {stream_id} (耗时: {process_time:.2f}s)")
|
||||||
else:
|
else:
|
||||||
@@ -339,6 +340,20 @@ class StreamLoopManager:
|
|||||||
"max_concurrent_streams": self.max_concurrent_streams,
|
"max_concurrent_streams": self.max_concurrent_streams,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async def _refresh_focus_energy(self, stream_id: str) -> None:
|
||||||
|
"""分发完成后基于历史消息刷新能量值"""
|
||||||
|
try:
|
||||||
|
chat_manager = get_chat_manager()
|
||||||
|
chat_stream = chat_manager.get_stream(stream_id)
|
||||||
|
if not chat_stream:
|
||||||
|
logger.debug(f"刷新能量时未找到聊天流: {stream_id}")
|
||||||
|
return
|
||||||
|
|
||||||
|
await chat_stream.context_manager.refresh_focus_energy_from_history()
|
||||||
|
logger.debug(f"已刷新聊天流 {stream_id} 的聚焦能量")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"刷新聊天流 {stream_id} 能量失败: {e}")
|
||||||
|
|
||||||
|
|
||||||
# 全局流循环管理器实例
|
# 全局流循环管理器实例
|
||||||
stream_loop_manager = StreamLoopManager()
|
stream_loop_manager = StreamLoopManager()
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
from typing import Dict, Optional, Any, TYPE_CHECKING
|
from typing import Dict, Optional, Any, TYPE_CHECKING, List
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
@@ -125,6 +125,44 @@ class MessageManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"更新消息 {message_id} 时发生错误: {e}")
|
logger.error(f"更新消息 {message_id} 时发生错误: {e}")
|
||||||
|
|
||||||
|
async def bulk_update_messages(self, stream_id: str, updates: List[Dict[str, Any]]) -> int:
|
||||||
|
"""批量更新消息信息,降低更新频率"""
|
||||||
|
if not updates:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
chat_manager = get_chat_manager()
|
||||||
|
chat_stream = chat_manager.get_stream(stream_id)
|
||||||
|
if not chat_stream:
|
||||||
|
logger.warning(f"MessageManager.bulk_update_messages: 聊天流 {stream_id} 不存在")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
updated_count = 0
|
||||||
|
for item in updates:
|
||||||
|
message_id = item.get("message_id")
|
||||||
|
if not message_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
key: value
|
||||||
|
for key, value in item.items()
|
||||||
|
if key != "message_id" and value is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
if not payload:
|
||||||
|
continue
|
||||||
|
|
||||||
|
success = await chat_stream.context_manager.update_message(message_id, payload)
|
||||||
|
if success:
|
||||||
|
updated_count += 1
|
||||||
|
|
||||||
|
if updated_count:
|
||||||
|
logger.debug(f"批量更新消息 {updated_count} 条 (stream={stream_id})")
|
||||||
|
return updated_count
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"批量更新聊天流 {stream_id} 消息失败: {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
async def add_action(self, stream_id: str, message_id: str, action: str):
|
async def add_action(self, stream_id: str, message_id: str, action: str):
|
||||||
"""添加动作到消息"""
|
"""添加动作到消息"""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -212,7 +212,11 @@ class MessageStorage:
|
|||||||
return match.group(0)
|
return match.group(0)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def update_message_interest_value(message_id: str, interest_value: float) -> None:
|
async def update_message_interest_value(
|
||||||
|
message_id: str,
|
||||||
|
interest_value: float,
|
||||||
|
should_reply: bool | None = None,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
更新数据库中消息的interest_value字段
|
更新数据库中消息的interest_value字段
|
||||||
|
|
||||||
@@ -223,7 +227,11 @@ class MessageStorage:
|
|||||||
try:
|
try:
|
||||||
async with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
# 更新消息的interest_value字段
|
# 更新消息的interest_value字段
|
||||||
stmt = update(Messages).where(Messages.message_id == message_id).values(interest_value=interest_value)
|
values = {"interest_value": interest_value}
|
||||||
|
if should_reply is not None:
|
||||||
|
values["should_reply"] = should_reply
|
||||||
|
|
||||||
|
stmt = update(Messages).where(Messages.message_id == message_id).values(**values)
|
||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
@@ -236,6 +244,31 @@ class MessageStorage:
|
|||||||
logger.error(f"更新消息 {message_id} 的interest_value失败: {e}")
|
logger.error(f"更新消息 {message_id} 的interest_value失败: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def bulk_update_interest_values(
|
||||||
|
interest_map: dict[str, float],
|
||||||
|
reply_map: dict[str, bool] | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""批量更新消息的兴趣度与回复标记"""
|
||||||
|
if not interest_map:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with get_db_session() as session:
|
||||||
|
for message_id, interest_value in interest_map.items():
|
||||||
|
values = {"interest_value": interest_value}
|
||||||
|
if reply_map and message_id in reply_map:
|
||||||
|
values["should_reply"] = reply_map[message_id]
|
||||||
|
|
||||||
|
stmt = update(Messages).where(Messages.message_id == message_id).values(**values)
|
||||||
|
await session.execute(stmt)
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
logger.debug(f"批量更新兴趣度 {len(interest_map)} 条记录")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"批量更新消息兴趣度失败: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def fix_zero_interest_values(chat_id: str, since_time: float) -> int:
|
async def fix_zero_interest_values(chat_id: str, since_time: float) -> int:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -728,8 +728,6 @@ async def initialize_database():
|
|||||||
"autocommit": config.mysql_autocommit,
|
"autocommit": config.mysql_autocommit,
|
||||||
"charset": config.mysql_charset,
|
"charset": config.mysql_charset,
|
||||||
"connect_timeout": config.connection_timeout,
|
"connect_timeout": config.connection_timeout,
|
||||||
"read_timeout": 30,
|
|
||||||
"write_timeout": 30,
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
|
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from src.plugins.built_in.affinity_flow_chatter.plan_executor import ChatterPlanExecutor
|
from src.plugins.built_in.affinity_flow_chatter.plan_executor import ChatterPlanExecutor
|
||||||
from src.plugins.built_in.affinity_flow_chatter.plan_filter import ChatterPlanFilter
|
from src.plugins.built_in.affinity_flow_chatter.plan_filter import ChatterPlanFilter
|
||||||
@@ -104,45 +104,33 @@ class ChatterActionPlanner:
|
|||||||
score = 0.0
|
score = 0.0
|
||||||
should_reply = False
|
should_reply = False
|
||||||
reply_not_available = False
|
reply_not_available = False
|
||||||
|
interest_updates: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
if unread_messages:
|
if unread_messages:
|
||||||
# 为每条消息计算兴趣度
|
# 为每条消息计算兴趣度,并延迟提交数据库更新
|
||||||
for message in unread_messages:
|
for message in unread_messages:
|
||||||
try:
|
try:
|
||||||
# 使用插件内部的兴趣度评分系统计算
|
|
||||||
interest_score = await chatter_interest_scoring_system._calculate_single_message_score(
|
interest_score = await chatter_interest_scoring_system._calculate_single_message_score(
|
||||||
message=message,
|
message=message,
|
||||||
bot_nickname=global_config.bot.nickname
|
bot_nickname=global_config.bot.nickname,
|
||||||
)
|
)
|
||||||
message_interest = interest_score.total_score
|
message_interest = interest_score.total_score
|
||||||
|
|
||||||
# 更新消息的兴趣度
|
|
||||||
message.interest_value = message_interest
|
message.interest_value = message_interest
|
||||||
|
|
||||||
# 简单的回复决策逻辑:兴趣度超过阈值则回复
|
|
||||||
message.should_reply = message_interest > global_config.affinity_flow.non_reply_action_interest_threshold
|
message.should_reply = message_interest > global_config.affinity_flow.non_reply_action_interest_threshold
|
||||||
|
|
||||||
logger.info(f"消息 {message.message_id} 兴趣度: {message_interest:.3f}, 应回复: {message.should_reply}")
|
interest_updates.append(
|
||||||
|
{
|
||||||
|
"message_id": message.message_id,
|
||||||
|
"interest_value": message_interest,
|
||||||
|
"should_reply": message.should_reply,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# 更新StreamContext中的消息信息并刷新focus_energy
|
logger.debug(
|
||||||
if context:
|
f"消息 {message.message_id} 兴趣度: {message_interest:.3f}, 应回复: {message.should_reply}"
|
||||||
from src.chat.message_manager.message_manager import message_manager
|
)
|
||||||
await message_manager.update_message(
|
|
||||||
stream_id=self.chat_id,
|
|
||||||
message_id=message.message_id,
|
|
||||||
interest_value=message_interest,
|
|
||||||
should_reply=message.should_reply
|
|
||||||
)
|
|
||||||
|
|
||||||
# 更新数据库中的消息记录
|
|
||||||
try:
|
|
||||||
from src.chat.message_receive.storage import MessageStorage
|
|
||||||
await MessageStorage.update_message_interest_value(message.message_id, message_interest)
|
|
||||||
logger.info(f"已更新数据库中消息 {message.message_id} 的兴趣度为: {message_interest:.3f}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"更新数据库消息兴趣度失败: {e}")
|
|
||||||
|
|
||||||
# 记录最高分
|
|
||||||
if message_interest > score:
|
if message_interest > score:
|
||||||
score = message_interest
|
score = message_interest
|
||||||
if message.should_reply:
|
if message.should_reply:
|
||||||
@@ -152,9 +140,18 @@ class ChatterActionPlanner:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"计算消息 {message.message_id} 兴趣度失败: {e}")
|
logger.warning(f"计算消息 {message.message_id} 兴趣度失败: {e}")
|
||||||
# 设置默认值
|
|
||||||
message.interest_value = 0.0
|
message.interest_value = 0.0
|
||||||
message.should_reply = False
|
message.should_reply = False
|
||||||
|
interest_updates.append(
|
||||||
|
{
|
||||||
|
"message_id": message.message_id,
|
||||||
|
"interest_value": 0.0,
|
||||||
|
"should_reply": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if interest_updates:
|
||||||
|
await self._commit_interest_updates(interest_updates)
|
||||||
|
|
||||||
# 检查兴趣度是否达到非回复动作阈值
|
# 检查兴趣度是否达到非回复动作阈值
|
||||||
non_reply_action_interest_threshold = global_config.affinity_flow.non_reply_action_interest_threshold
|
non_reply_action_interest_threshold = global_config.affinity_flow.non_reply_action_interest_threshold
|
||||||
@@ -194,6 +191,33 @@ class ChatterActionPlanner:
|
|||||||
self.planner_stats["failed_plans"] += 1
|
self.planner_stats["failed_plans"] += 1
|
||||||
return [], None
|
return [], None
|
||||||
|
|
||||||
|
async def _commit_interest_updates(self, updates: List[Dict[str, Any]]) -> None:
|
||||||
|
"""统一更新消息兴趣度,减少数据库写入次数"""
|
||||||
|
if not updates:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
from src.chat.message_manager.message_manager import message_manager
|
||||||
|
|
||||||
|
await message_manager.bulk_update_messages(self.chat_id, updates)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"批量更新上下文消息兴趣度失败: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from src.chat.message_receive.storage import MessageStorage
|
||||||
|
|
||||||
|
interest_map = {item["message_id"]: item["interest_value"] for item in updates if "interest_value" in item}
|
||||||
|
reply_map = {item["message_id"]: item["should_reply"] for item in updates if "should_reply" in item}
|
||||||
|
|
||||||
|
await MessageStorage.bulk_update_interest_values(
|
||||||
|
interest_map=interest_map,
|
||||||
|
reply_map=reply_map if reply_map else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"已批量更新 {len(interest_map)} 条消息的兴趣度")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"批量更新数据库兴趣度失败: {e}")
|
||||||
|
|
||||||
def _update_stats_from_execution_result(self, execution_result: Dict[str, any]):
|
def _update_stats_from_execution_result(self, execution_result: Dict[str, any]):
|
||||||
"""根据执行结果更新规划器统计"""
|
"""根据执行结果更新规划器统计"""
|
||||||
if not execution_result:
|
if not execution_result:
|
||||||
|
|||||||
Reference in New Issue
Block a user