feat: 重构完成开始测试debug

This commit is contained in:
tcmofashi
2025-03-11 01:15:32 +08:00
parent 20b8778e2b
commit 7899e67cb2
13 changed files with 486 additions and 572 deletions

View File

@@ -1,55 +1,67 @@
import time
import asyncio
from typing import Optional, Dict, Tuple
import hashlib
import time
from typing import Dict, Optional
from loguru import logger
from ...common.database import Database
from .message_base import UserInfo, GroupInfo
from .message_base import GroupInfo, UserInfo
class ChatStream:
"""聊天流对象,存储一个完整的聊天上下文"""
def __init__(self,
stream_id: str,
platform: str,
user_info: UserInfo,
group_info: Optional[GroupInfo] = None,
data: dict = None):
def __init__(
self,
stream_id: str,
platform: str,
user_info: UserInfo,
group_info: Optional[GroupInfo] = None,
data: dict = None,
):
self.stream_id = stream_id
self.platform = platform
self.user_info = user_info
self.group_info = group_info
self.create_time = data.get('create_time', int(time.time())) if data else int(time.time())
self.last_active_time = data.get('last_active_time', self.create_time) if data else self.create_time
self.create_time = (
data.get("create_time", int(time.time())) if data else int(time.time())
)
self.last_active_time = (
data.get("last_active_time", self.create_time) if data else self.create_time
)
self.saved = False
def to_dict(self) -> dict:
"""转换为字典格式"""
result = {
'stream_id': self.stream_id,
'platform': self.platform,
'user_info': self.user_info.to_dict() if self.user_info else None,
'group_info': self.group_info.to_dict() if self.group_info else None,
'create_time': self.create_time,
'last_active_time': self.last_active_time
"stream_id": self.stream_id,
"platform": self.platform,
"user_info": self.user_info.to_dict() if self.user_info else None,
"group_info": self.group_info.to_dict() if self.group_info else None,
"create_time": self.create_time,
"last_active_time": self.last_active_time,
}
return result
@classmethod
def from_dict(cls, data: dict) -> 'ChatStream':
def from_dict(cls, data: dict) -> "ChatStream":
"""从字典创建实例"""
user_info = UserInfo(**data.get('user_info', {})) if data.get('user_info') else None
group_info = GroupInfo(**data.get('group_info', {})) if data.get('group_info') else None
user_info = (
UserInfo(**data.get("user_info", {})) if data.get("user_info") else None
)
group_info = (
GroupInfo(**data.get("group_info", {})) if data.get("group_info") else None
)
return cls(
stream_id=data['stream_id'],
platform=data['platform'],
stream_id=data["stream_id"],
platform=data["platform"],
user_info=user_info,
group_info=group_info,
data=data
data=data,
)
def update_active_time(self):
"""更新最后活跃时间"""
self.last_active_time = int(time.time())
@@ -58,14 +70,15 @@ class ChatStream:
class ChatManager:
"""聊天管理器,管理所有聊天流"""
_instance = None
_initialized = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not self._initialized:
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
@@ -76,7 +89,7 @@ class ChatManager:
asyncio.create_task(self._initialize())
# 启动自动保存任务
asyncio.create_task(self._auto_save_task())
async def _initialize(self):
"""异步初始化"""
try:
@@ -84,7 +97,7 @@ class ChatManager:
logger.success(f"聊天管理器已启动,已加载 {len(self.streams)} 个聊天流")
except Exception as e:
logger.error(f"聊天管理器启动失败: {str(e)}")
async def _auto_save_task(self):
"""定期自动保存所有聊天流"""
while True:
@@ -94,49 +107,48 @@ class ChatManager:
logger.info("聊天流自动保存完成")
except Exception as e:
logger.error(f"聊天流自动保存失败: {str(e)}")
def _ensure_collection(self):
"""确保数据库集合存在并创建索引"""
if 'chat_streams' not in self.db.db.list_collection_names():
self.db.db.create_collection('chat_streams')
if "chat_streams" not in self.db.db.list_collection_names():
self.db.db.create_collection("chat_streams")
# 创建索引
self.db.db.chat_streams.create_index([('stream_id', 1)], unique=True)
self.db.db.chat_streams.create_index([
('platform', 1),
('user_info.user_id', 1),
('group_info.group_id', 1)
])
def _generate_stream_id(self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str:
self.db.db.chat_streams.create_index([("stream_id", 1)], unique=True)
self.db.db.chat_streams.create_index(
[("platform", 1), ("user_info.user_id", 1), ("group_info.group_id", 1)]
)
def _generate_stream_id(
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
) -> str:
"""生成聊天流唯一ID"""
# 组合关键信息
components = [
platform,
str(user_info.user_id),
str(group_info.group_id) if group_info else 'private'
str(group_info.group_id) if group_info else "private",
]
# 使用MD5生成唯一ID
key = '_'.join(components)
key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest()
async def get_or_create_stream(self,
platform: str,
user_info: UserInfo,
group_info: Optional[GroupInfo] = None) -> ChatStream:
async def get_or_create_stream(
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
) -> ChatStream:
"""获取或创建聊天流
Args:
platform: 平台标识
user_info: 用户信息
group_info: 群组信息(可选)
Returns:
ChatStream: 聊天流对象
"""
# 生成stream_id
stream_id = self._generate_stream_id(platform, user_info, group_info)
# 检查内存中是否存在
if stream_id in self.streams:
stream = self.streams[stream_id]
@@ -146,9 +158,9 @@ class ChatManager:
stream.group_info = group_info
stream.update_active_time()
return stream
# 检查数据库中是否存在
data = self.db.db.chat_streams.find_one({'stream_id': stream_id})
data = self.db.db.chat_streams.find_one({"stream_id": stream_id})
if data:
stream = ChatStream.from_dict(data)
# 更新用户信息和群组信息
@@ -162,41 +174,38 @@ class ChatManager:
stream_id=stream_id,
platform=platform,
user_info=user_info,
group_info=group_info
group_info=group_info,
)
# 保存到内存和数据库
self.streams[stream_id] = stream
await self._save_stream(stream)
return stream
def get_stream(self, stream_id: str) -> Optional[ChatStream]:
"""通过stream_id获取聊天流"""
return self.streams.get(stream_id)
def get_stream_by_info(self,
platform: str,
user_info: UserInfo,
group_info: Optional[GroupInfo] = None) -> Optional[ChatStream]:
def get_stream_by_info(
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
) -> Optional[ChatStream]:
"""通过信息获取聊天流"""
stream_id = self._generate_stream_id(platform, user_info, group_info)
return self.streams.get(stream_id)
async def _save_stream(self, stream: ChatStream):
"""保存聊天流到数据库"""
if not stream.saved:
self.db.db.chat_streams.update_one(
{'stream_id': stream.stream_id},
{'$set': stream.to_dict()},
upsert=True
{"stream_id": stream.stream_id}, {"$set": stream.to_dict()}, upsert=True
)
stream.saved = True
async def _save_all_streams(self):
"""保存所有聊天流"""
for stream in self.streams.values():
await self._save_stream(stream)
async def load_all_streams(self):
"""从数据库加载所有聊天流"""
all_streams = self.db.db.chat_streams.find({})