fix;修复提取消息和运行bug

This commit is contained in:
SengokuCola
2025-05-17 01:03:20 +08:00
parent 5e37d48735
commit d26d69de60
7 changed files with 472 additions and 5 deletions

View File

@@ -0,0 +1,174 @@
import unittest
from unittest.mock import patch, MagicMock
import datetime
import sys
import os
# 添加项目根目录到Python路径
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
from peewee import SqliteDatabase
from src.common.database.database_model import Messages, BaseModel
from src.common.message_repository import find_messages
class TestMessageRepository(unittest.TestCase):
def setUp(self):
# 创建内存中的SQLite数据库用于测试
self.test_db = SqliteDatabase(':memory:')
# 覆盖原有数据库连接
BaseModel._meta.database = self.test_db
Messages._meta.database = self.test_db
# 创建表
self.test_db.create_tables([Messages])
# 添加测试数据
current_time = datetime.datetime.now().timestamp()
self.test_messages = [
{
'message_id': 'msg1',
'time': current_time - 3600, # 1小时前
'chat_id': '5ed68437e28644da51f314f37df68d18',
'chat_info_stream_id': 'stream1',
'chat_info_platform': 'qq',
'chat_info_user_platform': 'qq',
'chat_info_user_id': 'user1',
'chat_info_user_nickname': '用户1',
'chat_info_user_cardname': '卡片名1',
'chat_info_group_platform': 'qq',
'chat_info_group_id': 'group1',
'chat_info_group_name': '群组1',
'chat_info_create_time': current_time - 7200, # 2小时前
'chat_info_last_active_time': current_time - 1800, # 30分钟前
'user_platform': 'qq',
'user_id': 'user1',
'user_nickname': '用户1',
'user_cardname': '卡片名1',
'processed_plain_text': '你好',
'detailed_plain_text': '你好',
'memorized_times': 1
},
{
'message_id': 'msg2',
'time': current_time - 1800, # 30分钟前
'chat_id': 'chat1',
'chat_info_stream_id': 'stream1',
'chat_info_platform': 'qq',
'chat_info_user_platform': 'qq',
'chat_info_user_id': 'user1',
'chat_info_user_nickname': '用户1',
'chat_info_user_cardname': '卡片名1',
'chat_info_group_platform': 'qq',
'chat_info_group_id': 'group1',
'chat_info_group_name': '群组1',
'chat_info_create_time': current_time - 7200,
'chat_info_last_active_time': current_time - 900, # 15分钟前
'user_platform': 'qq',
'user_id': 'user1',
'user_nickname': '用户1',
'user_cardname': '卡片名1',
'processed_plain_text': '世界',
'detailed_plain_text': '世界',
'memorized_times': 2
},
{
'message_id': 'msg3',
'time': current_time - 900, # 15分钟前
'chat_id': 'chat2',
'chat_info_stream_id': 'stream2',
'chat_info_platform': 'wechat',
'chat_info_user_platform': 'wechat',
'chat_info_user_id': 'user2',
'chat_info_user_nickname': '用户2',
'chat_info_user_cardname': '卡片名2',
'chat_info_group_platform': 'wechat',
'chat_info_group_id': 'group2',
'chat_info_group_name': '群组2',
'chat_info_create_time': current_time - 3600,
'chat_info_last_active_time': current_time - 600, # 10分钟前
'user_platform': 'wechat',
'user_id': 'user2',
'user_nickname': '用户2',
'user_cardname': '卡片名2',
'processed_plain_text': '测试',
'detailed_plain_text': '测试',
'memorized_times': 0
}
]
for msg_data in self.test_messages:
Messages.create(**msg_data)
def tearDown(self):
# 关闭测试数据库连接
self.test_db.close()
def test_find_messages_no_filter(self):
"""测试不带过滤器的查询"""
results = find_messages({})
self.assertEqual(len(results), 3)
# 验证结果是否按时间升序排列
self.assertEqual(results[0]['message_id'], 'msg1')
self.assertEqual(results[1]['message_id'], 'msg2')
self.assertEqual(results[2]['message_id'], 'msg3')
def test_find_messages_with_filter(self):
"""测试带过滤器的查询"""
results = find_messages({'chat_id': 'chat1'})
self.assertEqual(len(results), 2)
self.assertEqual(results[0]['message_id'], 'msg1')
self.assertEqual(results[1]['message_id'], 'msg2')
results = find_messages({'user_id': 'user2'})
self.assertEqual(len(results), 1)
self.assertEqual(results[0]['message_id'], 'msg3')
def test_find_messages_with_operators(self):
"""测试带操作符的查询"""
results = find_messages({'memorized_times': {'$gt': 0}})
self.assertEqual(len(results), 2)
self.assertEqual(results[0]['message_id'], 'msg1')
self.assertEqual(results[1]['message_id'], 'msg2')
results = find_messages({'memorized_times': {'$gte': 2}})
self.assertEqual(len(results), 1)
self.assertEqual(results[0]['message_id'], 'msg2')
def test_find_messages_with_sort(self):
"""测试带排序的查询"""
results = find_messages({}, sort=[('memorized_times', -1)])
self.assertEqual(len(results), 3)
# 验证结果是否按memorized_times降序排列
self.assertEqual(results[0]['message_id'], 'msg2') # memorized_times = 2
self.assertEqual(results[1]['message_id'], 'msg1') # memorized_times = 1
self.assertEqual(results[2]['message_id'], 'msg3') # memorized_times = 0
def test_find_messages_with_limit(self):
"""测试带限制的查询"""
# 默认limit_mode为latest应返回最新的2条记录
results = find_messages({}, limit=2)
self.assertEqual(len(results), 2)
self.assertEqual(results[0]['message_id'], 'msg2')
self.assertEqual(results[1]['message_id'], 'msg3')
# 使用earliest模式应返回最早的2条记录
results = find_messages({}, limit=2, limit_mode='earliest')
self.assertEqual(len(results), 2)
self.assertEqual(results[0]['message_id'], 'msg1')
self.assertEqual(results[1]['message_id'], 'msg2')
def test_find_messages_with_combined_criteria(self):
"""测试组合查询条件"""
results = find_messages(
{'chat_info_platform': 'qq', 'memorized_times': {'$gt': 0}},
sort=[('time', 1)],
limit=1
)
self.assertEqual(len(results), 1)
self.assertEqual(results[0]['message_id'], 'msg2')
if __name__ == '__main__':
unittest.main()