重构数据库交互以使用 Peewee ORM
- 更新数据库连接和模型定义,以便使用 Peewee for SQLite。 - 在消息存储和检索功能中,用 Peewee ORM 查询替换 MongoDB 查询。 - 为 Messages、ThinkingLog 和 OnlineTime 引入了新的模型,以方便结构化数据存储。 - 增强了数据库操作的错误处理和日志记录。 - 删除了过时的 MongoDB 集合管理代码。 - 通过利用 Peewee 内置的查询和数据操作方法来提升性能。
This commit is contained in:
@@ -6,6 +6,7 @@ from typing import Dict, Optional
|
||||
|
||||
|
||||
from ...common.database.database import db
|
||||
from ...common.database.database_model import ChatStreams # 新增导入
|
||||
from maim_message import GroupInfo, UserInfo
|
||||
|
||||
from src.common.logger_manager import get_logger
|
||||
@@ -82,7 +83,13 @@ class ChatManager:
|
||||
def __init__(self):
|
||||
if not self._initialized:
|
||||
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
|
||||
self._ensure_collection()
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
# 确保 ChatStreams 表存在
|
||||
db.create_tables([ChatStreams], safe=True)
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接或 ChatStreams 表创建失败: {e}")
|
||||
|
||||
self._initialized = True
|
||||
# 在事件循环中启动初始化
|
||||
# asyncio.create_task(self._initialize())
|
||||
@@ -107,15 +114,6 @@ class ChatManager:
|
||||
except Exception as e:
|
||||
logger.error(f"聊天流自动保存失败: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def _ensure_collection():
|
||||
"""确保数据库集合存在并创建索引"""
|
||||
if "chat_streams" not in db.list_collection_names():
|
||||
db.create_collection("chat_streams")
|
||||
# 创建索引
|
||||
db.chat_streams.create_index([("stream_id", 1)], unique=True)
|
||||
db.chat_streams.create_index([("platform", 1), ("user_info.user_id", 1), ("group_info.group_id", 1)])
|
||||
|
||||
@staticmethod
|
||||
def _generate_stream_id(platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str:
|
||||
"""生成聊天流唯一ID"""
|
||||
@@ -151,16 +149,43 @@ class ChatManager:
|
||||
stream = self.streams[stream_id]
|
||||
# 更新用户信息和群组信息
|
||||
stream.update_active_time()
|
||||
stream = copy.deepcopy(stream)
|
||||
stream = copy.deepcopy(stream) # 返回副本以避免外部修改影响缓存
|
||||
stream.user_info = user_info
|
||||
if group_info:
|
||||
stream.group_info = group_info
|
||||
return stream
|
||||
|
||||
# 检查数据库中是否存在
|
||||
data = db.chat_streams.find_one({"stream_id": stream_id})
|
||||
if data:
|
||||
stream = ChatStream.from_dict(data)
|
||||
def _db_find_stream_sync(s_id: str):
|
||||
return ChatStreams.get_or_none(ChatStreams.stream_id == s_id)
|
||||
|
||||
model_instance = await asyncio.to_thread(_db_find_stream_sync, stream_id)
|
||||
|
||||
if model_instance:
|
||||
# 从 Peewee 模型转换回 ChatStream.from_dict 期望的格式
|
||||
user_info_data = {
|
||||
"platform": model_instance.user_platform,
|
||||
"user_id": model_instance.user_id,
|
||||
"user_nickname": model_instance.user_nickname,
|
||||
"user_cardname": model_instance.user_cardname or "",
|
||||
}
|
||||
group_info_data = None
|
||||
if model_instance.group_id: # 假设 group_id 为空字符串表示没有群组信息
|
||||
group_info_data = {
|
||||
"platform": model_instance.group_platform,
|
||||
"group_id": model_instance.group_id,
|
||||
"group_name": model_instance.group_name,
|
||||
}
|
||||
|
||||
data_for_from_dict = {
|
||||
"stream_id": model_instance.stream_id,
|
||||
"platform": model_instance.platform,
|
||||
"user_info": user_info_data,
|
||||
"group_info": group_info_data,
|
||||
"create_time": model_instance.create_time,
|
||||
"last_active_time": model_instance.last_active_time,
|
||||
}
|
||||
stream = ChatStream.from_dict(data_for_from_dict)
|
||||
# 更新用户信息和群组信息
|
||||
stream.user_info = user_info
|
||||
if group_info:
|
||||
@@ -175,7 +200,7 @@ class ChatManager:
|
||||
group_info=group_info,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"创建聊天流失败: {e}")
|
||||
logger.error(f"获取或创建聊天流失败: {e}", exc_info=True)
|
||||
raise e
|
||||
|
||||
# 保存到内存和数据库
|
||||
@@ -205,15 +230,38 @@ class ChatManager:
|
||||
elif stream.user_info and stream.user_info.user_nickname:
|
||||
return f"{stream.user_info.user_nickname}的私聊"
|
||||
else:
|
||||
# 如果没有群名或用户昵称,返回 None 或其他默认值
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def _save_stream(stream: ChatStream):
|
||||
"""保存聊天流到数据库"""
|
||||
if not stream.saved:
|
||||
db.chat_streams.update_one({"stream_id": stream.stream_id}, {"$set": stream.to_dict()}, upsert=True)
|
||||
stream.saved = True
|
||||
stream_data_dict = stream.to_dict()
|
||||
|
||||
def _db_save_stream_sync(s_data_dict: dict):
|
||||
user_info_d = s_data_dict.get("user_info")
|
||||
group_info_d = s_data_dict.get("group_info")
|
||||
|
||||
fields_to_save = {
|
||||
"platform": s_data_dict["platform"],
|
||||
"create_time": s_data_dict["create_time"],
|
||||
"last_active_time": s_data_dict["last_active_time"],
|
||||
"user_platform": user_info_d["platform"] if user_info_d else "",
|
||||
"user_id": user_info_d["user_id"] if user_info_d else "",
|
||||
"user_nickname": user_info_d["user_nickname"] if user_info_d else "",
|
||||
"user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None,
|
||||
"group_platform": group_info_d["platform"] if group_info_d else "",
|
||||
"group_id": group_info_d["group_id"] if group_info_d else "",
|
||||
"group_name": group_info_d["group_name"] if group_info_d else "",
|
||||
}
|
||||
|
||||
ChatStreams.replace(stream_id=s_data_dict["stream_id"], **fields_to_save).execute()
|
||||
|
||||
try:
|
||||
await asyncio.to_thread(_db_save_stream_sync, stream_data_dict)
|
||||
stream.saved = True
|
||||
except Exception as e:
|
||||
logger.error(f"保存聊天流 {stream.stream_id} 到数据库失败 (Peewee): {e}", exc_info=True)
|
||||
|
||||
async def _save_all_streams(self):
|
||||
"""保存所有聊天流"""
|
||||
@@ -222,10 +270,44 @@ class ChatManager:
|
||||
|
||||
async def load_all_streams(self):
|
||||
"""从数据库加载所有聊天流"""
|
||||
all_streams = db.chat_streams.find({})
|
||||
for data in all_streams:
|
||||
stream = ChatStream.from_dict(data)
|
||||
self.streams[stream.stream_id] = stream
|
||||
|
||||
def _db_load_all_streams_sync():
|
||||
loaded_streams_data = []
|
||||
for model_instance in ChatStreams.select():
|
||||
user_info_data = {
|
||||
"platform": model_instance.user_platform,
|
||||
"user_id": model_instance.user_id,
|
||||
"user_nickname": model_instance.user_nickname,
|
||||
"user_cardname": model_instance.user_cardname or "",
|
||||
}
|
||||
group_info_data = None
|
||||
if model_instance.group_id:
|
||||
group_info_data = {
|
||||
"platform": model_instance.group_platform,
|
||||
"group_id": model_instance.group_id,
|
||||
"group_name": model_instance.group_name,
|
||||
}
|
||||
|
||||
data_for_from_dict = {
|
||||
"stream_id": model_instance.stream_id,
|
||||
"platform": model_instance.platform,
|
||||
"user_info": user_info_data,
|
||||
"group_info": group_info_data,
|
||||
"create_time": model_instance.create_time,
|
||||
"last_active_time": model_instance.last_active_time,
|
||||
}
|
||||
loaded_streams_data.append(data_for_from_dict)
|
||||
return loaded_streams_data
|
||||
|
||||
try:
|
||||
all_streams_data_list = await asyncio.to_thread(_db_load_all_streams_sync)
|
||||
self.streams.clear()
|
||||
for data in all_streams_data_list:
|
||||
stream = ChatStream.from_dict(data)
|
||||
stream.saved = True
|
||||
self.streams[stream.stream_id] = stream
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库加载所有聊天流失败 (Peewee): {e}", exc_info=True)
|
||||
|
||||
|
||||
# 创建全局单例
|
||||
|
||||
Reference in New Issue
Block a user