This commit is contained in:
UnCLAS-Prommer
2025-05-17 17:35:00 +08:00
parent 061fcefeef
commit 7973318f4c
7 changed files with 231 additions and 240 deletions

View File

@@ -146,7 +146,7 @@ class ChattingObservation(Observation):
"platform": find_msg.get("user_platform", ""), "platform": find_msg.get("user_platform", ""),
"user_id": find_msg.get("user_id", ""), "user_id": find_msg.get("user_id", ""),
"user_nickname": find_msg.get("user_nickname", ""), "user_nickname": find_msg.get("user_nickname", ""),
"user_cardname": find_msg.get("user_cardname", "") "user_cardname": find_msg.get("user_cardname", ""),
} }
# 创建所需的group_info字段如果是群聊的话 # 创建所需的group_info字段如果是群聊的话
@@ -155,7 +155,7 @@ class ChattingObservation(Observation):
group_info = { group_info = {
"platform": find_msg.get("chat_info_group_platform", ""), "platform": find_msg.get("chat_info_group_platform", ""),
"group_id": find_msg.get("chat_info_group_id", ""), "group_id": find_msg.get("chat_info_group_id", ""),
"group_name": find_msg.get("chat_info_group_name", "") "group_name": find_msg.get("chat_info_group_name", ""),
} }
content_format = "" content_format = ""

View File

@@ -861,9 +861,7 @@ class EntorhinalCortex:
# 确保在更新前获取最新的 memorized_times # 确保在更新前获取最新的 memorized_times
current_memorized_times = message.get("memorized_times", 0) current_memorized_times = message.get("memorized_times", 0)
# 使用 Peewee 更新记录 # 使用 Peewee 更新记录
Messages.update( Messages.update(memorized_times=current_memorized_times + 1).where(
memorized_times=current_memorized_times + 1
).where(
Messages.message_id == message["message_id"] Messages.message_id == message["message_id"]
).execute() ).execute()
return messages # 直接返回原始的消息列表 return messages # 直接返回原始的消息列表
@@ -983,9 +981,7 @@ class EntorhinalCortex:
if not node.last_modified: if not node.last_modified:
update_data["last_modified"] = current_time update_data["last_modified"] = current_time
GraphNodes.update( GraphNodes.update(**update_data).where(GraphNodes.concept == concept).execute()
**update_data
).where(GraphNodes.concept == concept).execute()
logger.info(f"[时间更新] 节点 {concept} 添加缺失的时间字段") logger.info(f"[时间更新] 节点 {concept} 添加缺失的时间字段")
# 获取时间信息(如果不存在则使用当前时间) # 获取时间信息(如果不存在则使用当前时间)
@@ -1014,9 +1010,7 @@ class EntorhinalCortex:
if not edge.last_modified: if not edge.last_modified:
update_data["last_modified"] = current_time update_data["last_modified"] = current_time
GraphEdges.update( GraphEdges.update(**update_data).where(
**update_data
).where(
(GraphEdges.source == source) & (GraphEdges.target == target) (GraphEdges.source == source) & (GraphEdges.target == target)
).execute() ).execute()
logger.info(f"[时间更新] 边 {source} - {target} 添加缺失的时间字段") logger.info(f"[时间更新] 边 {source} - {target} 添加缺失的时间字段")

View File

