Files
Mofox-Core/src/plugins/PFC/message_storage.py

124 lines
3.8 KiB
Python

from abc import ABC, abstractmethod
from typing import List, Dict, Any
from src.common.database import db
class MessageStorage(ABC):
"""消息存储接口"""
@abstractmethod
async def get_messages_after(self, chat_id: str, message: Dict[str, Any]) -> List[Dict[str, Any]]:
"""获取指定消息ID之后的所有消息
Args:
chat_id: 聊天ID
message: 消息
Returns:
List[Dict[str, Any]]: 消息列表
"""
pass
@abstractmethod
async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]:
"""获取指定时间点之前的消息
Args:
chat_id: 聊天ID
time_point: 时间戳
limit: 最大消息数量
Returns:
List[Dict[str, Any]]: 消息列表
"""
pass
@abstractmethod
async def has_new_messages(self, chat_id: str, after_time: float) -> bool:
"""检查是否有新消息
Args:
chat_id: 聊天ID
after_time: 时间戳
Returns:
bool: 是否有新消息
"""
pass
class MongoDBMessageStorage(MessageStorage):
"""MongoDB消息存储实现"""
def __init__(self):
self.db = db
async def get_messages_after(self, chat_id: str, message_time: float) -> List[Dict[str, Any]]:
query = {"chat_id": chat_id}
# print(f"storage_check_message: {message_time}")
query["time"] = {"$gt": message_time}
return list(self.db.messages.find(query).sort("time", 1))
async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]:
query = {"chat_id": chat_id, "time": {"$lt": time_point}}
messages = list(self.db.messages.find(query).sort("time", -1).limit(limit))
# 将消息按时间正序排列
messages.reverse()
return messages
async def has_new_messages(self, chat_id: str, after_time: float) -> bool:
query = {"chat_id": chat_id, "time": {"$gt": after_time}}
return self.db.messages.find_one(query) is not None
# # 创建一个内存消息存储实现,用于测试
# class InMemoryMessageStorage(MessageStorage):
# """内存消息存储实现,主要用于测试"""
# def __init__(self):
# self.messages: Dict[str, List[Dict[str, Any]]] = {}
# async def get_messages_after(self, chat_id: str, message_id: Optional[str] = None) -> List[Dict[str, Any]]:
# if chat_id not in self.messages:
# return []
# messages = self.messages[chat_id]
# if not message_id:
# return messages
# # 找到message_id的索引
# try:
# index = next(i for i, m in enumerate(messages) if m["message_id"] == message_id)
# return messages[index + 1:]
# except StopIteration:
# return []
# async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]:
# if chat_id not in self.messages:
# return []
# messages = [
# m for m in self.messages[chat_id]
# if m["time"] < time_point
# ]
# return messages[-limit:]
# async def has_new_messages(self, chat_id: str, after_time: float) -> bool:
# if chat_id not in self.messages:
# return False
# return any(m["time"] > after_time for m in self.messages[chat_id])
# # 测试辅助方法
# def add_message(self, chat_id: str, message: Dict[str, Any]):
# """添加测试消息"""
# if chat_id not in self.messages:
# self.messages[chat_id] = []
# self.messages[chat_id].append(message)
# self.messages[chat_id].sort(key=lambda m: m["time"])