refractor: 几乎写完了,进入测试阶段

This commit is contained in:
tcmofashi
2025-03-09 22:12:10 +08:00
parent fe3684736a
commit 6e2ea8261b
10 changed files with 654 additions and 338 deletions

View File

@@ -0,0 +1,209 @@
import time
import asyncio
from typing import Optional, Dict, Tuple
import hashlib
from loguru import logger
from ...common.database import Database
from .message_base import UserInfo, GroupInfo
class ChatStream:
"""聊天流对象,存储一个完整的聊天上下文"""
def __init__(self,
stream_id: str,
platform: str,
user_info: UserInfo,
group_info: Optional[GroupInfo] = None,
data: dict = None):
self.stream_id = stream_id
self.platform = platform
self.user_info = user_info
self.group_info = group_info
self.create_time = data.get('create_time', int(time.time())) if data else int(time.time())
self.last_active_time = data.get('last_active_time', self.create_time) if data else self.create_time
self.saved = False
def to_dict(self) -> dict:
"""转换为字典格式"""
result = {
'stream_id': self.stream_id,
'platform': self.platform,
'user_info': self.user_info.to_dict() if self.user_info else None,
'group_info': self.group_info.to_dict() if self.group_info else None,
'create_time': self.create_time,
'last_active_time': self.last_active_time
}
return result
@classmethod
def from_dict(cls, data: dict) -> 'ChatStream':
"""从字典创建实例"""
user_info = UserInfo(**data.get('user_info', {})) if data.get('user_info') else None
group_info = GroupInfo(**data.get('group_info', {})) if data.get('group_info') else None
return cls(
stream_id=data['stream_id'],
platform=data['platform'],
user_info=user_info,
group_info=group_info,
data=data
)
def update_active_time(self):
"""更新最后活跃时间"""
self.last_active_time = int(time.time())
self.saved = False
class ChatManager:
"""聊天管理器,管理所有聊天流"""
_instance = None
_initialized = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not self._initialized:
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
self.db = Database.get_instance()
self._ensure_collection()
self._initialized = True
# 在事件循环中启动初始化
asyncio.create_task(self._initialize())
# 启动自动保存任务
asyncio.create_task(self._auto_save_task())
async def _initialize(self):
"""异步初始化"""
try:
await self.load_all_streams()
logger.success(f"聊天管理器已启动,已加载 {len(self.streams)} 个聊天流")
except Exception as e:
logger.error(f"聊天管理器启动失败: {str(e)}")
async def _auto_save_task(self):
"""定期自动保存所有聊天流"""
while True:
await asyncio.sleep(300) # 每5分钟保存一次
try:
await self._save_all_streams()
logger.info("聊天流自动保存完成")
except Exception as e:
logger.error(f"聊天流自动保存失败: {str(e)}")
def _ensure_collection(self):
"""确保数据库集合存在并创建索引"""
if 'chat_streams' not in self.db.db.list_collection_names():
self.db.db.create_collection('chat_streams')
# 创建索引
self.db.db.chat_streams.create_index([('stream_id', 1)], unique=True)
self.db.db.chat_streams.create_index([
('platform', 1),
('user_info.user_id', 1),
('group_info.group_id', 1)
])
def _generate_stream_id(self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str:
"""生成聊天流唯一ID"""
# 组合关键信息
components = [
platform,
str(user_info.user_id),
str(group_info.group_id) if group_info else 'private'
]
# 使用MD5生成唯一ID
key = '_'.join(components)
return hashlib.md5(key.encode()).hexdigest()
async def get_or_create_stream(self,
platform: str,
user_info: UserInfo,
group_info: Optional[GroupInfo] = None) -> ChatStream:
"""获取或创建聊天流
Args:
platform: 平台标识
user_info: 用户信息
group_info: 群组信息(可选)
Returns:
ChatStream: 聊天流对象
"""
# 生成stream_id
stream_id = self._generate_stream_id(platform, user_info, group_info)
# 检查内存中是否存在
if stream_id in self.streams:
stream = self.streams[stream_id]
# 更新用户信息和群组信息
stream.user_info = user_info
if group_info:
stream.group_info = group_info
stream.update_active_time()
return stream
# 检查数据库中是否存在
data = self.db.db.chat_streams.find_one({'stream_id': stream_id})
if data:
stream = ChatStream.from_dict(data)
# 更新用户信息和群组信息
stream.user_info = user_info
if group_info:
stream.group_info = group_info
stream.update_active_time()
else:
# 创建新的聊天流
stream = ChatStream(
stream_id=stream_id,
platform=platform,
user_info=user_info,
group_info=group_info
)
# 保存到内存和数据库
self.streams[stream_id] = stream
await self._save_stream(stream)
return stream
def get_stream(self, stream_id: str) -> Optional[ChatStream]:
"""通过stream_id获取聊天流"""
return self.streams.get(stream_id)
def get_stream_by_info(self,
platform: str,
user_info: UserInfo,
group_info: Optional[GroupInfo] = None) -> Optional[ChatStream]:
"""通过信息获取聊天流"""
stream_id = self._generate_stream_id(platform, user_info, group_info)
return self.streams.get(stream_id)
async def _save_stream(self, stream: ChatStream):
"""保存聊天流到数据库"""
if not stream.saved:
self.db.db.chat_streams.update_one(
{'stream_id': stream.stream_id},
{'$set': stream.to_dict()},
upsert=True
)
stream.saved = True
async def _save_all_streams(self):
"""保存所有聊天流"""
for stream in self.streams.values():
await self._save_stream(stream)
async def load_all_streams(self):
"""从数据库加载所有聊天流"""
all_streams = self.db.db.chat_streams.find({})
for data in all_streams:
stream = ChatStream.from_dict(data)
self.streams[stream.stream_id] = stream
# 创建全局单例
chat_manager = ChatManager()