@@ -175,13 +175,13 @@ async def _build_readable_messages_internal(
# 1 & 2: 获取发送者信息并提取消息组件 # 1 & 2: 获取发送者信息并提取消息组件
for msg in messages: for msg in messages:
# 检查并修复缺少的user_info字段 # 检查并修复缺少的user_info字段
if 'user_info' not in msg: if "user_info" not in msg:
# 创建user_info字段 # 创建user_info字段
msg['user_info'] = { msg["user_info"] = {
'platform': msg.get('user_platform', ''), "platform": msg.get("user_platform", ""),
'user_id': msg.get('user_id', ''), "user_id": msg.get("user_id", ""),
'user_nickname': msg.get('user_nickname', ''), "user_nickname": msg.get("user_nickname", ""),
'user_cardname': msg.get('user_cardname', '') "user_cardname": msg.get("user_cardname", ""),
} }
user_info = msg.get("user_info", {}) user_info = msg.get("user_info", {})

View File

@@ -279,6 +279,7 @@ class GraphNodes(BaseModel):
""" """
用于存储记忆图节点的模型 用于存储记忆图节点的模型
""" """
concept = TextField(unique=True, index=True) # 节点概念 concept = TextField(unique=True, index=True) # 节点概念
memory_items = TextField() # JSON格式存储的记忆列表 memory_items = TextField() # JSON格式存储的记忆列表
hash = TextField() # 节点哈希值 hash = TextField() # 节点哈希值
@@ -293,6 +294,7 @@ class GraphEdges(BaseModel):
""" """
用于存储记忆图边的模型 用于存储记忆图边的模型
""" """
source = TextField(index=True) # 源节点 source = TextField(index=True) # 源节点
target = TextField(index=True) # 目标节点 target = TextField(index=True) # 目标节点
strength = IntegerField() # 连接强度 strength = IntegerField() # 连接强度

View File

@@ -5,7 +5,7 @@ import sys
import os import os
# 添加项目根目录到Python路径 # 添加项目根目录到Python路径
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
from peewee import SqliteDatabase from peewee import SqliteDatabase
from src.common.database.database_model import Messages, BaseModel from src.common.database.database_model import Messages, BaseModel
@@ -15,7 +15,7 @@ from src.common.message_repository import find_messages
class TestMessageRepository(unittest.TestCase): class TestMessageRepository(unittest.TestCase):
def setUp(self): def setUp(self):
# 创建内存中的SQLite数据库用于测试 # 创建内存中的SQLite数据库用于测试
self.test_db = SqliteDatabase(':memory:') self.test_db = SqliteDatabase(":memory:")
# 覆盖原有数据库连接 # 覆盖原有数据库连接
BaseModel._meta.database = self.test_db BaseModel._meta.database = self.test_db
@@ -28,74 +28,74 @@ class TestMessageRepository(unittest.TestCase):
current_time = datetime.datetime.now().timestamp() current_time = datetime.datetime.now().timestamp()
self.test_messages = [ self.test_messages = [
{ {
'message_id': 'msg1', "message_id": "msg1",
'time': current_time - 3600, # 1小时前 "time": current_time - 3600, # 1小时前
'chat_id': '5ed68437e28644da51f314f37df68d18', "chat_id": "5ed68437e28644da51f314f37df68d18",
'chat_info_stream_id': 'stream1', "chat_info_stream_id": "stream1",
'chat_info_platform': 'qq', "chat_info_platform": "qq",
'chat_info_user_platform': 'qq', "chat_info_user_platform": "qq",
'chat_info_user_id': 'user1', "chat_info_user_id": "user1",
'chat_info_user_nickname': '用户1', "chat_info_user_nickname": "用户1",
'chat_info_user_cardname': '卡片名1', "chat_info_user_cardname": "卡片名1",
'chat_info_group_platform': 'qq', "chat_info_group_platform": "qq",
'chat_info_group_id': 'group1', "chat_info_group_id": "group1",
'chat_info_group_name': '群组1', "chat_info_group_name": "群组1",
'chat_info_create_time': current_time - 7200, # 2小时前 "chat_info_create_time": current_time - 7200, # 2小时前
'chat_info_last_active_time': current_time - 1800, # 30分钟前 "chat_info_last_active_time": current_time - 1800, # 30分钟前
'user_platform': 'qq', "user_platform": "qq",
'user_id': 'user1', "user_id": "user1",
'user_nickname': '用户1', "user_nickname": "用户1",
'user_cardname': '卡片名1', "user_cardname": "卡片名1",
'processed_plain_text': '你好', "processed_plain_text": "你好",
'detailed_plain_text': '你好', "detailed_plain_text": "你好",
'memorized_times': 1 "memorized_times": 1,
}, },
{ {
'message_id': 'msg2', "message_id": "msg2",
'time': current_time - 1800, # 30分钟前 "time": current_time - 1800, # 30分钟前
'chat_id': 'chat1', "chat_id": "chat1",
'chat_info_stream_id': 'stream1', "chat_info_stream_id": "stream1",
'chat_info_platform': 'qq', "chat_info_platform": "qq",
'chat_info_user_platform': 'qq', "chat_info_user_platform": "qq",
'chat_info_user_id': 'user1', "chat_info_user_id": "user1",
'chat_info_user_nickname': '用户1', "chat_info_user_nickname": "用户1",
'chat_info_user_cardname': '卡片名1', "chat_info_user_cardname": "卡片名1",
'chat_info_group_platform': 'qq', "chat_info_group_platform": "qq",
'chat_info_group_id': 'group1', "chat_info_group_id": "group1",
'chat_info_group_name': '群组1', "chat_info_group_name": "群组1",
'chat_info_create_time': current_time - 7200, "chat_info_create_time": current_time - 7200,
'chat_info_last_active_time': current_time - 900, # 15分钟前 "chat_info_last_active_time": current_time - 900, # 15分钟前
'user_platform': 'qq', "user_platform": "qq",
'user_id': 'user1', "user_id": "user1",
'user_nickname': '用户1', "user_nickname": "用户1",
'user_cardname': '卡片名1', "user_cardname": "卡片名1",
'processed_plain_text': '世界', "processed_plain_text": "世界",
'detailed_plain_text': '世界', "detailed_plain_text": "世界",
'memorized_times': 2 "memorized_times": 2,
}, },
{ {
'message_id': 'msg3', "message_id": "msg3",
'time': current_time - 900, # 15分钟前 "time": current_time - 900, # 15分钟前
'chat_id': 'chat2', "chat_id": "chat2",
'chat_info_stream_id': 'stream2', "chat_info_stream_id": "stream2",
'chat_info_platform': 'wechat', "chat_info_platform": "wechat",
'chat_info_user_platform': 'wechat', "chat_info_user_platform": "wechat",
'chat_info_user_id': 'user2', "chat_info_user_id": "user2",
'chat_info_user_nickname': '用户2', "chat_info_user_nickname": "用户2",
'chat_info_user_cardname': '卡片名2', "chat_info_user_cardname": "卡片名2",
'chat_info_group_platform': 'wechat', "chat_info_group_platform": "wechat",
'chat_info_group_id': 'group2', "chat_info_group_id": "group2",
'chat_info_group_name': '群组2', "chat_info_group_name": "群组2",
'chat_info_create_time': current_time - 3600, "chat_info_create_time": current_time - 3600,
'chat_info_last_active_time': current_time - 600, # 10分钟前 "chat_info_last_active_time": current_time - 600, # 10分钟前
'user_platform': 'wechat', "user_platform": "wechat",
'user_id': 'user2', "user_id": "user2",
'user_nickname': '用户2', "user_nickname": "用户2",
'user_cardname': '卡片名2', "user_cardname": "卡片名2",
'processed_plain_text': '测试', "processed_plain_text": "测试",
'detailed_plain_text': '测试', "detailed_plain_text": "测试",
'memorized_times': 0 "memorized_times": 0,
} },
] ]
for msg_data in self.test_messages: for msg_data in self.test_messages:
@@ -110,65 +110,63 @@ class TestMessageRepository(unittest.TestCase):
results = find_messages({}) results = find_messages({})
self.assertEqual(len(results), 3) self.assertEqual(len(results), 3)
# 验证结果是否按时间升序排列 # 验证结果是否按时间升序排列
self.assertEqual(results[0]['message_id'], 'msg1') self.assertEqual(results[0]["message_id"], "msg1")
self.assertEqual(results[1]['message_id'], 'msg2') self.assertEqual(results[1]["message_id"], "msg2")
self.assertEqual(results[2]['message_id'], 'msg3') self.assertEqual(results[2]["message_id"], "msg3")
def test_find_messages_with_filter(self): def test_find_messages_with_filter(self):
"""测试带过滤器的查询""" """测试带过滤器的查询"""
results = find_messages({'chat_id': 'chat1'}) results = find_messages({"chat_id": "chat1"})
self.assertEqual(len(results), 2) self.assertEqual(len(results), 2)
self.assertEqual(results[0]['message_id'], 'msg1') self.assertEqual(results[0]["message_id"], "msg1")
self.assertEqual(results[1]['message_id'], 'msg2') self.assertEqual(results[1]["message_id"], "msg2")
results = find_messages({'user_id': 'user2'}) results = find_messages({"user_id": "user2"})
self.assertEqual(len(results), 1) self.assertEqual(len(results), 1)
self.assertEqual(results[0]['message_id'], 'msg3') self.assertEqual(results[0]["message_id"], "msg3")
def test_find_messages_with_operators(self): def test_find_messages_with_operators(self):
"""测试带操作符的查询""" """测试带操作符的查询"""
results = find_messages({'memorized_times': {'$gt': 0}}) results = find_messages({"memorized_times": {"$gt": 0}})
self.assertEqual(len(results), 2) self.assertEqual(len(results), 2)
self.assertEqual(results[0]['message_id'], 'msg1') self.assertEqual(results[0]["message_id"], "msg1")
self.assertEqual(results[1]['message_id'], 'msg2') self.assertEqual(results[1]["message_id"], "msg2")
results = find_messages({'memorized_times': {'$gte': 2}}) results = find_messages({"memorized_times": {"$gte": 2}})
self.assertEqual(len(results), 1) self.assertEqual(len(results), 1)
self.assertEqual(results[0]['message_id'], 'msg2') self.assertEqual(results[0]["message_id"], "msg2")
def test_find_messages_with_sort(self): def test_find_messages_with_sort(self):
"""测试带排序的查询""" """测试带排序的查询"""
results = find_messages({}, sort=[('memorized_times', -1)]) results = find_messages({}, sort=[("memorized_times", -1)])
self.assertEqual(len(results), 3) self.assertEqual(len(results), 3)
# 验证结果是否按memorized_times降序排列 # 验证结果是否按memorized_times降序排列
self.assertEqual(results[0]['message_id'], 'msg2') # memorized_times = 2 self.assertEqual(results[0]["message_id"], "msg2") # memorized_times = 2
self.assertEqual(results[1]['message_id'], 'msg1') # memorized_times = 1 self.assertEqual(results[1]["message_id"], "msg1") # memorized_times = 1
self.assertEqual(results[2]['message_id'], 'msg3') # memorized_times = 0 self.assertEqual(results[2]["message_id"], "msg3") # memorized_times = 0
def test_find_messages_with_limit(self): def test_find_messages_with_limit(self):
"""测试带限制的查询""" """测试带限制的查询"""
# 默认limit_mode为latest应返回最新的2条记录 # 默认limit_mode为latest应返回最新的2条记录
results = find_messages({}, limit=2) results = find_messages({}, limit=2)
self.assertEqual(len(results), 2) self.assertEqual(len(results), 2)
self.assertEqual(results[0]['message_id'], 'msg2') self.assertEqual(results[0]["message_id"], "msg2")
self.assertEqual(results[1]['message_id'], 'msg3') self.assertEqual(results[1]["message_id"], "msg3")
# 使用earliest模式应返回最早的2条记录 # 使用earliest模式应返回最早的2条记录
results = find_messages({}, limit=2, limit_mode='earliest') results = find_messages({}, limit=2, limit_mode="earliest")
self.assertEqual(len(results), 2) self.assertEqual(len(results), 2)
self.assertEqual(results[0]['message_id'], 'msg1') self.assertEqual(results[0]["message_id"], "msg1")
self.assertEqual(results[1]['message_id'], 'msg2') self.assertEqual(results[1]["message_id"], "msg2")
def test_find_messages_with_combined_criteria(self): def test_find_messages_with_combined_criteria(self):
"""测试组合查询条件""" """测试组合查询条件"""
results = find_messages( results = find_messages(
{'chat_info_platform': 'qq', 'memorized_times': {'$gt': 0}}, {"chat_info_platform": "qq", "memorized_times": {"$gt": 0}}, sort=[("time", 1)], limit=1
sort=[('time', 1)],
limit=1
) )
self.assertEqual(len(results), 1) self.assertEqual(len(results), 1)
self.assertEqual(results[0]['message_id'], 'msg2') self.assertEqual(results[0]["message_id"], "msg2")
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()

View File

@@ -9,7 +9,7 @@ import json
import copy import copy
# 添加项目根目录到Python路径 # 添加项目根目录到Python路径
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat, build_readable_messages from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat, build_readable_messages
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
@@ -17,10 +17,11 @@ from src.common.logger import get_module_logger
# 创建测试日志记录器 # 创建测试日志记录器
logger = get_module_logger("test_readable_msg") logger = get_module_logger("test_readable_msg")
class TestBuildReadableMessages(unittest.TestCase): class TestBuildReadableMessages(unittest.TestCase):
def setUp(self): def setUp(self):
# 准备测试数据:从真实数据库获取消息 # 准备测试数据:从真实数据库获取消息
self.chat_id = '5ed68437e28644da51f314f37df68d18' self.chat_id = "5ed68437e28644da51f314f37df68d18"
self.current_time = time.time() self.current_time = time.time()
self.thirty_days_ago = self.current_time - (30 * 24 * 60 * 60) # 30天前的时间戳 self.thirty_days_ago = self.current_time - (30 * 24 * 60 * 60) # 30天前的时间戳
@@ -31,7 +32,7 @@ class TestBuildReadableMessages(unittest.TestCase):
timestamp_start=self.thirty_days_ago, timestamp_start=self.thirty_days_ago,
timestamp_end=self.current_time, timestamp_end=self.current_time,
limit=10, limit=10,
limit_mode="latest" limit_mode="latest",
) )
logger.info(f"已获取 {len(self.messages)} 条测试消息") logger.info(f"已获取 {len(self.messages)} 条测试消息")
@@ -61,14 +62,14 @@ class TestBuildReadableMessages(unittest.TestCase):
fixed_msg = copy.deepcopy(msg) fixed_msg = copy.deepcopy(msg)
# 构建 user_info 对象 # 构建 user_info 对象
if 'user_info' not in fixed_msg: if "user_info" not in fixed_msg:
user_info = { user_info = {
'platform': fixed_msg.get('user_platform', 'qq'), "platform": fixed_msg.get("user_platform", "qq"),
'user_id': fixed_msg.get('user_id', '10000'), "user_id": fixed_msg.get("user_id", "10000"),
'user_nickname': fixed_msg.get('user_nickname', '测试用户'), "user_nickname": fixed_msg.get("user_nickname", "测试用户"),
'user_cardname': fixed_msg.get('user_cardname', '') "user_cardname": fixed_msg.get("user_cardname", ""),
} }
fixed_msg['user_info'] = user_info fixed_msg["user_info"] = user_info
logger.info(f"为消息 {fixed_msg.get('message_id')} 添加了 user_info") logger.info(f"为消息 {fixed_msg.get('message_id')} 添加了 user_info")
fixed_messages.append(fixed_msg) fixed_messages.append(fixed_msg)
@@ -77,14 +78,16 @@ class TestBuildReadableMessages(unittest.TestCase):
try: try:
# 使用修复后的消息尝试格式化 # 使用修复后的消息尝试格式化
formatted_text = asyncio.run(build_readable_messages( formatted_text = asyncio.run(
build_readable_messages(
messages=fixed_messages, messages=fixed_messages,
replace_bot_name=True, replace_bot_name=True,
merge_messages=False, merge_messages=False,
timestamp_mode="absolute", timestamp_mode="absolute",
read_mark=0.0, read_mark=0.0,
truncate=False truncate=False,
)) )
)
logger.info("使用修复后的消息格式化完成") logger.info("使用修复后的消息格式化完成")
logger.info(f"格式化结果长度: {len(formatted_text)}") logger.info(f"格式化结果长度: {len(formatted_text)}")
@@ -120,26 +123,24 @@ class TestBuildReadableMessages(unittest.TestCase):
logger.info(f"user_info存在: {'user_info' in test_msg}") logger.info(f"user_info存在: {'user_info' in test_msg}")
# 修复缺少的user_info字段 # 修复缺少的user_info字段
if 'user_info' not in test_msg: if "user_info" not in test_msg:
logger.warning("消息中缺少user_info字段添加模拟数据") logger.warning("消息中缺少user_info字段添加模拟数据")
test_msg['user_info'] = { test_msg["user_info"] = {
'platform': test_msg.get('user_platform', 'qq'), "platform": test_msg.get("user_platform", "qq"),
'user_id': test_msg.get('user_id', '10000'), "user_id": test_msg.get("user_id", "10000"),
'user_nickname': test_msg.get('user_nickname', '测试用户'), "user_nickname": test_msg.get("user_nickname", "测试用户"),
'user_cardname': test_msg.get('user_cardname', '') "user_cardname": test_msg.get("user_cardname", ""),
} }
logger.info(f"添加的user_info: {test_msg['user_info']}") logger.info(f"添加的user_info: {test_msg['user_info']}")
simple_msgs = [test_msg] simple_msgs = [test_msg]
# 运行内部函数 # 运行内部函数
result_text, result_details = asyncio.run(_build_readable_messages_internal( result_text, result_details = asyncio.run(
simple_msgs, _build_readable_messages_internal(
replace_bot_name=True, simple_msgs, replace_bot_name=True, merge_messages=False, timestamp_mode="absolute", truncate=False
merge_messages=False, )
timestamp_mode="absolute", )
truncate=False
))
logger.info(f"内部函数返回结果: {result_text[:200] if result_text else ''}") logger.info(f"内部函数返回结果: {result_text[:200] if result_text else ''}")
logger.info(f"详情列表长度: {len(result_details)}") logger.info(f"详情列表长度: {len(result_details)}")
@@ -167,5 +168,6 @@ class TestBuildReadableMessages(unittest.TestCase):
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
raise raise
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main() unittest.main()

