fix;修复提取消息和运行bug
This commit is contained in:
@@ -63,13 +63,16 @@ class ChattingInfoProcessor(BaseProcessor):
|
|||||||
|
|
||||||
# 设置说话消息
|
# 设置说话消息
|
||||||
if hasattr(obs, "talking_message_str"):
|
if hasattr(obs, "talking_message_str"):
|
||||||
|
print(f"设置说话消息:obs.talking_message_str: {obs.talking_message_str}")
|
||||||
obs_info.set_talking_message(obs.talking_message_str)
|
obs_info.set_talking_message(obs.talking_message_str)
|
||||||
|
|
||||||
# 设置截断后的说话消息
|
# 设置截断后的说话消息
|
||||||
if hasattr(obs, "talking_message_str_truncate"):
|
if hasattr(obs, "talking_message_str_truncate"):
|
||||||
|
print(f"设置截断后的说话消息:obs.talking_message_str_truncate: {obs.talking_message_str_truncate}")
|
||||||
obs_info.set_talking_message_str_truncate(obs.talking_message_str_truncate)
|
obs_info.set_talking_message_str_truncate(obs.talking_message_str_truncate)
|
||||||
|
|
||||||
if hasattr(obs, "mid_memory_info"):
|
if hasattr(obs, "mid_memory_info"):
|
||||||
|
print(f"设置之前聊天信息:obs.mid_memory_info: {obs.mid_memory_info}")
|
||||||
obs_info.set_previous_chat_info(obs.mid_memory_info)
|
obs_info.set_previous_chat_info(obs.mid_memory_info)
|
||||||
|
|
||||||
# 设置聊天类型
|
# 设置聊天类型
|
||||||
|
|||||||
@@ -140,8 +140,23 @@ class ChattingObservation(Observation):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# logger.debug(f"找到的锚定消息:find_msg: {find_msg}")
|
# logger.debug(f"找到的锚定消息:find_msg: {find_msg}")
|
||||||
group_info = find_msg.get("chat_info", {}).get("group_info")
|
|
||||||
user_info = find_msg.get("chat_info", {}).get("user_info")
|
# 创建所需的user_info字段
|
||||||
|
user_info = {
|
||||||
|
"platform": find_msg.get("user_platform", ""),
|
||||||
|
"user_id": find_msg.get("user_id", ""),
|
||||||
|
"user_nickname": find_msg.get("user_nickname", ""),
|
||||||
|
"user_cardname": find_msg.get("user_cardname", "")
|
||||||
|
}
|
||||||
|
|
||||||
|
# 创建所需的group_info字段,如果是群聊的话
|
||||||
|
group_info = {}
|
||||||
|
if find_msg.get("chat_info_group_id"):
|
||||||
|
group_info = {
|
||||||
|
"platform": find_msg.get("chat_info_group_platform", ""),
|
||||||
|
"group_id": find_msg.get("chat_info_group_id", ""),
|
||||||
|
"group_name": find_msg.get("chat_info_group_name", "")
|
||||||
|
}
|
||||||
|
|
||||||
content_format = ""
|
content_format = ""
|
||||||
accept_format = ""
|
accept_format = ""
|
||||||
@@ -181,6 +196,8 @@ class ChattingObservation(Observation):
|
|||||||
limit=self.max_now_obs_len,
|
limit=self.max_now_obs_len,
|
||||||
limit_mode="latest",
|
limit_mode="latest",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# print(f"new_messages_list: {new_messages_list}")
|
||||||
|
|
||||||
last_obs_time_mark = self.last_observe_time
|
last_obs_time_mark = self.last_observe_time
|
||||||
if new_messages_list:
|
if new_messages_list:
|
||||||
@@ -193,6 +210,7 @@ class ChattingObservation(Observation):
|
|||||||
oldest_messages = self.talking_message[:messages_to_remove_count]
|
oldest_messages = self.talking_message[:messages_to_remove_count]
|
||||||
self.talking_message = self.talking_message[messages_to_remove_count:] # 保留后半部分,即最新的
|
self.talking_message = self.talking_message[messages_to_remove_count:] # 保留后半部分,即最新的
|
||||||
|
|
||||||
|
# print(f"压缩中:oldest_messages: {oldest_messages}")
|
||||||
oldest_messages_str = await build_readable_messages(
|
oldest_messages_str = await build_readable_messages(
|
||||||
messages=oldest_messages, timestamp_mode="normal", read_mark=0
|
messages=oldest_messages, timestamp_mode="normal", read_mark=0
|
||||||
)
|
)
|
||||||
@@ -235,21 +253,24 @@ class ChattingObservation(Observation):
|
|||||||
self.oldest_messages = oldest_messages
|
self.oldest_messages = oldest_messages
|
||||||
self.oldest_messages_str = oldest_messages_str
|
self.oldest_messages_str = oldest_messages_str
|
||||||
|
|
||||||
|
# 构建中
|
||||||
|
# print(f"构建中:self.talking_message: {self.talking_message}")
|
||||||
self.talking_message_str = await build_readable_messages(
|
self.talking_message_str = await build_readable_messages(
|
||||||
messages=self.talking_message,
|
messages=self.talking_message,
|
||||||
timestamp_mode="lite",
|
timestamp_mode="lite",
|
||||||
read_mark=last_obs_time_mark,
|
read_mark=last_obs_time_mark,
|
||||||
)
|
)
|
||||||
|
# print(f"构建中:self.talking_message_str: {self.talking_message_str}")
|
||||||
self.talking_message_str_truncate = await build_readable_messages(
|
self.talking_message_str_truncate = await build_readable_messages(
|
||||||
messages=self.talking_message,
|
messages=self.talking_message,
|
||||||
timestamp_mode="normal",
|
timestamp_mode="normal",
|
||||||
read_mark=last_obs_time_mark,
|
read_mark=last_obs_time_mark,
|
||||||
truncate=True,
|
truncate=True,
|
||||||
)
|
)
|
||||||
|
# print(f"构建中:self.talking_message_str_truncate: {self.talking_message_str_truncate}")
|
||||||
|
|
||||||
self.person_list = await get_person_id_list(self.talking_message)
|
self.person_list = await get_person_id_list(self.talking_message)
|
||||||
|
# print(f"构建中:self.person_list: {self.person_list}")
|
||||||
# print(f"self.11111person_list: {self.person_list}")
|
|
||||||
|
|
||||||
logger.trace(
|
logger.trace(
|
||||||
f"Chat {self.chat_id} - 压缩早期记忆:{self.mid_memory_info}\n现在聊天内容:{self.talking_message_str}"
|
f"Chat {self.chat_id} - 压缩早期记忆:{self.mid_memory_info}\n现在聊天内容:{self.talking_message_str}"
|
||||||
|
|||||||
@@ -174,6 +174,16 @@ async def _build_readable_messages_internal(
|
|||||||
|
|
||||||
# 1 & 2: 获取发送者信息并提取消息组件
|
# 1 & 2: 获取发送者信息并提取消息组件
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
|
# 检查并修复缺少的user_info字段
|
||||||
|
if 'user_info' not in msg:
|
||||||
|
# 创建user_info字段
|
||||||
|
msg['user_info'] = {
|
||||||
|
'platform': msg.get('user_platform', ''),
|
||||||
|
'user_id': msg.get('user_id', ''),
|
||||||
|
'user_nickname': msg.get('user_nickname', ''),
|
||||||
|
'user_cardname': msg.get('user_cardname', '')
|
||||||
|
}
|
||||||
|
|
||||||
user_info = msg.get("user_info", {})
|
user_info = msg.get("user_info", {})
|
||||||
platform = user_info.get("platform")
|
platform = user_info.get("platform")
|
||||||
user_id = user_info.get("user_id")
|
user_id = user_info.get("user_id")
|
||||||
|
|||||||
@@ -380,7 +380,7 @@ def process_llm_response(text: str) -> list[str]:
|
|||||||
# sentences.append(content)
|
# sentences.append(content)
|
||||||
|
|
||||||
# 在所有句子处理完毕后,对包含占位符的列表进行恢复
|
# 在所有句子处理完毕后,对包含占位符的列表进行恢复
|
||||||
if global_config.enable_kaomoji_protection:
|
if global_config.response_splitter.enable_kaomoji_protection:
|
||||||
sentences = recover_kaomoji(sentences, kaomoji_mapping)
|
sentences = recover_kaomoji(sentences, kaomoji_mapping)
|
||||||
|
|
||||||
return sentences
|
return sentences
|
||||||
|
|||||||
174
tests/common/test_message_repository.py
Normal file
174
tests/common/test_message_repository.py
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
import unittest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
import datetime
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
# 添加项目根目录到Python路径
|
||||||
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
|
||||||
|
|
||||||
|
from peewee import SqliteDatabase
|
||||||
|
from src.common.database.database_model import Messages, BaseModel
|
||||||
|
from src.common.message_repository import find_messages
|
||||||
|
|
||||||
|
|
||||||
|
class TestMessageRepository(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
# 创建内存中的SQLite数据库用于测试
|
||||||
|
self.test_db = SqliteDatabase(':memory:')
|
||||||
|
|
||||||
|
# 覆盖原有数据库连接
|
||||||
|
BaseModel._meta.database = self.test_db
|
||||||
|
Messages._meta.database = self.test_db
|
||||||
|
|
||||||
|
# 创建表
|
||||||
|
self.test_db.create_tables([Messages])
|
||||||
|
|
||||||
|
# 添加测试数据
|
||||||
|
current_time = datetime.datetime.now().timestamp()
|
||||||
|
self.test_messages = [
|
||||||
|
{
|
||||||
|
'message_id': 'msg1',
|
||||||
|
'time': current_time - 3600, # 1小时前
|
||||||
|
'chat_id': '5ed68437e28644da51f314f37df68d18',
|
||||||
|
'chat_info_stream_id': 'stream1',
|
||||||
|
'chat_info_platform': 'qq',
|
||||||
|
'chat_info_user_platform': 'qq',
|
||||||
|
'chat_info_user_id': 'user1',
|
||||||
|
'chat_info_user_nickname': '用户1',
|
||||||
|
'chat_info_user_cardname': '卡片名1',
|
||||||
|
'chat_info_group_platform': 'qq',
|
||||||
|
'chat_info_group_id': 'group1',
|
||||||
|
'chat_info_group_name': '群组1',
|
||||||
|
'chat_info_create_time': current_time - 7200, # 2小时前
|
||||||
|
'chat_info_last_active_time': current_time - 1800, # 30分钟前
|
||||||
|
'user_platform': 'qq',
|
||||||
|
'user_id': 'user1',
|
||||||
|
'user_nickname': '用户1',
|
||||||
|
'user_cardname': '卡片名1',
|
||||||
|
'processed_plain_text': '你好',
|
||||||
|
'detailed_plain_text': '你好',
|
||||||
|
'memorized_times': 1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'message_id': 'msg2',
|
||||||
|
'time': current_time - 1800, # 30分钟前
|
||||||
|
'chat_id': 'chat1',
|
||||||
|
'chat_info_stream_id': 'stream1',
|
||||||
|
'chat_info_platform': 'qq',
|
||||||
|
'chat_info_user_platform': 'qq',
|
||||||
|
'chat_info_user_id': 'user1',
|
||||||
|
'chat_info_user_nickname': '用户1',
|
||||||
|
'chat_info_user_cardname': '卡片名1',
|
||||||
|
'chat_info_group_platform': 'qq',
|
||||||
|
'chat_info_group_id': 'group1',
|
||||||
|
'chat_info_group_name': '群组1',
|
||||||
|
'chat_info_create_time': current_time - 7200,
|
||||||
|
'chat_info_last_active_time': current_time - 900, # 15分钟前
|
||||||
|
'user_platform': 'qq',
|
||||||
|
'user_id': 'user1',
|
||||||
|
'user_nickname': '用户1',
|
||||||
|
'user_cardname': '卡片名1',
|
||||||
|
'processed_plain_text': '世界',
|
||||||
|
'detailed_plain_text': '世界',
|
||||||
|
'memorized_times': 2
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'message_id': 'msg3',
|
||||||
|
'time': current_time - 900, # 15分钟前
|
||||||
|
'chat_id': 'chat2',
|
||||||
|
'chat_info_stream_id': 'stream2',
|
||||||
|
'chat_info_platform': 'wechat',
|
||||||
|
'chat_info_user_platform': 'wechat',
|
||||||
|
'chat_info_user_id': 'user2',
|
||||||
|
'chat_info_user_nickname': '用户2',
|
||||||
|
'chat_info_user_cardname': '卡片名2',
|
||||||
|
'chat_info_group_platform': 'wechat',
|
||||||
|
'chat_info_group_id': 'group2',
|
||||||
|
'chat_info_group_name': '群组2',
|
||||||
|
'chat_info_create_time': current_time - 3600,
|
||||||
|
'chat_info_last_active_time': current_time - 600, # 10分钟前
|
||||||
|
'user_platform': 'wechat',
|
||||||
|
'user_id': 'user2',
|
||||||
|
'user_nickname': '用户2',
|
||||||
|
'user_cardname': '卡片名2',
|
||||||
|
'processed_plain_text': '测试',
|
||||||
|
'detailed_plain_text': '测试',
|
||||||
|
'memorized_times': 0
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
for msg_data in self.test_messages:
|
||||||
|
Messages.create(**msg_data)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
# 关闭测试数据库连接
|
||||||
|
self.test_db.close()
|
||||||
|
|
||||||
|
def test_find_messages_no_filter(self):
|
||||||
|
"""测试不带过滤器的查询"""
|
||||||
|
results = find_messages({})
|
||||||
|
self.assertEqual(len(results), 3)
|
||||||
|
# 验证结果是否按时间升序排列
|
||||||
|
self.assertEqual(results[0]['message_id'], 'msg1')
|
||||||
|
self.assertEqual(results[1]['message_id'], 'msg2')
|
||||||
|
self.assertEqual(results[2]['message_id'], 'msg3')
|
||||||
|
|
||||||
|
def test_find_messages_with_filter(self):
|
||||||
|
"""测试带过滤器的查询"""
|
||||||
|
results = find_messages({'chat_id': 'chat1'})
|
||||||
|
self.assertEqual(len(results), 2)
|
||||||
|
self.assertEqual(results[0]['message_id'], 'msg1')
|
||||||
|
self.assertEqual(results[1]['message_id'], 'msg2')
|
||||||
|
|
||||||
|
results = find_messages({'user_id': 'user2'})
|
||||||
|
self.assertEqual(len(results), 1)
|
||||||
|
self.assertEqual(results[0]['message_id'], 'msg3')
|
||||||
|
|
||||||
|
def test_find_messages_with_operators(self):
|
||||||
|
"""测试带操作符的查询"""
|
||||||
|
results = find_messages({'memorized_times': {'$gt': 0}})
|
||||||
|
self.assertEqual(len(results), 2)
|
||||||
|
self.assertEqual(results[0]['message_id'], 'msg1')
|
||||||
|
self.assertEqual(results[1]['message_id'], 'msg2')
|
||||||
|
|
||||||
|
results = find_messages({'memorized_times': {'$gte': 2}})
|
||||||
|
self.assertEqual(len(results), 1)
|
||||||
|
self.assertEqual(results[0]['message_id'], 'msg2')
|
||||||
|
|
||||||
|
def test_find_messages_with_sort(self):
|
||||||
|
"""测试带排序的查询"""
|
||||||
|
results = find_messages({}, sort=[('memorized_times', -1)])
|
||||||
|
self.assertEqual(len(results), 3)
|
||||||
|
# 验证结果是否按memorized_times降序排列
|
||||||
|
self.assertEqual(results[0]['message_id'], 'msg2') # memorized_times = 2
|
||||||
|
self.assertEqual(results[1]['message_id'], 'msg1') # memorized_times = 1
|
||||||
|
self.assertEqual(results[2]['message_id'], 'msg3') # memorized_times = 0
|
||||||
|
|
||||||
|
def test_find_messages_with_limit(self):
|
||||||
|
"""测试带限制的查询"""
|
||||||
|
# 默认limit_mode为latest,应返回最新的2条记录
|
||||||
|
results = find_messages({}, limit=2)
|
||||||
|
self.assertEqual(len(results), 2)
|
||||||
|
self.assertEqual(results[0]['message_id'], 'msg2')
|
||||||
|
self.assertEqual(results[1]['message_id'], 'msg3')
|
||||||
|
|
||||||
|
# 使用earliest模式,应返回最早的2条记录
|
||||||
|
results = find_messages({}, limit=2, limit_mode='earliest')
|
||||||
|
self.assertEqual(len(results), 2)
|
||||||
|
self.assertEqual(results[0]['message_id'], 'msg1')
|
||||||
|
self.assertEqual(results[1]['message_id'], 'msg2')
|
||||||
|
|
||||||
|
def test_find_messages_with_combined_criteria(self):
|
||||||
|
"""测试组合查询条件"""
|
||||||
|
results = find_messages(
|
||||||
|
{'chat_info_platform': 'qq', 'memorized_times': {'$gt': 0}},
|
||||||
|
sort=[('time', 1)],
|
||||||
|
limit=1
|
||||||
|
)
|
||||||
|
self.assertEqual(len(results), 1)
|
||||||
|
self.assertEqual(results[0]['message_id'], 'msg2')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
171
tests/test_build_readable_messages.py
Normal file
171
tests/test_build_readable_messages.py
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
import unittest
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import datetime
|
||||||
|
import time
|
||||||
|
import asyncio
|
||||||
|
import traceback
|
||||||
|
import json
|
||||||
|
import copy
|
||||||
|
|
||||||
|
# 添加项目根目录到Python路径
|
||||||
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||||
|
|
||||||
|
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat, build_readable_messages
|
||||||
|
from src.common.logger import get_module_logger
|
||||||
|
|
||||||
|
# 创建测试日志记录器
|
||||||
|
logger = get_module_logger("test_readable_msg")
|
||||||
|
|
||||||
|
class TestBuildReadableMessages(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
# 准备测试数据:从真实数据库获取消息
|
||||||
|
self.chat_id = '5ed68437e28644da51f314f37df68d18'
|
||||||
|
self.current_time = time.time()
|
||||||
|
self.thirty_days_ago = self.current_time - (30 * 24 * 60 * 60) # 30天前的时间戳
|
||||||
|
|
||||||
|
# 获取最新的10条消息
|
||||||
|
try:
|
||||||
|
self.messages = get_raw_msg_by_timestamp_with_chat(
|
||||||
|
chat_id=self.chat_id,
|
||||||
|
timestamp_start=self.thirty_days_ago,
|
||||||
|
timestamp_end=self.current_time,
|
||||||
|
limit=10,
|
||||||
|
limit_mode="latest"
|
||||||
|
)
|
||||||
|
logger.info(f"已获取 {len(self.messages)} 条测试消息")
|
||||||
|
|
||||||
|
# 打印消息样例
|
||||||
|
if self.messages:
|
||||||
|
sample_msg = self.messages[0]
|
||||||
|
logger.info(f"消息样例: {list(sample_msg.keys())}")
|
||||||
|
logger.info(f"消息内容: {sample_msg.get('processed_plain_text', '无文本内容')[:50]}...")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取消息失败: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
self.messages = []
|
||||||
|
|
||||||
|
def test_manual_fix_messages(self):
|
||||||
|
"""创建一个手动修复版本的消息进行测试"""
|
||||||
|
if not self.messages:
|
||||||
|
self.skipTest("没有测试消息,跳过测试")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("开始手动修复消息...")
|
||||||
|
|
||||||
|
# 创建修复版本的消息列表
|
||||||
|
fixed_messages = []
|
||||||
|
|
||||||
|
for msg in self.messages:
|
||||||
|
# 深拷贝以避免修改原始数据
|
||||||
|
fixed_msg = copy.deepcopy(msg)
|
||||||
|
|
||||||
|
# 构建 user_info 对象
|
||||||
|
if 'user_info' not in fixed_msg:
|
||||||
|
user_info = {
|
||||||
|
'platform': fixed_msg.get('user_platform', 'qq'),
|
||||||
|
'user_id': fixed_msg.get('user_id', '10000'),
|
||||||
|
'user_nickname': fixed_msg.get('user_nickname', '测试用户'),
|
||||||
|
'user_cardname': fixed_msg.get('user_cardname', '')
|
||||||
|
}
|
||||||
|
fixed_msg['user_info'] = user_info
|
||||||
|
logger.info(f"为消息 {fixed_msg.get('message_id')} 添加了 user_info")
|
||||||
|
|
||||||
|
fixed_messages.append(fixed_msg)
|
||||||
|
|
||||||
|
logger.info(f"已修复 {len(fixed_messages)} 条消息")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 使用修复后的消息尝试格式化
|
||||||
|
formatted_text = asyncio.run(build_readable_messages(
|
||||||
|
messages=fixed_messages,
|
||||||
|
replace_bot_name=True,
|
||||||
|
merge_messages=False,
|
||||||
|
timestamp_mode="absolute",
|
||||||
|
read_mark=0.0,
|
||||||
|
truncate=False
|
||||||
|
))
|
||||||
|
|
||||||
|
logger.info("使用修复后的消息格式化完成")
|
||||||
|
logger.info(f"格式化结果长度: {len(formatted_text)}")
|
||||||
|
if formatted_text:
|
||||||
|
logger.info(f"格式化结果预览: {formatted_text[:200]}...")
|
||||||
|
else:
|
||||||
|
logger.warning("格式化结果为空")
|
||||||
|
|
||||||
|
# 断言
|
||||||
|
self.assertNotEqual(formatted_text, "", "有消息时不应返回空字符串")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"使用修复后的消息格式化失败: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
raise
|
||||||
|
|
||||||
|
def test_debug_build_messages_internal(self):
|
||||||
|
"""调试_build_readable_messages_internal函数"""
|
||||||
|
if not self.messages:
|
||||||
|
self.skipTest("没有测试消息,跳过测试")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info("开始调试内部构建函数...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 直接导入内部函数进行测试
|
||||||
|
from src.chat.utils.chat_message_builder import _build_readable_messages_internal
|
||||||
|
|
||||||
|
# 手动创建一个简单的测试消息列表
|
||||||
|
test_msg = self.messages[0].copy() # 使用第一条消息作为模板
|
||||||
|
|
||||||
|
# 检查消息结构
|
||||||
|
logger.info(f"测试消息keys: {list(test_msg.keys())}")
|
||||||
|
logger.info(f"user_info存在: {'user_info' in test_msg}")
|
||||||
|
|
||||||
|
# 修复缺少的user_info字段
|
||||||
|
if 'user_info' not in test_msg:
|
||||||
|
logger.warning("消息中缺少user_info字段,添加模拟数据")
|
||||||
|
test_msg['user_info'] = {
|
||||||
|
'platform': test_msg.get('user_platform', 'qq'),
|
||||||
|
'user_id': test_msg.get('user_id', '10000'),
|
||||||
|
'user_nickname': test_msg.get('user_nickname', '测试用户'),
|
||||||
|
'user_cardname': test_msg.get('user_cardname', '')
|
||||||
|
}
|
||||||
|
logger.info(f"添加的user_info: {test_msg['user_info']}")
|
||||||
|
|
||||||
|
simple_msgs = [test_msg]
|
||||||
|
|
||||||
|
# 运行内部函数
|
||||||
|
result_text, result_details = asyncio.run(_build_readable_messages_internal(
|
||||||
|
simple_msgs,
|
||||||
|
replace_bot_name=True,
|
||||||
|
merge_messages=False,
|
||||||
|
timestamp_mode="absolute",
|
||||||
|
truncate=False
|
||||||
|
))
|
||||||
|
|
||||||
|
logger.info(f"内部函数返回结果: {result_text[:200] if result_text else '空'}")
|
||||||
|
logger.info(f"详情列表长度: {len(result_details)}")
|
||||||
|
|
||||||
|
# 显示处理过程中的变量
|
||||||
|
if not result_text and len(simple_msgs) > 0:
|
||||||
|
logger.warning("消息处理可能有问题,检查关键步骤")
|
||||||
|
msg = simple_msgs[0]
|
||||||
|
|
||||||
|
# 打印关键变量的值
|
||||||
|
user_info = msg.get("user_info", {})
|
||||||
|
platform = user_info.get("platform")
|
||||||
|
user_id = user_info.get("user_id")
|
||||||
|
timestamp = msg.get("time")
|
||||||
|
content = msg.get("processed_plain_text", "")
|
||||||
|
|
||||||
|
logger.warning(f"平台: {platform}, 用户ID: {user_id}, 时间戳: {timestamp}")
|
||||||
|
logger.warning(f"内容: {content[:50]}...")
|
||||||
|
|
||||||
|
# 检查必要信息是否完整
|
||||||
|
logger.warning(f"必要信息完整性检查: {all([platform, user_id, timestamp is not None])}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"调试内部函数失败: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
raise
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
88
tests/test_extract_messages.py
Normal file
88
tests/test_extract_messages.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
import unittest
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import datetime
|
||||||
|
import time
|
||||||
|
|
||||||
|
# 添加项目根目录到Python路径
|
||||||
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||||
|
|
||||||
|
from src.common.message_repository import find_messages
|
||||||
|
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat
|
||||||
|
from peewee import SqliteDatabase
|
||||||
|
from src.common.database.database import db # 导入实际的数据库连接
|
||||||
|
|
||||||
|
class TestExtractMessages(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
# 这个测试使用真实的数据库,所以不需要创建测试数据
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_extract_latest_messages_direct(self):
|
||||||
|
"""测试直接使用message_repository.find_messages函数"""
|
||||||
|
chat_id = '5ed68437e28644da51f314f37df68d18'
|
||||||
|
|
||||||
|
# 提取最新的10条消息
|
||||||
|
results = find_messages(
|
||||||
|
{'chat_id': chat_id},
|
||||||
|
limit=10
|
||||||
|
)
|
||||||
|
|
||||||
|
# 打印结果数量
|
||||||
|
print(f"\n直接使用find_messages,找到 {len(results)} 条消息")
|
||||||
|
|
||||||
|
# 如果有结果,打印一些信息
|
||||||
|
if results:
|
||||||
|
print("\n消息时间顺序:")
|
||||||
|
for idx, msg in enumerate(results):
|
||||||
|
msg_time = datetime.datetime.fromtimestamp(msg['time']).strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
print(f"{idx+1}. ID: {msg['message_id']}, 时间: {msg_time}")
|
||||||
|
print(f" 文本: {msg.get('processed_plain_text', '无文本内容')[:50]}...")
|
||||||
|
|
||||||
|
# 验证结果按时间排序
|
||||||
|
times = [msg['time'] for msg in results]
|
||||||
|
self.assertEqual(times, sorted(times), "消息应该按时间升序排列")
|
||||||
|
else:
|
||||||
|
print(f"未找到chat_id为 {chat_id} 的消息")
|
||||||
|
|
||||||
|
# 最基本的断言,确保测试有效
|
||||||
|
self.assertIsInstance(results, list, "结果应该是一个列表")
|
||||||
|
|
||||||
|
def test_extract_latest_messages_via_builder(self):
|
||||||
|
"""使用chat_message_builder中的函数测试从真实数据库提取消息"""
|
||||||
|
chat_id = '5ed68437e28644da51f314f37df68d18'
|
||||||
|
|
||||||
|
# 设置时间范围为过去30天到现在
|
||||||
|
current_time = time.time()
|
||||||
|
thirty_days_ago = current_time - (30 * 24 * 60 * 60) # 30天前的时间戳
|
||||||
|
|
||||||
|
# 使用chat_message_builder中的函数
|
||||||
|
results = get_raw_msg_by_timestamp_with_chat(
|
||||||
|
chat_id=chat_id,
|
||||||
|
timestamp_start=thirty_days_ago,
|
||||||
|
timestamp_end=current_time,
|
||||||
|
limit=10,
|
||||||
|
limit_mode="latest"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 打印结果数量
|
||||||
|
print(f"\n使用get_raw_msg_by_timestamp_with_chat,找到 {len(results)} 条消息")
|
||||||
|
|
||||||
|
# 如果有结果,打印一些信息
|
||||||
|
if results:
|
||||||
|
print("\n消息时间顺序:")
|
||||||
|
for idx, msg in enumerate(results):
|
||||||
|
msg_time = datetime.datetime.fromtimestamp(msg['time']).strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
print(f"{idx+1}. ID: {msg['message_id']}, 时间: {msg_time}")
|
||||||
|
print(f" 文本: {msg.get('processed_plain_text', '无文本内容')[:50]}...")
|
||||||
|
|
||||||
|
# 验证结果按时间排序
|
||||||
|
times = [msg['time'] for msg in results]
|
||||||
|
self.assertEqual(times, sorted(times), "消息应该按时间升序排列")
|
||||||
|
else:
|
||||||
|
print(f"未找到chat_id为 {chat_id} 的消息")
|
||||||
|
|
||||||
|
# 最基本的断言,确保测试有效
|
||||||
|
self.assertIsInstance(results, list, "结果应该是一个列表")
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user