添加全局notice管理器,将notice消息与普通消息分离处理。主要功能包括: - 创建 GlobalNoticeManager 单例类,支持公共和特定聊天流作用域 - 在 message_manager 中集成notice检测和处理逻辑 - 扩展数据库模型和消息类,添加notice相关字段 - 在提示词生成器中添加notice信息块展示 - 配置系统支持notice相关参数设置 - 适配器插件增强notice类型识别和配置 notice消息特点: - 默认不触发聊天流程,只记录到全局管理器 - 可在提示词中展示最近的系统通知 - 支持按类型设置不同的生存时间 - 支持公共notice(所有聊天可见)和流特定notice BREAKING CHANGE: 数据库消息表结构变更,需要添加 is_public_notice 和 notice_type 字段
924 lines
41 KiB
Python
924 lines
41 KiB
Python
import asyncio
|
||
import copy
|
||
import hashlib
|
||
import time
|
||
from typing import TYPE_CHECKING
|
||
|
||
from maim_message import GroupInfo, UserInfo
|
||
from rich.traceback import install
|
||
from sqlalchemy import select
|
||
from sqlalchemy.dialects.mysql import insert as mysql_insert
|
||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||
|
||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||
from src.common.database.sqlalchemy_models import ChatStreams # 新增导入
|
||
from src.common.logger import get_logger
|
||
from src.config.config import global_config # 新增导入
|
||
|
||
# 避免循环导入,使用TYPE_CHECKING进行类型提示
|
||
if TYPE_CHECKING:
|
||
from .message import MessageRecv
|
||
|
||
|
||
install(extra_lines=3)
|
||
|
||
|
||
logger = get_logger("chat_stream")
|
||
|
||
|
||
class ChatStream:
|
||
"""聊天流对象,存储一个完整的聊天上下文"""
|
||
|
||
def __init__(
|
||
self,
|
||
stream_id: str,
|
||
platform: str,
|
||
user_info: UserInfo,
|
||
group_info: GroupInfo | None = None,
|
||
data: dict | None = None,
|
||
):
|
||
self.stream_id = stream_id
|
||
self.platform = platform
|
||
self.user_info = user_info
|
||
self.group_info = group_info
|
||
self.create_time = data.get("create_time", time.time()) if data else time.time()
|
||
self.last_active_time = data.get("last_active_time", self.create_time) if data else self.create_time
|
||
self.sleep_pressure = data.get("sleep_pressure", 0.0) if data else 0.0
|
||
self.saved = False
|
||
|
||
# 使用StreamContext替代ChatMessageContext
|
||
from src.common.data_models.message_manager_data_model import StreamContext
|
||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||
|
||
# 创建StreamContext
|
||
self.stream_context: StreamContext = StreamContext(
|
||
stream_id=stream_id, chat_type=ChatType.GROUP if group_info else ChatType.PRIVATE, chat_mode=ChatMode.NORMAL
|
||
)
|
||
|
||
# 创建单流上下文管理器
|
||
from src.chat.message_manager.context_manager import SingleStreamContextManager
|
||
|
||
self.context_manager: SingleStreamContextManager = SingleStreamContextManager(
|
||
stream_id=stream_id, context=self.stream_context
|
||
)
|
||
|
||
# 基础参数
|
||
self.base_interest_energy = 0.5 # 默认基础兴趣度
|
||
self._focus_energy = 0.5 # 内部存储的focus_energy值
|
||
self.no_reply_consecutive = 0
|
||
|
||
def __deepcopy__(self, memo):
|
||
"""自定义深拷贝方法,避免复制不可序列化的 asyncio.Task 对象"""
|
||
import copy
|
||
|
||
# 创建新的实例
|
||
new_stream = ChatStream(
|
||
stream_id=self.stream_id,
|
||
platform=self.platform,
|
||
user_info=copy.deepcopy(self.user_info, memo),
|
||
group_info=copy.deepcopy(self.group_info, memo),
|
||
)
|
||
|
||
# 复制基本属性
|
||
new_stream.create_time = self.create_time
|
||
new_stream.last_active_time = self.last_active_time
|
||
new_stream.sleep_pressure = self.sleep_pressure
|
||
new_stream.saved = self.saved
|
||
new_stream.base_interest_energy = self.base_interest_energy
|
||
new_stream._focus_energy = self._focus_energy
|
||
new_stream.no_reply_consecutive = self.no_reply_consecutive
|
||
|
||
# 复制 stream_context,但跳过 processing_task
|
||
new_stream.stream_context = copy.deepcopy(self.stream_context, memo)
|
||
if hasattr(new_stream.stream_context, "processing_task"):
|
||
new_stream.stream_context.processing_task = None
|
||
|
||
# 复制 context_manager
|
||
new_stream.context_manager = copy.deepcopy(self.context_manager, memo)
|
||
|
||
return new_stream
|
||
|
||
def to_dict(self) -> dict:
|
||
"""转换为字典格式"""
|
||
return {
|
||
"stream_id": self.stream_id,
|
||
"platform": self.platform,
|
||
"user_info": self.user_info.to_dict() if self.user_info else None,
|
||
"group_info": self.group_info.to_dict() if self.group_info else None,
|
||
"create_time": self.create_time,
|
||
"last_active_time": self.last_active_time,
|
||
"sleep_pressure": self.sleep_pressure,
|
||
"focus_energy": self.focus_energy,
|
||
# 基础兴趣度
|
||
"base_interest_energy": self.base_interest_energy,
|
||
# stream_context基本信息
|
||
"stream_context_chat_type": self.stream_context.chat_type.value,
|
||
"stream_context_chat_mode": self.stream_context.chat_mode.value,
|
||
# 统计信息
|
||
"interruption_count": self.stream_context.interruption_count,
|
||
}
|
||
|
||
@classmethod
|
||
def from_dict(cls, data: dict) -> "ChatStream":
|
||
"""从字典创建实例"""
|
||
user_info = UserInfo.from_dict(data.get("user_info", {})) if data.get("user_info") else None
|
||
group_info = GroupInfo.from_dict(data.get("group_info", {})) if data.get("group_info") else None
|
||
|
||
instance = cls(
|
||
stream_id=data["stream_id"],
|
||
platform=data["platform"],
|
||
user_info=user_info, # type: ignore
|
||
group_info=group_info,
|
||
data=data,
|
||
)
|
||
|
||
# 恢复stream_context信息
|
||
if "stream_context_chat_type" in data:
|
||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||
|
||
instance.stream_context.chat_type = ChatType(data["stream_context_chat_type"])
|
||
if "stream_context_chat_mode" in data:
|
||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||
|
||
instance.stream_context.chat_mode = ChatMode(data["stream_context_chat_mode"])
|
||
|
||
# 恢复interruption_count信息
|
||
if "interruption_count" in data:
|
||
instance.stream_context.interruption_count = data["interruption_count"]
|
||
|
||
# 确保 context_manager 已初始化
|
||
if not hasattr(instance, "context_manager"):
|
||
from src.chat.message_manager.context_manager import SingleStreamContextManager
|
||
|
||
instance.context_manager = SingleStreamContextManager(
|
||
stream_id=instance.stream_id, context=instance.stream_context
|
||
)
|
||
|
||
return instance
|
||
|
||
def update_active_time(self):
|
||
"""更新最后活跃时间"""
|
||
self.last_active_time = time.time()
|
||
self.saved = False
|
||
|
||
async def set_context(self, message: "MessageRecv"):
|
||
"""设置聊天消息上下文"""
|
||
# 将MessageRecv转换为DatabaseMessages并设置到stream_context
|
||
import json
|
||
|
||
from src.common.data_models.database_data_model import DatabaseMessages
|
||
|
||
# 安全获取message_info中的数据
|
||
message_info = getattr(message, "message_info", {})
|
||
user_info = getattr(message_info, "user_info", {})
|
||
group_info = getattr(message_info, "group_info", {})
|
||
|
||
# 提取reply_to信息(从message_segment中查找reply类型的段)
|
||
reply_to = None
|
||
if hasattr(message, "message_segment") and message.message_segment:
|
||
reply_to = self._extract_reply_from_segment(message.message_segment)
|
||
|
||
# 完整的数据转移逻辑
|
||
db_message = DatabaseMessages(
|
||
# 基础消息信息
|
||
message_id=getattr(message, "message_id", ""),
|
||
time=getattr(message, "time", time.time()),
|
||
chat_id=self._generate_chat_id(message_info),
|
||
reply_to=reply_to,
|
||
# 兴趣度相关
|
||
interest_value=getattr(message, "interest_value", 0.0),
|
||
# 关键词
|
||
key_words=json.dumps(getattr(message, "key_words", []), ensure_ascii=False)
|
||
if getattr(message, "key_words", None)
|
||
else None,
|
||
key_words_lite=json.dumps(getattr(message, "key_words_lite", []), ensure_ascii=False)
|
||
if getattr(message, "key_words_lite", None)
|
||
else None,
|
||
# 消息状态标记
|
||
is_mentioned=getattr(message, "is_mentioned", None),
|
||
is_at=getattr(message, "is_at", False),
|
||
is_emoji=getattr(message, "is_emoji", False),
|
||
is_picid=getattr(message, "is_picid", False),
|
||
is_voice=getattr(message, "is_voice", False),
|
||
is_video=getattr(message, "is_video", False),
|
||
is_command=getattr(message, "is_command", False),
|
||
is_notify=getattr(message, "is_notify", False),
|
||
is_public_notice=getattr(message, "is_public_notice", False),
|
||
notice_type=getattr(message, "notice_type", None),
|
||
# 消息内容
|
||
processed_plain_text=getattr(message, "processed_plain_text", ""),
|
||
display_message=getattr(message, "processed_plain_text", ""), # 默认使用processed_plain_text
|
||
# 优先级信息
|
||
priority_mode=getattr(message, "priority_mode", None),
|
||
priority_info=json.dumps(getattr(message, "priority_info", None))
|
||
if getattr(message, "priority_info", None)
|
||
else None,
|
||
# 额外配置
|
||
additional_config=getattr(message_info, "additional_config", None),
|
||
# 用户信息
|
||
user_id=str(getattr(user_info, "user_id", "")),
|
||
user_nickname=getattr(user_info, "user_nickname", ""),
|
||
user_cardname=getattr(user_info, "user_cardname", None),
|
||
user_platform=getattr(user_info, "platform", ""),
|
||
# 群组信息
|
||
chat_info_group_id=getattr(group_info, "group_id", None),
|
||
chat_info_group_name=getattr(group_info, "group_name", None),
|
||
chat_info_group_platform=getattr(group_info, "platform", None),
|
||
# 聊天流信息
|
||
chat_info_user_id=str(getattr(user_info, "user_id", "")),
|
||
chat_info_user_nickname=getattr(user_info, "user_nickname", ""),
|
||
chat_info_user_cardname=getattr(user_info, "user_cardname", None),
|
||
chat_info_user_platform=getattr(user_info, "platform", ""),
|
||
chat_info_stream_id=self.stream_id,
|
||
chat_info_platform=self.platform,
|
||
chat_info_create_time=self.create_time,
|
||
chat_info_last_active_time=self.last_active_time,
|
||
# 新增兴趣度系统字段 - 添加安全处理
|
||
actions=self._safe_get_actions(message),
|
||
should_reply=getattr(message, "should_reply", False),
|
||
should_act=getattr(message, "should_act", False),
|
||
)
|
||
|
||
self.stream_context.set_current_message(db_message)
|
||
self.stream_context.priority_mode = getattr(message, "priority_mode", None)
|
||
self.stream_context.priority_info = getattr(message, "priority_info", None)
|
||
|
||
# 调试日志:记录数据转移情况
|
||
logger.debug(
|
||
f"消息数据转移完成 - message_id: {db_message.message_id}, "
|
||
f"chat_id: {db_message.chat_id}, "
|
||
f"is_mentioned: {db_message.is_mentioned}, "
|
||
f"is_emoji: {db_message.is_emoji}, "
|
||
f"is_picid: {db_message.is_picid}, "
|
||
f"interest_value: {db_message.interest_value}"
|
||
)
|
||
|
||
def _safe_get_actions(self, message: "MessageRecv") -> list | None:
|
||
"""安全获取消息的actions字段"""
|
||
try:
|
||
actions = getattr(message, "actions", None)
|
||
if actions is None:
|
||
return None
|
||
|
||
# 如果是字符串,尝试解析为JSON
|
||
if isinstance(actions, str):
|
||
try:
|
||
import json
|
||
|
||
actions = json.loads(actions)
|
||
except json.JSONDecodeError:
|
||
logger.warning(f"无法解析actions JSON字符串: {actions}")
|
||
return None
|
||
|
||
# 确保返回列表类型
|
||
if isinstance(actions, list):
|
||
# 过滤掉空值和非字符串元素
|
||
filtered_actions = [action for action in actions if action is not None and isinstance(action, str)]
|
||
return filtered_actions if filtered_actions else None
|
||
else:
|
||
logger.warning(f"actions字段类型不支持: {type(actions)}")
|
||
return None
|
||
|
||
except Exception as e:
|
||
logger.warning(f"获取actions字段失败: {e}")
|
||
return None
|
||
|
||
async def _calculate_message_interest(self, db_message):
|
||
"""计算消息兴趣值并更新消息对象"""
|
||
try:
|
||
from src.chat.interest_system.interest_manager import get_interest_manager
|
||
|
||
interest_manager = get_interest_manager()
|
||
|
||
if interest_manager.has_calculator():
|
||
# 使用兴趣值计算组件计算
|
||
result = await interest_manager.calculate_interest(db_message)
|
||
|
||
if result.success:
|
||
# 更新消息对象的兴趣值相关字段
|
||
db_message.interest_value = result.interest_value
|
||
db_message.should_reply = result.should_reply
|
||
db_message.should_act = result.should_act
|
||
|
||
logger.debug(
|
||
f"消息 {db_message.message_id} 兴趣值已更新: {result.interest_value:.3f}, "
|
||
f"should_reply: {result.should_reply}, should_act: {result.should_act}"
|
||
)
|
||
else:
|
||
logger.warning(f"消息 {db_message.message_id} 兴趣值计算失败: {result.error_message}")
|
||
# 使用默认值
|
||
db_message.interest_value = 0.3
|
||
db_message.should_reply = False
|
||
db_message.should_act = False
|
||
else:
|
||
# 没有兴趣值计算组件,抛出异常
|
||
raise RuntimeError("没有可用的兴趣值计算组件")
|
||
|
||
except Exception as e:
|
||
logger.error(f"计算消息兴趣值失败: {e}", exc_info=True)
|
||
# 异常情况下使用默认值
|
||
if hasattr(db_message, "interest_value"):
|
||
db_message.interest_value = 0.3
|
||
if hasattr(db_message, "should_reply"):
|
||
db_message.should_reply = False
|
||
if hasattr(db_message, "should_act"):
|
||
db_message.should_act = False
|
||
|
||
def _extract_reply_from_segment(self, segment) -> str | None:
|
||
"""从消息段中提取reply_to信息"""
|
||
try:
|
||
if hasattr(segment, "type") and segment.type == "seglist":
|
||
# 递归搜索seglist中的reply段
|
||
if hasattr(segment, "data") and segment.data:
|
||
for seg in segment.data:
|
||
reply_id = self._extract_reply_from_segment(seg)
|
||
if reply_id:
|
||
return reply_id
|
||
elif hasattr(segment, "type") and segment.type == "reply":
|
||
# 找到reply段,返回message_id
|
||
return str(segment.data) if segment.data else None
|
||
except Exception as e:
|
||
logger.warning(f"提取reply_to信息失败: {e}")
|
||
return None
|
||
|
||
def _generate_chat_id(self, message_info) -> str:
|
||
"""生成chat_id,基于群组或用户信息"""
|
||
try:
|
||
group_info = getattr(message_info, "group_info", None)
|
||
user_info = getattr(message_info, "user_info", None)
|
||
|
||
if group_info and hasattr(group_info, "group_id") and group_info.group_id:
|
||
# 群聊:使用群组ID
|
||
return f"{self.platform}_{group_info.group_id}"
|
||
elif user_info and hasattr(user_info, "user_id") and user_info.user_id:
|
||
# 私聊:使用用户ID
|
||
return f"{self.platform}_{user_info.user_id}_private"
|
||
else:
|
||
# 默认:使用stream_id
|
||
return self.stream_id
|
||
except Exception as e:
|
||
logger.warning(f"生成chat_id失败: {e}")
|
||
return self.stream_id
|
||
|
||
@property
|
||
def focus_energy(self) -> float:
|
||
"""获取缓存的focus_energy值"""
|
||
if hasattr(self, "_focus_energy"):
|
||
return self._focus_energy
|
||
else:
|
||
return 0.5
|
||
|
||
async def calculate_focus_energy(self) -> float:
|
||
"""异步计算focus_energy"""
|
||
try:
|
||
# 使用单流上下文管理器获取消息
|
||
all_messages = self.context_manager.get_messages(limit=global_config.chat.max_context_size)
|
||
|
||
# 获取用户ID
|
||
user_id = None
|
||
if self.user_info and hasattr(self.user_info, "user_id"):
|
||
user_id = str(self.user_info.user_id)
|
||
|
||
# 使用能量管理器计算
|
||
from src.chat.energy_system import energy_manager
|
||
|
||
energy = await energy_manager.calculate_focus_energy(
|
||
stream_id=self.stream_id, messages=all_messages, user_id=user_id
|
||
)
|
||
|
||
# 更新内部存储
|
||
self._focus_energy = energy
|
||
|
||
logger.debug(f"聊天流 {self.stream_id} 能量: {energy:.3f}")
|
||
return energy
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取focus_energy失败: {e}", exc_info=True)
|
||
# 返回缓存的值或默认值
|
||
if hasattr(self, "_focus_energy"):
|
||
return self._focus_energy
|
||
else:
|
||
return 0.5
|
||
|
||
@focus_energy.setter
|
||
def focus_energy(self, value: float):
|
||
"""设置focus_energy值(主要用于初始化或特殊场景)"""
|
||
self._focus_energy = max(0.0, min(1.0, value))
|
||
|
||
async def _get_user_relationship_score(self) -> float:
|
||
"""获取用户关系分"""
|
||
# 使用统一的评分API
|
||
try:
|
||
from src.plugin_system.apis.scoring_api import scoring_api
|
||
|
||
if self.user_info and hasattr(self.user_info, "user_id"):
|
||
user_id = str(self.user_info.user_id)
|
||
relationship_score = await scoring_api.get_user_relationship_score(user_id)
|
||
logger.debug(f"ChatStream {self.stream_id}: 用户关系分 = {relationship_score:.3f}")
|
||
return relationship_score
|
||
|
||
except Exception as e:
|
||
logger.warning(f"ChatStream {self.stream_id}: 关系分计算失败: {e}")
|
||
|
||
# 默认基础分
|
||
return 0.3
|
||
|
||
|
||
class ChatManager:
|
||
"""聊天管理器,管理所有聊天流"""
|
||
|
||
_instance = None
|
||
_initialized = False
|
||
|
||
def __new__(cls):
|
||
if cls._instance is None:
|
||
cls._instance = super().__new__(cls)
|
||
return cls._instance
|
||
|
||
def __init__(self):
|
||
if not self._initialized:
|
||
self.streams: dict[str, ChatStream] = {} # stream_id -> ChatStream
|
||
self.last_messages: dict[str, "MessageRecv"] = {} # stream_id -> last_message
|
||
# try:
|
||
# async with get_db_session() as session:
|
||
# db.connect(reuse_if_open=True)
|
||
# # 确保 ChatStreams 表存在
|
||
# session.execute(text("CREATE TABLE IF NOT EXISTS chat_streams (stream_id TEXT PRIMARY KEY, platform TEXT, create_time REAL, last_active_time REAL, user_platform TEXT, user_id TEXT, user_nickname TEXT, user_cardname TEXT, group_platform TEXT, group_id TEXT, group_name TEXT)"))
|
||
# await session.commit()
|
||
# except Exception as e:
|
||
# logger.error(f"数据库连接或 ChatStreams 表创建失败: {e}")
|
||
|
||
self._initialized = True
|
||
# 在事件循环中启动初始化
|
||
# asyncio.create_task(self._initialize())
|
||
# # 启动自动保存任务
|
||
# asyncio.create_task(self._auto_save_task())
|
||
|
||
async def _initialize(self):
|
||
"""异步初始化"""
|
||
try:
|
||
await self.load_all_streams()
|
||
logger.info(f"聊天管理器已启动,已加载 {len(self.streams)} 个聊天流")
|
||
except Exception as e:
|
||
logger.error(f"聊天管理器启动失败: {e!s}")
|
||
|
||
async def _auto_save_task(self):
|
||
"""定期自动保存所有聊天流"""
|
||
while True:
|
||
await asyncio.sleep(300) # 每5分钟保存一次
|
||
try:
|
||
await self._save_all_streams()
|
||
logger.info("聊天流自动保存完成")
|
||
except Exception as e:
|
||
logger.error(f"聊天流自动保存失败: {e!s}")
|
||
|
||
def register_message(self, message: "MessageRecv"):
|
||
"""注册消息到聊天流"""
|
||
stream_id = self._generate_stream_id(
|
||
message.message_info.platform, # type: ignore
|
||
message.message_info.user_info,
|
||
message.message_info.group_info,
|
||
)
|
||
self.last_messages[stream_id] = message
|
||
# logger.debug(f"注册消息到聊天流: {stream_id}")
|
||
|
||
@staticmethod
|
||
def _generate_stream_id(platform: str, user_info: UserInfo | None, group_info: GroupInfo | None = None) -> str:
|
||
"""生成聊天流唯一ID"""
|
||
if not user_info and not group_info:
|
||
raise ValueError("用户信息或群组信息必须提供")
|
||
|
||
if group_info:
|
||
# 组合关键信息
|
||
components = [platform, str(group_info.group_id)]
|
||
else:
|
||
components = [platform, str(user_info.user_id), "private"] # type: ignore
|
||
|
||
# 使用SHA-256生成唯一ID
|
||
key = "_".join(components)
|
||
return hashlib.sha256(key.encode()).hexdigest()
|
||
|
||
@staticmethod
|
||
def get_stream_id(platform: str, id: str, is_group: bool = True) -> str:
|
||
"""获取聊天流ID"""
|
||
components = [platform, id] if is_group else [platform, id, "private"]
|
||
key = "_".join(components)
|
||
return hashlib.sha256(key.encode()).hexdigest()
|
||
|
||
async def get_or_create_stream(
|
||
self, platform: str, user_info: UserInfo, group_info: GroupInfo | None = None
|
||
) -> ChatStream:
|
||
"""获取或创建聊天流 - 优化版本使用缓存管理器
|
||
|
||
Args:
|
||
platform: 平台标识
|
||
user_info: 用户信息
|
||
group_info: 群组信息(可选)
|
||
|
||
Returns:
|
||
ChatStream: 聊天流对象
|
||
"""
|
||
# 生成stream_id
|
||
try:
|
||
stream_id = self._generate_stream_id(platform, user_info, group_info)
|
||
|
||
# 优先使用缓存管理器(优化版本)
|
||
try:
|
||
from src.chat.message_manager.stream_cache_manager import get_stream_cache_manager
|
||
|
||
cache_manager = get_stream_cache_manager()
|
||
|
||
if cache_manager.is_running:
|
||
optimized_stream = await cache_manager.get_or_create_stream(
|
||
stream_id=stream_id, platform=platform, user_info=user_info, group_info=group_info
|
||
)
|
||
|
||
# 设置消息上下文
|
||
from .message import MessageRecv
|
||
|
||
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
|
||
optimized_stream.set_context(self.last_messages[stream_id])
|
||
|
||
# 转换为原始ChatStream以保持兼容性
|
||
original_stream = self._convert_to_original_stream(optimized_stream)
|
||
|
||
return original_stream
|
||
|
||
except Exception as e:
|
||
logger.debug(f"缓存管理器获取流失败,使用原始方法: {e}")
|
||
|
||
# 回退到原始方法
|
||
# 检查内存中是否存在
|
||
if stream_id in self.streams:
|
||
stream = self.streams[stream_id]
|
||
|
||
# 更新用户信息和群组信息
|
||
stream.update_active_time()
|
||
stream = copy.deepcopy(stream) # 返回副本以避免外部修改影响缓存
|
||
if user_info and user_info.platform and user_info.user_id:
|
||
stream.user_info = user_info
|
||
if group_info:
|
||
stream.group_info = group_info
|
||
from .message import MessageRecv # 延迟导入,避免循环引用
|
||
|
||
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
|
||
await stream.set_context(self.last_messages[stream_id])
|
||
else:
|
||
logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的")
|
||
return stream
|
||
|
||
# 检查数据库中是否存在
|
||
async def _db_find_stream_async(s_id: str):
|
||
async with get_db_session() as session:
|
||
return (
|
||
(await session.execute(select(ChatStreams).where(ChatStreams.stream_id == s_id)))
|
||
.scalars()
|
||
.first()
|
||
)
|
||
|
||
model_instance = await _db_find_stream_async(stream_id)
|
||
|
||
if model_instance:
|
||
# 从 SQLAlchemy 模型转换回 ChatStream.from_dict 期望的格式
|
||
user_info_data = {
|
||
"platform": model_instance.user_platform,
|
||
"user_id": model_instance.user_id,
|
||
"user_nickname": model_instance.user_nickname,
|
||
"user_cardname": model_instance.user_cardname or "",
|
||
}
|
||
group_info_data = None
|
||
if model_instance and getattr(model_instance, "group_id", None):
|
||
group_info_data = {
|
||
"platform": model_instance.group_platform,
|
||
"group_id": model_instance.group_id,
|
||
"group_name": model_instance.group_name,
|
||
}
|
||
|
||
data_for_from_dict = {
|
||
"stream_id": model_instance.stream_id,
|
||
"platform": model_instance.platform,
|
||
"user_info": user_info_data,
|
||
"group_info": group_info_data,
|
||
"create_time": model_instance.create_time,
|
||
"last_active_time": model_instance.last_active_time,
|
||
"energy_value": model_instance.energy_value,
|
||
"sleep_pressure": model_instance.sleep_pressure,
|
||
}
|
||
stream = ChatStream.from_dict(data_for_from_dict)
|
||
# 更新用户信息和群组信息
|
||
stream.user_info = user_info
|
||
if group_info:
|
||
stream.group_info = group_info
|
||
stream.update_active_time()
|
||
else:
|
||
# 创建新的聊天流
|
||
stream = ChatStream(
|
||
stream_id=stream_id,
|
||
platform=platform,
|
||
user_info=user_info,
|
||
group_info=group_info,
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"获取或创建聊天流失败: {e}", exc_info=True)
|
||
raise e
|
||
|
||
stream = copy.deepcopy(stream)
|
||
from .message import MessageRecv # 延迟导入,避免循环引用
|
||
|
||
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
|
||
await stream.set_context(self.last_messages[stream_id])
|
||
else:
|
||
logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的")
|
||
|
||
# 确保 ChatStream 有自己的 context_manager
|
||
if not hasattr(stream, "context_manager"):
|
||
# 创建新的单流上下文管理器
|
||
from src.chat.message_manager.context_manager import SingleStreamContextManager
|
||
|
||
stream.context_manager = SingleStreamContextManager(stream_id=stream_id, context=stream.stream_context)
|
||
|
||
# 保存到内存和数据库
|
||
self.streams[stream_id] = stream
|
||
await self._save_stream(stream)
|
||
return stream
|
||
|
||
async def get_stream(self, stream_id: str) -> ChatStream | None:
|
||
"""通过stream_id获取聊天流"""
|
||
stream = self.streams.get(stream_id)
|
||
if not stream:
|
||
return None
|
||
if stream_id in self.last_messages:
|
||
await stream.set_context(self.last_messages[stream_id])
|
||
return stream
|
||
|
||
def get_stream_by_info(
|
||
self, platform: str, user_info: UserInfo, group_info: GroupInfo | None = None
|
||
) -> ChatStream | None:
|
||
"""通过信息获取聊天流"""
|
||
stream_id = self._generate_stream_id(platform, user_info, group_info)
|
||
return self.streams.get(stream_id)
|
||
|
||
async def get_stream_name(self, stream_id: str) -> str | None:
|
||
"""根据 stream_id 获取聊天流名称"""
|
||
stream = await self.get_stream(stream_id)
|
||
if not stream:
|
||
return None
|
||
|
||
if stream.group_info and stream.group_info.group_name:
|
||
return stream.group_info.group_name
|
||
elif stream.user_info and stream.user_info.user_nickname:
|
||
return f"{stream.user_info.user_nickname}的私聊"
|
||
else:
|
||
return None
|
||
|
||
@staticmethod
|
||
def _prepare_stream_data(stream_data_dict: dict) -> dict:
|
||
"""准备聊天流保存数据"""
|
||
user_info_d = stream_data_dict.get("user_info")
|
||
group_info_d = stream_data_dict.get("group_info")
|
||
|
||
return {
|
||
"platform": stream_data_dict["platform"],
|
||
"create_time": stream_data_dict["create_time"],
|
||
"last_active_time": stream_data_dict["last_active_time"],
|
||
"user_platform": user_info_d["platform"] if user_info_d else "",
|
||
"user_id": user_info_d["user_id"] if user_info_d else "",
|
||
"user_nickname": user_info_d["user_nickname"] if user_info_d else "",
|
||
"user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None,
|
||
"group_platform": group_info_d["platform"] if group_info_d else "",
|
||
"group_id": group_info_d["group_id"] if group_info_d else "",
|
||
"group_name": group_info_d["group_name"] if group_info_d else "",
|
||
"energy_value": stream_data_dict.get("energy_value", 5.0),
|
||
"sleep_pressure": stream_data_dict.get("sleep_pressure", 0.0),
|
||
"focus_energy": stream_data_dict.get("focus_energy", 0.5),
|
||
# 新增动态兴趣度系统字段
|
||
"base_interest_energy": stream_data_dict.get("base_interest_energy", 0.5),
|
||
"message_interest_total": stream_data_dict.get("message_interest_total", 0.0),
|
||
"message_count": stream_data_dict.get("message_count", 0),
|
||
"action_count": stream_data_dict.get("action_count", 0),
|
||
"reply_count": stream_data_dict.get("reply_count", 0),
|
||
"last_interaction_time": stream_data_dict.get("last_interaction_time", time.time()),
|
||
"consecutive_no_reply": stream_data_dict.get("consecutive_no_reply", 0),
|
||
"interruption_count": stream_data_dict.get("interruption_count", 0),
|
||
}
|
||
|
||
@staticmethod
|
||
async def _save_stream(stream: ChatStream):
|
||
"""保存聊天流到数据库 - 优化版本使用异步批量写入"""
|
||
if stream.saved:
|
||
return
|
||
stream_data_dict = stream.to_dict()
|
||
|
||
# 优先使用新的批量写入器
|
||
try:
|
||
from src.chat.message_manager.batch_database_writer import get_batch_writer
|
||
|
||
batch_writer = get_batch_writer()
|
||
if batch_writer.is_running:
|
||
success = await batch_writer.schedule_stream_update(
|
||
stream_id=stream_data_dict["stream_id"],
|
||
update_data=ChatManager._prepare_stream_data(stream_data_dict),
|
||
priority=1, # 流更新的优先级
|
||
)
|
||
if success:
|
||
stream.saved = True
|
||
logger.debug(f"聊天流 {stream.stream_id} 通过批量写入器调度成功")
|
||
return
|
||
else:
|
||
logger.warning(f"批量写入器队列已满,使用原始方法: {stream.stream_id}")
|
||
else:
|
||
logger.debug(f"批量写入器未运行,使用原始方法: {stream.stream_id}")
|
||
|
||
except Exception as e:
|
||
logger.debug(f"批量写入器保存聊天流失败,使用原始方法: {e}")
|
||
|
||
# 尝试使用数据库批量调度器(回退方案1)
|
||
try:
|
||
from src.common.database.db_batch_scheduler import batch_update, get_batch_session
|
||
|
||
async with get_batch_session():
|
||
# 使用批量更新
|
||
result = await batch_update(
|
||
model_class=ChatStreams,
|
||
conditions={"stream_id": stream_data_dict["stream_id"]},
|
||
data=ChatManager._prepare_stream_data(stream_data_dict),
|
||
)
|
||
if result and result > 0:
|
||
stream.saved = True
|
||
logger.debug(f"聊天流 {stream.stream_id} 通过批量调度器保存成功")
|
||
return
|
||
except (ImportError, Exception) as e:
|
||
logger.debug(f"批量调度器保存聊天流失败,使用原始方法: {e}")
|
||
|
||
# 回退到原始方法(最终方案)
|
||
async def _db_save_stream_async(s_data_dict: dict):
|
||
async with get_db_session() as session:
|
||
user_info_d = s_data_dict.get("user_info")
|
||
group_info_d = s_data_dict.get("group_info")
|
||
fields_to_save = {
|
||
"platform": s_data_dict["platform"],
|
||
"create_time": s_data_dict["create_time"],
|
||
"last_active_time": s_data_dict["last_active_time"],
|
||
"user_platform": user_info_d["platform"] if user_info_d else "",
|
||
"user_id": user_info_d["user_id"] if user_info_d else "",
|
||
"user_nickname": user_info_d["user_nickname"] if user_info_d else "",
|
||
"user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None,
|
||
"group_platform": group_info_d["platform"] if group_info_d else "",
|
||
"group_id": group_info_d["group_id"] if group_info_d else "",
|
||
"group_name": group_info_d["group_name"] if group_info_d else "",
|
||
"energy_value": s_data_dict.get("energy_value", 5.0),
|
||
"sleep_pressure": s_data_dict.get("sleep_pressure", 0.0),
|
||
"focus_energy": s_data_dict.get("focus_energy", 0.5),
|
||
# 新增动态兴趣度系统字段
|
||
"base_interest_energy": s_data_dict.get("base_interest_energy", 0.5),
|
||
"message_interest_total": s_data_dict.get("message_interest_total", 0.0),
|
||
"message_count": s_data_dict.get("message_count", 0),
|
||
"action_count": s_data_dict.get("action_count", 0),
|
||
"reply_count": s_data_dict.get("reply_count", 0),
|
||
"last_interaction_time": s_data_dict.get("last_interaction_time", time.time()),
|
||
"consecutive_no_reply": s_data_dict.get("consecutive_no_reply", 0),
|
||
"interruption_count": s_data_dict.get("interruption_count", 0),
|
||
}
|
||
if global_config.database.database_type == "sqlite":
|
||
stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
|
||
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=fields_to_save)
|
||
elif global_config.database.database_type == "mysql":
|
||
stmt = mysql_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
|
||
stmt = stmt.on_duplicate_key_update(
|
||
**{key: value for key, value in fields_to_save.items() if key != "stream_id"}
|
||
)
|
||
else:
|
||
stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
|
||
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=fields_to_save)
|
||
await session.execute(stmt)
|
||
await session.commit()
|
||
|
||
try:
|
||
await _db_save_stream_async(stream_data_dict)
|
||
stream.saved = True
|
||
except Exception as e:
|
||
logger.error(f"保存聊天流 {stream.stream_id} 到数据库失败 (SQLAlchemy): {e}", exc_info=True)
|
||
|
||
async def _save_all_streams(self):
|
||
"""保存所有聊天流"""
|
||
for stream in self.streams.values():
|
||
await self._save_stream(stream)
|
||
|
||
async def load_all_streams(self):
|
||
"""从数据库加载所有聊天流"""
|
||
logger.info("正在从数据库加载所有聊天流")
|
||
|
||
async def _db_load_all_streams_async():
|
||
loaded_streams_data = []
|
||
async with get_db_session() as session:
|
||
result = await session.execute(select(ChatStreams))
|
||
for model_instance in result.scalars().all():
|
||
user_info_data = {
|
||
"platform": model_instance.user_platform,
|
||
"user_id": model_instance.user_id,
|
||
"user_nickname": model_instance.user_nickname,
|
||
"user_cardname": model_instance.user_cardname or "",
|
||
}
|
||
group_info_data = None
|
||
if model_instance and getattr(model_instance, "group_id", None):
|
||
group_info_data = {
|
||
"platform": model_instance.group_platform,
|
||
"group_id": model_instance.group_id,
|
||
"group_name": model_instance.group_name,
|
||
}
|
||
data_for_from_dict = {
|
||
"stream_id": model_instance.stream_id,
|
||
"platform": model_instance.platform,
|
||
"user_info": user_info_data,
|
||
"group_info": group_info_data,
|
||
"create_time": model_instance.create_time,
|
||
"last_active_time": model_instance.last_active_time,
|
||
"energy_value": model_instance.energy_value,
|
||
"sleep_pressure": model_instance.sleep_pressure,
|
||
"focus_energy": getattr(model_instance, "focus_energy", 0.5),
|
||
# 新增动态兴趣度系统字段 - 使用getattr提供默认值
|
||
"base_interest_energy": getattr(model_instance, "base_interest_energy", 0.5),
|
||
"message_interest_total": getattr(model_instance, "message_interest_total", 0.0),
|
||
"message_count": getattr(model_instance, "message_count", 0),
|
||
"action_count": getattr(model_instance, "action_count", 0),
|
||
"reply_count": getattr(model_instance, "reply_count", 0),
|
||
"last_interaction_time": getattr(model_instance, "last_interaction_time", time.time()),
|
||
"relationship_score": getattr(model_instance, "relationship_score", 0.3),
|
||
"consecutive_no_reply": getattr(model_instance, "consecutive_no_reply", 0),
|
||
"interruption_count": getattr(model_instance, "interruption_count", 0),
|
||
}
|
||
loaded_streams_data.append(data_for_from_dict)
|
||
await session.commit()
|
||
return loaded_streams_data
|
||
|
||
try:
|
||
all_streams_data_list = await _db_load_all_streams_async()
|
||
self.streams.clear()
|
||
for data in all_streams_data_list:
|
||
stream = ChatStream.from_dict(data)
|
||
stream.saved = True
|
||
self.streams[stream.stream_id] = stream
|
||
# 不在异步加载中设置上下文,避免复杂依赖
|
||
# if stream.stream_id in self.last_messages:
|
||
# await stream.set_context(self.last_messages[stream.stream_id])
|
||
|
||
# 确保 ChatStream 有自己的 context_manager
|
||
if not hasattr(stream, "context_manager"):
|
||
from src.chat.message_manager.context_manager import SingleStreamContextManager
|
||
|
||
stream.context_manager = SingleStreamContextManager(
|
||
stream_id=stream.stream_id, context=stream.stream_context
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"从数据库加载所有聊天流失败 (SQLAlchemy): {e}", exc_info=True)
|
||
|
||
|
||
chat_manager = None
|
||
|
||
|
||
def _convert_to_original_stream(self, optimized_stream) -> "ChatStream":
|
||
"""将OptimizedChatStream转换为原始ChatStream以保持兼容性"""
|
||
try:
|
||
# 创建原始ChatStream实例
|
||
original_stream = ChatStream(
|
||
stream_id=optimized_stream.stream_id,
|
||
platform=optimized_stream.platform,
|
||
user_info=optimized_stream._get_effective_user_info(),
|
||
group_info=optimized_stream._get_effective_group_info(),
|
||
)
|
||
|
||
# 复制状态
|
||
original_stream.create_time = optimized_stream.create_time
|
||
original_stream.last_active_time = optimized_stream.last_active_time
|
||
original_stream.sleep_pressure = optimized_stream.sleep_pressure
|
||
original_stream.base_interest_energy = optimized_stream.base_interest_energy
|
||
original_stream._focus_energy = optimized_stream._focus_energy
|
||
original_stream.no_reply_consecutive = optimized_stream.no_reply_consecutive
|
||
original_stream.saved = optimized_stream.saved
|
||
|
||
# 复制上下文信息(如果存在)
|
||
if hasattr(optimized_stream, "_stream_context") and optimized_stream._stream_context:
|
||
original_stream.stream_context = optimized_stream._stream_context
|
||
|
||
if hasattr(optimized_stream, "_context_manager") and optimized_stream._context_manager:
|
||
original_stream.context_manager = optimized_stream._context_manager
|
||
|
||
return original_stream
|
||
|
||
except Exception as e:
|
||
logger.error(f"转换OptimizedChatStream失败: {e}")
|
||
# 如果转换失败,创建一个新的原始流
|
||
return ChatStream(
|
||
stream_id=optimized_stream.stream_id,
|
||
platform=optimized_stream.platform,
|
||
user_info=optimized_stream._get_effective_user_info(),
|
||
group_info=optimized_stream._get_effective_group_info(),
|
||
)
|
||
|
||
|
||
def get_chat_manager():
|
||
global chat_manager
|
||
if chat_manager is None:
|
||
chat_manager = ChatManager()
|
||
return chat_manager
|