View File

@@ -5,13 +5,14 @@ import datetime
import time import time
# 添加项目根目录到Python路径 # 添加项目根目录到Python路径
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from src.common.message_repository import find_messages from src.common.message_repository import find_messages
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat
from peewee import SqliteDatabase from peewee import SqliteDatabase
from src.common.database.database import db # 导入实际的数据库连接 from src.common.database.database import db # 导入实际的数据库连接
class TestExtractMessages(unittest.TestCase): class TestExtractMessages(unittest.TestCase):
def setUp(self): def setUp(self):
# 这个测试使用真实的数据库,所以不需要创建测试数据 # 这个测试使用真实的数据库,所以不需要创建测试数据
@@ -19,13 +20,10 @@ class TestExtractMessages(unittest.TestCase):
def test_extract_latest_messages_direct(self): def test_extract_latest_messages_direct(self):
"""测试直接使用message_repository.find_messages函数""" """测试直接使用message_repository.find_messages函数"""
chat_id = '5ed68437e28644da51f314f37df68d18' chat_id = "5ed68437e28644da51f314f37df68d18"
# 提取最新的10条消息 # 提取最新的10条消息
results = find_messages( results = find_messages({"chat_id": chat_id}, limit=10)
{'chat_id': chat_id},
limit=10
)
# 打印结果数量 # 打印结果数量
print(f"\n直接使用find_messages找到 {len(results)} 条消息") print(f"\n直接使用find_messages找到 {len(results)} 条消息")
@@ -34,12 +32,12 @@ class TestExtractMessages(unittest.TestCase):
if results: if results:
print("\n消息时间顺序:") print("\n消息时间顺序:")
for idx, msg in enumerate(results): for idx, msg in enumerate(results):
msg_time = datetime.datetime.fromtimestamp(msg['time']).strftime('%Y-%m-%d %H:%M:%S') msg_time = datetime.datetime.fromtimestamp(msg["time"]).strftime("%Y-%m-%d %H:%M:%S")
print(f"{idx + 1}. ID: {msg['message_id']}, 时间: {msg_time}") print(f"{idx + 1}. ID: {msg['message_id']}, 时间: {msg_time}")
print(f" 文本: {msg.get('processed_plain_text', '无文本内容')[:50]}...") print(f" 文本: {msg.get('processed_plain_text', '无文本内容')[:50]}...")
# 验证结果按时间排序 # 验证结果按时间排序
times = [msg['time'] for msg in results] times = [msg["time"] for msg in results]
self.assertEqual(times, sorted(times), "消息应该按时间升序排列") self.assertEqual(times, sorted(times), "消息应该按时间升序排列")
else: else:
print(f"未找到chat_id为 {chat_id} 的消息") print(f"未找到chat_id为 {chat_id} 的消息")
@@ -49,7 +47,7 @@ class TestExtractMessages(unittest.TestCase):
def test_extract_latest_messages_via_builder(self): def test_extract_latest_messages_via_builder(self):
"""使用chat_message_builder中的函数测试从真实数据库提取消息""" """使用chat_message_builder中的函数测试从真实数据库提取消息"""
chat_id = '5ed68437e28644da51f314f37df68d18' chat_id = "5ed68437e28644da51f314f37df68d18"
# 设置时间范围为过去30天到现在 # 设置时间范围为过去30天到现在
current_time = time.time() current_time = time.time()
@@ -57,11 +55,7 @@ class TestExtractMessages(unittest.TestCase):
# 使用chat_message_builder中的函数 # 使用chat_message_builder中的函数
results = get_raw_msg_by_timestamp_with_chat( results = get_raw_msg_by_timestamp_with_chat(
chat_id=chat_id, chat_id=chat_id, timestamp_start=thirty_days_ago, timestamp_end=current_time, limit=10, limit_mode="latest"
timestamp_start=thirty_days_ago,
timestamp_end=current_time,
limit=10,
limit_mode="latest"
) )
# 打印结果数量 # 打印结果数量
@@ -71,12 +65,12 @@ class TestExtractMessages(unittest.TestCase):
if results: if results:
print("\n消息时间顺序:") print("\n消息时间顺序:")
for idx, msg in enumerate(results): for idx, msg in enumerate(results):
msg_time = datetime.datetime.fromtimestamp(msg['time']).strftime('%Y-%m-%d %H:%M:%S') msg_time = datetime.datetime.fromtimestamp(msg["time"]).strftime("%Y-%m-%d %H:%M:%S")
print(f"{idx + 1}. ID: {msg['message_id']}, 时间: {msg_time}") print(f"{idx + 1}. ID: {msg['message_id']}, 时间: {msg_time}")
print(f" 文本: {msg.get('processed_plain_text', '无文本内容')[:50]}...") print(f" 文本: {msg.get('processed_plain_text', '无文本内容')[:50]}...")
# 验证结果按时间排序 # 验证结果按时间排序
times = [msg['time'] for msg in results] times = [msg["time"] for msg in results]
self.assertEqual(times, sorted(times), "消息应该按时间升序排列") self.assertEqual(times, sorted(times), "消息应该按时间升序排列")
else: else:
print(f"未找到chat_id为 {chat_id} 的消息") print(f"未找到chat_id为 {chat_id} 的消息")
@@ -84,5 +78,6 @@ class TestExtractMessages(unittest.TestCase):
# 最基本的断言,确保测试有效 # 最基本的断言,确保测试有效
self.assertIsInstance(results, list, "结果应该是一个列表") self.assertIsInstance(results, list, "结果应该是一个列表")
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main() unittest.main()