ruff
This commit is contained in:
@@ -5,7 +5,7 @@ import sys
|
||||
import os
|
||||
|
||||
# 添加项目根目录到Python路径
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
|
||||
|
||||
from peewee import SqliteDatabase
|
||||
from src.common.database.database_model import Messages, BaseModel
|
||||
@@ -15,160 +15,158 @@ from src.common.message_repository import find_messages
|
||||
class TestMessageRepository(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# 创建内存中的SQLite数据库用于测试
|
||||
self.test_db = SqliteDatabase(':memory:')
|
||||
|
||||
self.test_db = SqliteDatabase(":memory:")
|
||||
|
||||
# 覆盖原有数据库连接
|
||||
BaseModel._meta.database = self.test_db
|
||||
Messages._meta.database = self.test_db
|
||||
|
||||
|
||||
# 创建表
|
||||
self.test_db.create_tables([Messages])
|
||||
|
||||
|
||||
# 添加测试数据
|
||||
current_time = datetime.datetime.now().timestamp()
|
||||
self.test_messages = [
|
||||
{
|
||||
'message_id': 'msg1',
|
||||
'time': current_time - 3600, # 1小时前
|
||||
'chat_id': '5ed68437e28644da51f314f37df68d18',
|
||||
'chat_info_stream_id': 'stream1',
|
||||
'chat_info_platform': 'qq',
|
||||
'chat_info_user_platform': 'qq',
|
||||
'chat_info_user_id': 'user1',
|
||||
'chat_info_user_nickname': '用户1',
|
||||
'chat_info_user_cardname': '卡片名1',
|
||||
'chat_info_group_platform': 'qq',
|
||||
'chat_info_group_id': 'group1',
|
||||
'chat_info_group_name': '群组1',
|
||||
'chat_info_create_time': current_time - 7200, # 2小时前
|
||||
'chat_info_last_active_time': current_time - 1800, # 30分钟前
|
||||
'user_platform': 'qq',
|
||||
'user_id': 'user1',
|
||||
'user_nickname': '用户1',
|
||||
'user_cardname': '卡片名1',
|
||||
'processed_plain_text': '你好',
|
||||
'detailed_plain_text': '你好',
|
||||
'memorized_times': 1
|
||||
"message_id": "msg1",
|
||||
"time": current_time - 3600, # 1小时前
|
||||
"chat_id": "5ed68437e28644da51f314f37df68d18",
|
||||
"chat_info_stream_id": "stream1",
|
||||
"chat_info_platform": "qq",
|
||||
"chat_info_user_platform": "qq",
|
||||
"chat_info_user_id": "user1",
|
||||
"chat_info_user_nickname": "用户1",
|
||||
"chat_info_user_cardname": "卡片名1",
|
||||
"chat_info_group_platform": "qq",
|
||||
"chat_info_group_id": "group1",
|
||||
"chat_info_group_name": "群组1",
|
||||
"chat_info_create_time": current_time - 7200, # 2小时前
|
||||
"chat_info_last_active_time": current_time - 1800, # 30分钟前
|
||||
"user_platform": "qq",
|
||||
"user_id": "user1",
|
||||
"user_nickname": "用户1",
|
||||
"user_cardname": "卡片名1",
|
||||
"processed_plain_text": "你好",
|
||||
"detailed_plain_text": "你好",
|
||||
"memorized_times": 1,
|
||||
},
|
||||
{
|
||||
'message_id': 'msg2',
|
||||
'time': current_time - 1800, # 30分钟前
|
||||
'chat_id': 'chat1',
|
||||
'chat_info_stream_id': 'stream1',
|
||||
'chat_info_platform': 'qq',
|
||||
'chat_info_user_platform': 'qq',
|
||||
'chat_info_user_id': 'user1',
|
||||
'chat_info_user_nickname': '用户1',
|
||||
'chat_info_user_cardname': '卡片名1',
|
||||
'chat_info_group_platform': 'qq',
|
||||
'chat_info_group_id': 'group1',
|
||||
'chat_info_group_name': '群组1',
|
||||
'chat_info_create_time': current_time - 7200,
|
||||
'chat_info_last_active_time': current_time - 900, # 15分钟前
|
||||
'user_platform': 'qq',
|
||||
'user_id': 'user1',
|
||||
'user_nickname': '用户1',
|
||||
'user_cardname': '卡片名1',
|
||||
'processed_plain_text': '世界',
|
||||
'detailed_plain_text': '世界',
|
||||
'memorized_times': 2
|
||||
"message_id": "msg2",
|
||||
"time": current_time - 1800, # 30分钟前
|
||||
"chat_id": "chat1",
|
||||
"chat_info_stream_id": "stream1",
|
||||
"chat_info_platform": "qq",
|
||||
"chat_info_user_platform": "qq",
|
||||
"chat_info_user_id": "user1",
|
||||
"chat_info_user_nickname": "用户1",
|
||||
"chat_info_user_cardname": "卡片名1",
|
||||
"chat_info_group_platform": "qq",
|
||||
"chat_info_group_id": "group1",
|
||||
"chat_info_group_name": "群组1",
|
||||
"chat_info_create_time": current_time - 7200,
|
||||
"chat_info_last_active_time": current_time - 900, # 15分钟前
|
||||
"user_platform": "qq",
|
||||
"user_id": "user1",
|
||||
"user_nickname": "用户1",
|
||||
"user_cardname": "卡片名1",
|
||||
"processed_plain_text": "世界",
|
||||
"detailed_plain_text": "世界",
|
||||
"memorized_times": 2,
|
||||
},
|
||||
{
|
||||
'message_id': 'msg3',
|
||||
'time': current_time - 900, # 15分钟前
|
||||
'chat_id': 'chat2',
|
||||
'chat_info_stream_id': 'stream2',
|
||||
'chat_info_platform': 'wechat',
|
||||
'chat_info_user_platform': 'wechat',
|
||||
'chat_info_user_id': 'user2',
|
||||
'chat_info_user_nickname': '用户2',
|
||||
'chat_info_user_cardname': '卡片名2',
|
||||
'chat_info_group_platform': 'wechat',
|
||||
'chat_info_group_id': 'group2',
|
||||
'chat_info_group_name': '群组2',
|
||||
'chat_info_create_time': current_time - 3600,
|
||||
'chat_info_last_active_time': current_time - 600, # 10分钟前
|
||||
'user_platform': 'wechat',
|
||||
'user_id': 'user2',
|
||||
'user_nickname': '用户2',
|
||||
'user_cardname': '卡片名2',
|
||||
'processed_plain_text': '测试',
|
||||
'detailed_plain_text': '测试',
|
||||
'memorized_times': 0
|
||||
}
|
||||
"message_id": "msg3",
|
||||
"time": current_time - 900, # 15分钟前
|
||||
"chat_id": "chat2",
|
||||
"chat_info_stream_id": "stream2",
|
||||
"chat_info_platform": "wechat",
|
||||
"chat_info_user_platform": "wechat",
|
||||
"chat_info_user_id": "user2",
|
||||
"chat_info_user_nickname": "用户2",
|
||||
"chat_info_user_cardname": "卡片名2",
|
||||
"chat_info_group_platform": "wechat",
|
||||
"chat_info_group_id": "group2",
|
||||
"chat_info_group_name": "群组2",
|
||||
"chat_info_create_time": current_time - 3600,
|
||||
"chat_info_last_active_time": current_time - 600, # 10分钟前
|
||||
"user_platform": "wechat",
|
||||
"user_id": "user2",
|
||||
"user_nickname": "用户2",
|
||||
"user_cardname": "卡片名2",
|
||||
"processed_plain_text": "测试",
|
||||
"detailed_plain_text": "测试",
|
||||
"memorized_times": 0,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
for msg_data in self.test_messages:
|
||||
Messages.create(**msg_data)
|
||||
|
||||
|
||||
def tearDown(self):
|
||||
# 关闭测试数据库连接
|
||||
self.test_db.close()
|
||||
|
||||
|
||||
def test_find_messages_no_filter(self):
|
||||
"""测试不带过滤器的查询"""
|
||||
results = find_messages({})
|
||||
self.assertEqual(len(results), 3)
|
||||
# 验证结果是否按时间升序排列
|
||||
self.assertEqual(results[0]['message_id'], 'msg1')
|
||||
self.assertEqual(results[1]['message_id'], 'msg2')
|
||||
self.assertEqual(results[2]['message_id'], 'msg3')
|
||||
|
||||
self.assertEqual(results[0]["message_id"], "msg1")
|
||||
self.assertEqual(results[1]["message_id"], "msg2")
|
||||
self.assertEqual(results[2]["message_id"], "msg3")
|
||||
|
||||
def test_find_messages_with_filter(self):
|
||||
"""测试带过滤器的查询"""
|
||||
results = find_messages({'chat_id': 'chat1'})
|
||||
results = find_messages({"chat_id": "chat1"})
|
||||
self.assertEqual(len(results), 2)
|
||||
self.assertEqual(results[0]['message_id'], 'msg1')
|
||||
self.assertEqual(results[1]['message_id'], 'msg2')
|
||||
|
||||
results = find_messages({'user_id': 'user2'})
|
||||
self.assertEqual(results[0]["message_id"], "msg1")
|
||||
self.assertEqual(results[1]["message_id"], "msg2")
|
||||
|
||||
results = find_messages({"user_id": "user2"})
|
||||
self.assertEqual(len(results), 1)
|
||||
self.assertEqual(results[0]['message_id'], 'msg3')
|
||||
|
||||
self.assertEqual(results[0]["message_id"], "msg3")
|
||||
|
||||
def test_find_messages_with_operators(self):
|
||||
"""测试带操作符的查询"""
|
||||
results = find_messages({'memorized_times': {'$gt': 0}})
|
||||
results = find_messages({"memorized_times": {"$gt": 0}})
|
||||
self.assertEqual(len(results), 2)
|
||||
self.assertEqual(results[0]['message_id'], 'msg1')
|
||||
self.assertEqual(results[1]['message_id'], 'msg2')
|
||||
|
||||
results = find_messages({'memorized_times': {'$gte': 2}})
|
||||
self.assertEqual(results[0]["message_id"], "msg1")
|
||||
self.assertEqual(results[1]["message_id"], "msg2")
|
||||
|
||||
results = find_messages({"memorized_times": {"$gte": 2}})
|
||||
self.assertEqual(len(results), 1)
|
||||
self.assertEqual(results[0]['message_id'], 'msg2')
|
||||
|
||||
self.assertEqual(results[0]["message_id"], "msg2")
|
||||
|
||||
def test_find_messages_with_sort(self):
|
||||
"""测试带排序的查询"""
|
||||
results = find_messages({}, sort=[('memorized_times', -1)])
|
||||
results = find_messages({}, sort=[("memorized_times", -1)])
|
||||
self.assertEqual(len(results), 3)
|
||||
# 验证结果是否按memorized_times降序排列
|
||||
self.assertEqual(results[0]['message_id'], 'msg2') # memorized_times = 2
|
||||
self.assertEqual(results[1]['message_id'], 'msg1') # memorized_times = 1
|
||||
self.assertEqual(results[2]['message_id'], 'msg3') # memorized_times = 0
|
||||
|
||||
self.assertEqual(results[0]["message_id"], "msg2") # memorized_times = 2
|
||||
self.assertEqual(results[1]["message_id"], "msg1") # memorized_times = 1
|
||||
self.assertEqual(results[2]["message_id"], "msg3") # memorized_times = 0
|
||||
|
||||
def test_find_messages_with_limit(self):
|
||||
"""测试带限制的查询"""
|
||||
# 默认limit_mode为latest,应返回最新的2条记录
|
||||
results = find_messages({}, limit=2)
|
||||
self.assertEqual(len(results), 2)
|
||||
self.assertEqual(results[0]['message_id'], 'msg2')
|
||||
self.assertEqual(results[1]['message_id'], 'msg3')
|
||||
|
||||
self.assertEqual(results[0]["message_id"], "msg2")
|
||||
self.assertEqual(results[1]["message_id"], "msg3")
|
||||
|
||||
# 使用earliest模式,应返回最早的2条记录
|
||||
results = find_messages({}, limit=2, limit_mode='earliest')
|
||||
results = find_messages({}, limit=2, limit_mode="earliest")
|
||||
self.assertEqual(len(results), 2)
|
||||
self.assertEqual(results[0]['message_id'], 'msg1')
|
||||
self.assertEqual(results[1]['message_id'], 'msg2')
|
||||
|
||||
self.assertEqual(results[0]["message_id"], "msg1")
|
||||
self.assertEqual(results[1]["message_id"], "msg2")
|
||||
|
||||
def test_find_messages_with_combined_criteria(self):
|
||||
"""测试组合查询条件"""
|
||||
results = find_messages(
|
||||
{'chat_info_platform': 'qq', 'memorized_times': {'$gt': 0}},
|
||||
sort=[('time', 1)],
|
||||
limit=1
|
||||
{"chat_info_platform": "qq", "memorized_times": {"$gt": 0}}, sort=[("time", 1)], limit=1
|
||||
)
|
||||
self.assertEqual(len(results), 1)
|
||||
self.assertEqual(results[0]['message_id'], 'msg2')
|
||||
self.assertEqual(results[0]["message_id"], "msg2")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user