This commit applies automated code formatting across the project. The changes primarily involve removing unnecessary blank lines and ensuring consistent code style, improving readability and maintainability without altering functionality.
450 lines
19 KiB
Python
450 lines
19 KiB
Python
import asyncio
|
||
import hashlib
|
||
import time
|
||
import copy
|
||
from typing import Dict, Optional, TYPE_CHECKING
|
||
from rich.traceback import install
|
||
from maim_message import GroupInfo, UserInfo
|
||
|
||
from src.common.logger import get_logger
|
||
from sqlalchemy import select
|
||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||
from sqlalchemy.dialects.mysql import insert as mysql_insert
|
||
from src.common.database.sqlalchemy_models import ChatStreams # 新增导入
|
||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||
from src.config.config import global_config # 新增导入
|
||
|
||
# 避免循环导入,使用TYPE_CHECKING进行类型提示
|
||
if TYPE_CHECKING:
|
||
from .message import MessageRecv
|
||
|
||
|
||
install(extra_lines=3)
|
||
|
||
|
||
logger = get_logger("chat_stream")
|
||
|
||
|
||
class ChatMessageContext:
|
||
"""聊天消息上下文,存储消息的上下文信息"""
|
||
|
||
def __init__(self, message: "MessageRecv"):
|
||
self.message = message
|
||
|
||
def get_template_name(self) -> Optional[str]:
|
||
"""获取模板名称"""
|
||
if self.message.message_info.template_info and not self.message.message_info.template_info.template_default:
|
||
return self.message.message_info.template_info.template_name # type: ignore
|
||
return None
|
||
|
||
def get_last_message(self) -> "MessageRecv":
|
||
"""获取最后一条消息"""
|
||
return self.message
|
||
|
||
def check_types(self, types: list) -> bool:
|
||
# sourcery skip: invert-any-all, use-any, use-next
|
||
"""检查消息类型"""
|
||
if not self.message.message_info.format_info.accept_format: # type: ignore
|
||
return False
|
||
for t in types:
|
||
if t not in self.message.message_info.format_info.accept_format: # type: ignore
|
||
return False
|
||
return True
|
||
|
||
def get_priority_mode(self) -> str:
|
||
"""获取优先级模式"""
|
||
return self.message.priority_mode
|
||
|
||
def get_priority_info(self) -> Optional[dict]:
|
||
"""获取优先级信息"""
|
||
if hasattr(self.message, "priority_info") and self.message.priority_info:
|
||
return self.message.priority_info
|
||
return None
|
||
|
||
|
||
class ChatStream:
|
||
"""聊天流对象,存储一个完整的聊天上下文"""
|
||
|
||
def __init__(
|
||
self,
|
||
stream_id: str,
|
||
platform: str,
|
||
user_info: UserInfo,
|
||
group_info: Optional[GroupInfo] = None,
|
||
data: Optional[dict] = None,
|
||
):
|
||
self.stream_id = stream_id
|
||
self.platform = platform
|
||
self.user_info = user_info
|
||
self.group_info = group_info
|
||
self.create_time = data.get("create_time", time.time()) if data else time.time()
|
||
self.last_active_time = data.get("last_active_time", self.create_time) if data else self.create_time
|
||
self.energy_value = data.get("energy_value", 5.0) if data else 5.0
|
||
self.sleep_pressure = data.get("sleep_pressure", 0.0) if data else 0.0
|
||
self.saved = False
|
||
self.context: ChatMessageContext = None # type: ignore # 用于存储该聊天的上下文信息
|
||
self.focus_energy = 1
|
||
self.no_reply_consecutive = 0
|
||
self.breaking_accumulated_interest = 0.0
|
||
|
||
def to_dict(self) -> dict:
|
||
"""转换为字典格式"""
|
||
return {
|
||
"stream_id": self.stream_id,
|
||
"platform": self.platform,
|
||
"user_info": self.user_info.to_dict() if self.user_info else None,
|
||
"group_info": self.group_info.to_dict() if self.group_info else None,
|
||
"create_time": self.create_time,
|
||
"last_active_time": self.last_active_time,
|
||
"energy_value": self.energy_value,
|
||
"sleep_pressure": self.sleep_pressure,
|
||
"breaking_accumulated_interest": self.breaking_accumulated_interest,
|
||
}
|
||
|
||
@classmethod
|
||
def from_dict(cls, data: dict) -> "ChatStream":
|
||
"""从字典创建实例"""
|
||
user_info = UserInfo.from_dict(data.get("user_info", {})) if data.get("user_info") else None
|
||
group_info = GroupInfo.from_dict(data.get("group_info", {})) if data.get("group_info") else None
|
||
|
||
return cls(
|
||
stream_id=data["stream_id"],
|
||
platform=data["platform"],
|
||
user_info=user_info, # type: ignore
|
||
group_info=group_info,
|
||
data=data,
|
||
)
|
||
|
||
def update_active_time(self):
|
||
"""更新最后活跃时间"""
|
||
self.last_active_time = time.time()
|
||
self.saved = False
|
||
|
||
def set_context(self, message: "MessageRecv"):
|
||
"""设置聊天消息上下文"""
|
||
self.context = ChatMessageContext(message)
|
||
|
||
|
||
class ChatManager:
|
||
"""聊天管理器,管理所有聊天流"""
|
||
|
||
_instance = None
|
||
_initialized = False
|
||
|
||
def __new__(cls):
|
||
if cls._instance is None:
|
||
cls._instance = super().__new__(cls)
|
||
return cls._instance
|
||
|
||
def __init__(self):
|
||
if not self._initialized:
|
||
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
|
||
self.last_messages: Dict[str, "MessageRecv"] = {} # stream_id -> last_message
|
||
# try:
|
||
# with get_db_session() as session:
|
||
# db.connect(reuse_if_open=True)
|
||
# # 确保 ChatStreams 表存在
|
||
# session.execute(text("CREATE TABLE IF NOT EXISTS chat_streams (stream_id TEXT PRIMARY KEY, platform TEXT, create_time REAL, last_active_time REAL, user_platform TEXT, user_id TEXT, user_nickname TEXT, user_cardname TEXT, group_platform TEXT, group_id TEXT, group_name TEXT)"))
|
||
# session.commit()
|
||
# except Exception as e:
|
||
# logger.error(f"数据库连接或 ChatStreams 表创建失败: {e}")
|
||
|
||
self._initialized = True
|
||
# 在事件循环中启动初始化
|
||
# asyncio.create_task(self._initialize())
|
||
# # 启动自动保存任务
|
||
# asyncio.create_task(self._auto_save_task())
|
||
|
||
async def _initialize(self):
|
||
"""异步初始化"""
|
||
try:
|
||
await self.load_all_streams()
|
||
logger.info(f"聊天管理器已启动,已加载 {len(self.streams)} 个聊天流")
|
||
except Exception as e:
|
||
logger.error(f"聊天管理器启动失败: {str(e)}")
|
||
|
||
async def _auto_save_task(self):
|
||
"""定期自动保存所有聊天流"""
|
||
while True:
|
||
await asyncio.sleep(300) # 每5分钟保存一次
|
||
try:
|
||
await self._save_all_streams()
|
||
logger.info("聊天流自动保存完成")
|
||
except Exception as e:
|
||
logger.error(f"聊天流自动保存失败: {str(e)}")
|
||
|
||
def register_message(self, message: "MessageRecv"):
|
||
"""注册消息到聊天流"""
|
||
stream_id = self._generate_stream_id(
|
||
message.message_info.platform, # type: ignore
|
||
message.message_info.user_info,
|
||
message.message_info.group_info,
|
||
)
|
||
self.last_messages[stream_id] = message
|
||
# logger.debug(f"注册消息到聊天流: {stream_id}")
|
||
|
||
@staticmethod
|
||
def _generate_stream_id(
|
||
platform: str, user_info: Optional[UserInfo], group_info: Optional[GroupInfo] = None
|
||
) -> str:
|
||
"""生成聊天流唯一ID"""
|
||
if not user_info and not group_info:
|
||
raise ValueError("用户信息或群组信息必须提供")
|
||
|
||
if group_info:
|
||
# 组合关键信息
|
||
components = [platform, str(group_info.group_id)]
|
||
else:
|
||
components = [platform, str(user_info.user_id), "private"] # type: ignore
|
||
|
||
# 使用MD5生成唯一ID
|
||
key = "_".join(components)
|
||
return hashlib.md5(key.encode()).hexdigest()
|
||
|
||
def get_stream_id(self, platform: str, id: str, is_group: bool = True) -> str:
|
||
"""获取聊天流ID"""
|
||
components = [platform, id] if is_group else [platform, id, "private"]
|
||
key = "_".join(components)
|
||
return hashlib.md5(key.encode()).hexdigest()
|
||
|
||
async def get_or_create_stream(
|
||
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
|
||
) -> ChatStream:
|
||
"""获取或创建聊天流
|
||
|
||
Args:
|
||
platform: 平台标识
|
||
user_info: 用户信息
|
||
group_info: 群组信息(可选)
|
||
|
||
Returns:
|
||
ChatStream: 聊天流对象
|
||
"""
|
||
# 生成stream_id
|
||
try:
|
||
stream_id = self._generate_stream_id(platform, user_info, group_info)
|
||
|
||
# 检查内存中是否存在
|
||
if stream_id in self.streams:
|
||
stream = self.streams[stream_id]
|
||
|
||
# 更新用户信息和群组信息
|
||
stream.update_active_time()
|
||
stream = copy.deepcopy(stream) # 返回副本以避免外部修改影响缓存
|
||
if user_info.platform and user_info.user_id:
|
||
stream.user_info = user_info
|
||
if group_info:
|
||
stream.group_info = group_info
|
||
from .message import MessageRecv # 延迟导入,避免循环引用
|
||
|
||
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
|
||
stream.set_context(self.last_messages[stream_id])
|
||
else:
|
||
logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的")
|
||
return stream
|
||
|
||
# 检查数据库中是否存在
|
||
def _db_find_stream_sync(s_id: str):
|
||
with get_db_session() as session:
|
||
return session.execute(select(ChatStreams).where(ChatStreams.stream_id == s_id)).scalar()
|
||
|
||
model_instance = await asyncio.to_thread(_db_find_stream_sync, stream_id)
|
||
|
||
if model_instance:
|
||
# 从 Peewee 模型转换回 ChatStream.from_dict 期望的格式
|
||
user_info_data = {
|
||
"platform": model_instance.user_platform,
|
||
"user_id": model_instance.user_id,
|
||
"user_nickname": model_instance.user_nickname,
|
||
"user_cardname": model_instance.user_cardname or "",
|
||
}
|
||
group_info_data = None
|
||
if model_instance and getattr(model_instance, "group_id", None):
|
||
group_info_data = {
|
||
"platform": model_instance.group_platform,
|
||
"group_id": model_instance.group_id,
|
||
"group_name": model_instance.group_name,
|
||
}
|
||
|
||
data_for_from_dict = {
|
||
"stream_id": model_instance.stream_id,
|
||
"platform": model_instance.platform,
|
||
"user_info": user_info_data,
|
||
"group_info": group_info_data,
|
||
"create_time": model_instance.create_time,
|
||
"last_active_time": model_instance.last_active_time,
|
||
"energy_value": model_instance.energy_value,
|
||
"sleep_pressure": model_instance.sleep_pressure,
|
||
}
|
||
stream = ChatStream.from_dict(data_for_from_dict)
|
||
# 更新用户信息和群组信息
|
||
stream.user_info = user_info
|
||
if group_info:
|
||
stream.group_info = group_info
|
||
stream.update_active_time()
|
||
else:
|
||
# 创建新的聊天流
|
||
stream = ChatStream(
|
||
stream_id=stream_id,
|
||
platform=platform,
|
||
user_info=user_info,
|
||
group_info=group_info,
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"获取或创建聊天流失败: {e}", exc_info=True)
|
||
raise e
|
||
|
||
stream = copy.deepcopy(stream)
|
||
from .message import MessageRecv # 延迟导入,避免循环引用
|
||
|
||
if stream_id in self.last_messages and isinstance(self.last_messages[stream_id], MessageRecv):
|
||
stream.set_context(self.last_messages[stream_id])
|
||
else:
|
||
logger.error(f"聊天流 {stream_id} 不在最后消息列表中,可能是新创建的")
|
||
# 保存到内存和数据库
|
||
self.streams[stream_id] = stream
|
||
await self._save_stream(stream)
|
||
return stream
|
||
|
||
def get_stream(self, stream_id: str) -> Optional[ChatStream]:
|
||
"""通过stream_id获取聊天流"""
|
||
stream = self.streams.get(stream_id)
|
||
if not stream:
|
||
return None
|
||
if stream_id in self.last_messages:
|
||
stream.set_context(self.last_messages[stream_id])
|
||
return stream
|
||
|
||
def get_stream_by_info(
|
||
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
|
||
) -> Optional[ChatStream]:
|
||
"""通过信息获取聊天流"""
|
||
stream_id = self._generate_stream_id(platform, user_info, group_info)
|
||
return self.streams.get(stream_id)
|
||
|
||
def get_stream_name(self, stream_id: str) -> Optional[str]:
|
||
"""根据 stream_id 获取聊天流名称"""
|
||
stream = self.get_stream(stream_id)
|
||
if not stream:
|
||
return None
|
||
|
||
if stream.group_info and stream.group_info.group_name:
|
||
return stream.group_info.group_name
|
||
elif stream.user_info and stream.user_info.user_nickname:
|
||
return f"{stream.user_info.user_nickname}的私聊"
|
||
else:
|
||
return None
|
||
|
||
@staticmethod
|
||
async def _save_stream(stream: ChatStream):
|
||
"""保存聊天流到数据库"""
|
||
if stream.saved:
|
||
return
|
||
stream_data_dict = stream.to_dict()
|
||
|
||
def _db_save_stream_sync(s_data_dict: dict):
|
||
with get_db_session() as session:
|
||
user_info_d = s_data_dict.get("user_info")
|
||
group_info_d = s_data_dict.get("group_info")
|
||
|
||
fields_to_save = {
|
||
"platform": s_data_dict["platform"],
|
||
"create_time": s_data_dict["create_time"],
|
||
"last_active_time": s_data_dict["last_active_time"],
|
||
"user_platform": user_info_d["platform"] if user_info_d else "",
|
||
"user_id": user_info_d["user_id"] if user_info_d else "",
|
||
"user_nickname": user_info_d["user_nickname"] if user_info_d else "",
|
||
"user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None,
|
||
"group_platform": group_info_d["platform"] if group_info_d else "",
|
||
"group_id": group_info_d["group_id"] if group_info_d else "",
|
||
"group_name": group_info_d["group_name"] if group_info_d else "",
|
||
"energy_value": s_data_dict.get("energy_value", 5.0),
|
||
"sleep_pressure": s_data_dict.get("sleep_pressure", 0.0),
|
||
}
|
||
|
||
# 根据数据库类型选择插入语句
|
||
if global_config.database.database_type == "sqlite":
|
||
stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
|
||
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=fields_to_save)
|
||
elif global_config.database.database_type == "mysql":
|
||
stmt = mysql_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
|
||
stmt = stmt.on_duplicate_key_update(
|
||
**{key: value for key, value in fields_to_save.items() if key != "stream_id"}
|
||
)
|
||
else:
|
||
# 默认使用通用插入,尝试SQLite语法
|
||
stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
|
||
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=fields_to_save)
|
||
|
||
session.execute(stmt)
|
||
session.commit()
|
||
|
||
try:
|
||
await asyncio.to_thread(_db_save_stream_sync, stream_data_dict)
|
||
stream.saved = True
|
||
except Exception as e:
|
||
logger.error(f"保存聊天流 {stream.stream_id} 到数据库失败 (Peewee): {e}", exc_info=True)
|
||
|
||
async def _save_all_streams(self):
|
||
"""保存所有聊天流"""
|
||
for stream in self.streams.values():
|
||
await self._save_stream(stream)
|
||
|
||
async def load_all_streams(self):
|
||
"""从数据库加载所有聊天流"""
|
||
logger.info("正在从数据库加载所有聊天流")
|
||
|
||
def _db_load_all_streams_sync():
|
||
loaded_streams_data = []
|
||
with get_db_session() as session:
|
||
for model_instance in session.execute(select(ChatStreams)).scalars():
|
||
user_info_data = {
|
||
"platform": model_instance.user_platform,
|
||
"user_id": model_instance.user_id,
|
||
"user_nickname": model_instance.user_nickname,
|
||
"user_cardname": model_instance.user_cardname or "",
|
||
}
|
||
group_info_data = None
|
||
if model_instance and getattr(model_instance, "group_id", None):
|
||
group_info_data = {
|
||
"platform": model_instance.group_platform,
|
||
"group_id": model_instance.group_id,
|
||
"group_name": model_instance.group_name,
|
||
}
|
||
|
||
data_for_from_dict = {
|
||
"stream_id": model_instance.stream_id,
|
||
"platform": model_instance.platform,
|
||
"user_info": user_info_data,
|
||
"group_info": group_info_data,
|
||
"create_time": model_instance.create_time,
|
||
"last_active_time": model_instance.last_active_time,
|
||
"energy_value": model_instance.energy_value,
|
||
"sleep_pressure": model_instance.sleep_pressure,
|
||
}
|
||
loaded_streams_data.append(data_for_from_dict)
|
||
session.commit()
|
||
return loaded_streams_data
|
||
|
||
try:
|
||
all_streams_data_list = await asyncio.to_thread(_db_load_all_streams_sync)
|
||
self.streams.clear()
|
||
for data in all_streams_data_list:
|
||
stream = ChatStream.from_dict(data)
|
||
stream.saved = True
|
||
self.streams[stream.stream_id] = stream
|
||
if stream.stream_id in self.last_messages:
|
||
stream.set_context(self.last_messages[stream.stream_id])
|
||
except Exception as e:
|
||
logger.error(f"从数据库加载所有聊天流失败 (Peewee): {e}", exc_info=True)
|
||
|
||
|
||
chat_manager = None
|
||
|
||
|
||
def get_chat_manager():
|
||
global chat_manager
|
||
if chat_manager is None:
|
||
chat_manager = ChatManager()
|
||
return chat_manager
|