From d26d69de60b92b020894dfe8a57aad02f4206bf0 Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Sat, 17 May 2025 01:03:20 +0800 Subject: [PATCH] =?UTF-8?q?fix=EF=BC=9B=E4=BF=AE=E5=A4=8D=E6=8F=90?= =?UTF-8?q?=E5=8F=96=E6=B6=88=E6=81=AF=E5=92=8C=E8=BF=90=E8=A1=8Cbug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../info_processors/chattinginfo_processor.py | 3 + .../observation/chatting_observation.py | 29 ++- src/chat/utils/chat_message_builder.py | 10 + src/chat/utils/utils.py | 2 +- tests/common/test_message_repository.py | 174 ++++++++++++++++++ tests/test_build_readable_messages.py | 171 +++++++++++++++++ tests/test_extract_messages.py | 88 +++++++++ 7 files changed, 472 insertions(+), 5 deletions(-) create mode 100644 tests/common/test_message_repository.py create mode 100644 tests/test_build_readable_messages.py create mode 100644 tests/test_extract_messages.py diff --git a/src/chat/focus_chat/info_processors/chattinginfo_processor.py b/src/chat/focus_chat/info_processors/chattinginfo_processor.py index c9641b9b7..5a72bcd9e 100644 --- a/src/chat/focus_chat/info_processors/chattinginfo_processor.py +++ b/src/chat/focus_chat/info_processors/chattinginfo_processor.py @@ -63,13 +63,16 @@ class ChattingInfoProcessor(BaseProcessor): # 设置说话消息 if hasattr(obs, "talking_message_str"): + print(f"设置说话消息:obs.talking_message_str: {obs.talking_message_str}") obs_info.set_talking_message(obs.talking_message_str) # 设置截断后的说话消息 if hasattr(obs, "talking_message_str_truncate"): + print(f"设置截断后的说话消息:obs.talking_message_str_truncate: {obs.talking_message_str_truncate}") obs_info.set_talking_message_str_truncate(obs.talking_message_str_truncate) if hasattr(obs, "mid_memory_info"): + print(f"设置之前聊天信息:obs.mid_memory_info: {obs.mid_memory_info}") obs_info.set_previous_chat_info(obs.mid_memory_info) # 设置聊天类型 diff --git a/src/chat/heart_flow/observation/chatting_observation.py b/src/chat/heart_flow/observation/chatting_observation.py index 7e4872014..415e4b100 100644 --- a/src/chat/heart_flow/observation/chatting_observation.py +++ b/src/chat/heart_flow/observation/chatting_observation.py @@ -140,8 +140,23 @@ class ChattingObservation(Observation): return None # logger.debug(f"找到的锚定消息:find_msg: {find_msg}") - group_info = find_msg.get("chat_info", {}).get("group_info") - user_info = find_msg.get("chat_info", {}).get("user_info") + + # 创建所需的user_info字段 + user_info = { + "platform": find_msg.get("user_platform", ""), + "user_id": find_msg.get("user_id", ""), + "user_nickname": find_msg.get("user_nickname", ""), + "user_cardname": find_msg.get("user_cardname", "") + } + + # 创建所需的group_info字段,如果是群聊的话 + group_info = {} + if find_msg.get("chat_info_group_id"): + group_info = { + "platform": find_msg.get("chat_info_group_platform", ""), + "group_id": find_msg.get("chat_info_group_id", ""), + "group_name": find_msg.get("chat_info_group_name", "") + } content_format = "" accept_format = "" @@ -181,6 +196,8 @@ class ChattingObservation(Observation): limit=self.max_now_obs_len, limit_mode="latest", ) + + # print(f"new_messages_list: {new_messages_list}") last_obs_time_mark = self.last_observe_time if new_messages_list: @@ -193,6 +210,7 @@ class ChattingObservation(Observation): oldest_messages = self.talking_message[:messages_to_remove_count] self.talking_message = self.talking_message[messages_to_remove_count:] # 保留后半部分,即最新的 + # print(f"压缩中:oldest_messages: {oldest_messages}") oldest_messages_str = await build_readable_messages( messages=oldest_messages, timestamp_mode="normal", read_mark=0 ) @@ -235,21 +253,24 @@ class ChattingObservation(Observation): self.oldest_messages = oldest_messages self.oldest_messages_str = oldest_messages_str + # 构建中 + # print(f"构建中:self.talking_message: {self.talking_message}") self.talking_message_str = await build_readable_messages( messages=self.talking_message, timestamp_mode="lite", read_mark=last_obs_time_mark, ) + # print(f"构建中:self.talking_message_str: {self.talking_message_str}") self.talking_message_str_truncate = await build_readable_messages( messages=self.talking_message, timestamp_mode="normal", read_mark=last_obs_time_mark, truncate=True, ) + # print(f"构建中:self.talking_message_str_truncate: {self.talking_message_str_truncate}") self.person_list = await get_person_id_list(self.talking_message) - - # print(f"self.11111person_list: {self.person_list}") + # print(f"构建中:self.person_list: {self.person_list}") logger.trace( f"Chat {self.chat_id} - 压缩早期记忆:{self.mid_memory_info}\n现在聊天内容:{self.talking_message_str}" diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index d3a062680..f81603e13 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -174,6 +174,16 @@ async def _build_readable_messages_internal( # 1 & 2: 获取发送者信息并提取消息组件 for msg in messages: + # 检查并修复缺少的user_info字段 + if 'user_info' not in msg: + # 创建user_info字段 + msg['user_info'] = { + 'platform': msg.get('user_platform', ''), + 'user_id': msg.get('user_id', ''), + 'user_nickname': msg.get('user_nickname', ''), + 'user_cardname': msg.get('user_cardname', '') + } + user_info = msg.get("user_info", {}) platform = user_info.get("platform") user_id = user_info.get("user_id") diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index c400a9948..a5b601c43 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -380,7 +380,7 @@ def process_llm_response(text: str) -> list[str]: # sentences.append(content) # 在所有句子处理完毕后,对包含占位符的列表进行恢复 - if global_config.enable_kaomoji_protection: + if global_config.response_splitter.enable_kaomoji_protection: sentences = recover_kaomoji(sentences, kaomoji_mapping) return sentences diff --git a/tests/common/test_message_repository.py b/tests/common/test_message_repository.py new file mode 100644 index 000000000..43d629761 --- /dev/null +++ b/tests/common/test_message_repository.py @@ -0,0 +1,174 @@ +import unittest +from unittest.mock import patch, MagicMock +import datetime +import sys +import os + +# 添加项目根目录到Python路径 +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) + +from peewee import SqliteDatabase +from src.common.database.database_model import Messages, BaseModel +from src.common.message_repository import find_messages + + +class TestMessageRepository(unittest.TestCase): + def setUp(self): + # 创建内存中的SQLite数据库用于测试 + self.test_db = SqliteDatabase(':memory:') + + # 覆盖原有数据库连接 + BaseModel._meta.database = self.test_db + Messages._meta.database = self.test_db + + # 创建表 + self.test_db.create_tables([Messages]) + + # 添加测试数据 + current_time = datetime.datetime.now().timestamp() + self.test_messages = [ + { + 'message_id': 'msg1', + 'time': current_time - 3600, # 1小时前 + 'chat_id': '5ed68437e28644da51f314f37df68d18', + 'chat_info_stream_id': 'stream1', + 'chat_info_platform': 'qq', + 'chat_info_user_platform': 'qq', + 'chat_info_user_id': 'user1', + 'chat_info_user_nickname': '用户1', + 'chat_info_user_cardname': '卡片名1', + 'chat_info_group_platform': 'qq', + 'chat_info_group_id': 'group1', + 'chat_info_group_name': '群组1', + 'chat_info_create_time': current_time - 7200, # 2小时前 + 'chat_info_last_active_time': current_time - 1800, # 30分钟前 + 'user_platform': 'qq', + 'user_id': 'user1', + 'user_nickname': '用户1', + 'user_cardname': '卡片名1', + 'processed_plain_text': '你好', + 'detailed_plain_text': '你好', + 'memorized_times': 1 + }, + { + 'message_id': 'msg2', + 'time': current_time - 1800, # 30分钟前 + 'chat_id': 'chat1', + 'chat_info_stream_id': 'stream1', + 'chat_info_platform': 'qq', + 'chat_info_user_platform': 'qq', + 'chat_info_user_id': 'user1', + 'chat_info_user_nickname': '用户1', + 'chat_info_user_cardname': '卡片名1', + 'chat_info_group_platform': 'qq', + 'chat_info_group_id': 'group1', + 'chat_info_group_name': '群组1', + 'chat_info_create_time': current_time - 7200, + 'chat_info_last_active_time': current_time - 900, # 15分钟前 + 'user_platform': 'qq', + 'user_id': 'user1', + 'user_nickname': '用户1', + 'user_cardname': '卡片名1', + 'processed_plain_text': '世界', + 'detailed_plain_text': '世界', + 'memorized_times': 2 + }, + { + 'message_id': 'msg3', + 'time': current_time - 900, # 15分钟前 + 'chat_id': 'chat2', + 'chat_info_stream_id': 'stream2', + 'chat_info_platform': 'wechat', + 'chat_info_user_platform': 'wechat', + 'chat_info_user_id': 'user2', + 'chat_info_user_nickname': '用户2', + 'chat_info_user_cardname': '卡片名2', + 'chat_info_group_platform': 'wechat', + 'chat_info_group_id': 'group2', + 'chat_info_group_name': '群组2', + 'chat_info_create_time': current_time - 3600, + 'chat_info_last_active_time': current_time - 600, # 10分钟前 + 'user_platform': 'wechat', + 'user_id': 'user2', + 'user_nickname': '用户2', + 'user_cardname': '卡片名2', + 'processed_plain_text': '测试', + 'detailed_plain_text': '测试', + 'memorized_times': 0 + } + ] + + for msg_data in self.test_messages: + Messages.create(**msg_data) + + def tearDown(self): + # 关闭测试数据库连接 + self.test_db.close() + + def test_find_messages_no_filter(self): + """测试不带过滤器的查询""" + results = find_messages({}) + self.assertEqual(len(results), 3) + # 验证结果是否按时间升序排列 + self.assertEqual(results[0]['message_id'], 'msg1') + self.assertEqual(results[1]['message_id'], 'msg2') + self.assertEqual(results[2]['message_id'], 'msg3') + + def test_find_messages_with_filter(self): + """测试带过滤器的查询""" + results = find_messages({'chat_id': 'chat1'}) + self.assertEqual(len(results), 2) + self.assertEqual(results[0]['message_id'], 'msg1') + self.assertEqual(results[1]['message_id'], 'msg2') + + results = find_messages({'user_id': 'user2'}) + self.assertEqual(len(results), 1) + self.assertEqual(results[0]['message_id'], 'msg3') + + def test_find_messages_with_operators(self): + """测试带操作符的查询""" + results = find_messages({'memorized_times': {'$gt': 0}}) + self.assertEqual(len(results), 2) + self.assertEqual(results[0]['message_id'], 'msg1') + self.assertEqual(results[1]['message_id'], 'msg2') + + results = find_messages({'memorized_times': {'$gte': 2}}) + self.assertEqual(len(results), 1) + self.assertEqual(results[0]['message_id'], 'msg2') + + def test_find_messages_with_sort(self): + """测试带排序的查询""" + results = find_messages({}, sort=[('memorized_times', -1)]) + self.assertEqual(len(results), 3) + # 验证结果是否按memorized_times降序排列 + self.assertEqual(results[0]['message_id'], 'msg2') # memorized_times = 2 + self.assertEqual(results[1]['message_id'], 'msg1') # memorized_times = 1 + self.assertEqual(results[2]['message_id'], 'msg3') # memorized_times = 0 + + def test_find_messages_with_limit(self): + """测试带限制的查询""" + # 默认limit_mode为latest,应返回最新的2条记录 + results = find_messages({}, limit=2) + self.assertEqual(len(results), 2) + self.assertEqual(results[0]['message_id'], 'msg2') + self.assertEqual(results[1]['message_id'], 'msg3') + + # 使用earliest模式,应返回最早的2条记录 + results = find_messages({}, limit=2, limit_mode='earliest') + self.assertEqual(len(results), 2) + self.assertEqual(results[0]['message_id'], 'msg1') + self.assertEqual(results[1]['message_id'], 'msg2') + + def test_find_messages_with_combined_criteria(self): + """测试组合查询条件""" + results = find_messages( + {'chat_info_platform': 'qq', 'memorized_times': {'$gt': 0}}, + sort=[('time', 1)], + limit=1 + ) + self.assertEqual(len(results), 1) + self.assertEqual(results[0]['message_id'], 'msg2') + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/test_build_readable_messages.py b/tests/test_build_readable_messages.py new file mode 100644 index 000000000..76caffb75 --- /dev/null +++ b/tests/test_build_readable_messages.py @@ -0,0 +1,171 @@ +import unittest +import sys +import os +import datetime +import time +import asyncio +import traceback +import json +import copy + +# 添加项目根目录到Python路径 +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat, build_readable_messages +from src.common.logger import get_module_logger + +# 创建测试日志记录器 +logger = get_module_logger("test_readable_msg") + +class TestBuildReadableMessages(unittest.TestCase): + def setUp(self): + # 准备测试数据:从真实数据库获取消息 + self.chat_id = '5ed68437e28644da51f314f37df68d18' + self.current_time = time.time() + self.thirty_days_ago = self.current_time - (30 * 24 * 60 * 60) # 30天前的时间戳 + + # 获取最新的10条消息 + try: + self.messages = get_raw_msg_by_timestamp_with_chat( + chat_id=self.chat_id, + timestamp_start=self.thirty_days_ago, + timestamp_end=self.current_time, + limit=10, + limit_mode="latest" + ) + logger.info(f"已获取 {len(self.messages)} 条测试消息") + + # 打印消息样例 + if self.messages: + sample_msg = self.messages[0] + logger.info(f"消息样例: {list(sample_msg.keys())}") + logger.info(f"消息内容: {sample_msg.get('processed_plain_text', '无文本内容')[:50]}...") + except Exception as e: + logger.error(f"获取消息失败: {e}") + logger.error(traceback.format_exc()) + self.messages = [] + + def test_manual_fix_messages(self): + """创建一个手动修复版本的消息进行测试""" + if not self.messages: + self.skipTest("没有测试消息,跳过测试") + return + + logger.info("开始手动修复消息...") + + # 创建修复版本的消息列表 + fixed_messages = [] + + for msg in self.messages: + # 深拷贝以避免修改原始数据 + fixed_msg = copy.deepcopy(msg) + + # 构建 user_info 对象 + if 'user_info' not in fixed_msg: + user_info = { + 'platform': fixed_msg.get('user_platform', 'qq'), + 'user_id': fixed_msg.get('user_id', '10000'), + 'user_nickname': fixed_msg.get('user_nickname', '测试用户'), + 'user_cardname': fixed_msg.get('user_cardname', '') + } + fixed_msg['user_info'] = user_info + logger.info(f"为消息 {fixed_msg.get('message_id')} 添加了 user_info") + + fixed_messages.append(fixed_msg) + + logger.info(f"已修复 {len(fixed_messages)} 条消息") + + try: + # 使用修复后的消息尝试格式化 + formatted_text = asyncio.run(build_readable_messages( + messages=fixed_messages, + replace_bot_name=True, + merge_messages=False, + timestamp_mode="absolute", + read_mark=0.0, + truncate=False + )) + + logger.info("使用修复后的消息格式化完成") + logger.info(f"格式化结果长度: {len(formatted_text)}") + if formatted_text: + logger.info(f"格式化结果预览: {formatted_text[:200]}...") + else: + logger.warning("格式化结果为空") + + # 断言 + self.assertNotEqual(formatted_text, "", "有消息时不应返回空字符串") + except Exception as e: + logger.error(f"使用修复后的消息格式化失败: {e}") + logger.error(traceback.format_exc()) + raise + + def test_debug_build_messages_internal(self): + """调试_build_readable_messages_internal函数""" + if not self.messages: + self.skipTest("没有测试消息,跳过测试") + return + + logger.info("开始调试内部构建函数...") + + try: + # 直接导入内部函数进行测试 + from src.chat.utils.chat_message_builder import _build_readable_messages_internal + + # 手动创建一个简单的测试消息列表 + test_msg = self.messages[0].copy() # 使用第一条消息作为模板 + + # 检查消息结构 + logger.info(f"测试消息keys: {list(test_msg.keys())}") + logger.info(f"user_info存在: {'user_info' in test_msg}") + + # 修复缺少的user_info字段 + if 'user_info' not in test_msg: + logger.warning("消息中缺少user_info字段,添加模拟数据") + test_msg['user_info'] = { + 'platform': test_msg.get('user_platform', 'qq'), + 'user_id': test_msg.get('user_id', '10000'), + 'user_nickname': test_msg.get('user_nickname', '测试用户'), + 'user_cardname': test_msg.get('user_cardname', '') + } + logger.info(f"添加的user_info: {test_msg['user_info']}") + + simple_msgs = [test_msg] + + # 运行内部函数 + result_text, result_details = asyncio.run(_build_readable_messages_internal( + simple_msgs, + replace_bot_name=True, + merge_messages=False, + timestamp_mode="absolute", + truncate=False + )) + + logger.info(f"内部函数返回结果: {result_text[:200] if result_text else '空'}") + logger.info(f"详情列表长度: {len(result_details)}") + + # 显示处理过程中的变量 + if not result_text and len(simple_msgs) > 0: + logger.warning("消息处理可能有问题,检查关键步骤") + msg = simple_msgs[0] + + # 打印关键变量的值 + user_info = msg.get("user_info", {}) + platform = user_info.get("platform") + user_id = user_info.get("user_id") + timestamp = msg.get("time") + content = msg.get("processed_plain_text", "") + + logger.warning(f"平台: {platform}, 用户ID: {user_id}, 时间戳: {timestamp}") + logger.warning(f"内容: {content[:50]}...") + + # 检查必要信息是否完整 + logger.warning(f"必要信息完整性检查: {all([platform, user_id, timestamp is not None])}") + + except Exception as e: + logger.error(f"调试内部函数失败: {e}") + logger.error(traceback.format_exc()) + raise + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/test_extract_messages.py b/tests/test_extract_messages.py new file mode 100644 index 000000000..d32e644b6 --- /dev/null +++ b/tests/test_extract_messages.py @@ -0,0 +1,88 @@ +import unittest +import sys +import os +import datetime +import time + +# 添加项目根目录到Python路径 +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from src.common.message_repository import find_messages +from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat +from peewee import SqliteDatabase +from src.common.database.database import db # 导入实际的数据库连接 + +class TestExtractMessages(unittest.TestCase): + def setUp(self): + # 这个测试使用真实的数据库,所以不需要创建测试数据 + pass + + def test_extract_latest_messages_direct(self): + """测试直接使用message_repository.find_messages函数""" + chat_id = '5ed68437e28644da51f314f37df68d18' + + # 提取最新的10条消息 + results = find_messages( + {'chat_id': chat_id}, + limit=10 + ) + + # 打印结果数量 + print(f"\n直接使用find_messages,找到 {len(results)} 条消息") + + # 如果有结果,打印一些信息 + if results: + print("\n消息时间顺序:") + for idx, msg in enumerate(results): + msg_time = datetime.datetime.fromtimestamp(msg['time']).strftime('%Y-%m-%d %H:%M:%S') + print(f"{idx+1}. ID: {msg['message_id']}, 时间: {msg_time}") + print(f" 文本: {msg.get('processed_plain_text', '无文本内容')[:50]}...") + + # 验证结果按时间排序 + times = [msg['time'] for msg in results] + self.assertEqual(times, sorted(times), "消息应该按时间升序排列") + else: + print(f"未找到chat_id为 {chat_id} 的消息") + + # 最基本的断言,确保测试有效 + self.assertIsInstance(results, list, "结果应该是一个列表") + + def test_extract_latest_messages_via_builder(self): + """使用chat_message_builder中的函数测试从真实数据库提取消息""" + chat_id = '5ed68437e28644da51f314f37df68d18' + + # 设置时间范围为过去30天到现在 + current_time = time.time() + thirty_days_ago = current_time - (30 * 24 * 60 * 60) # 30天前的时间戳 + + # 使用chat_message_builder中的函数 + results = get_raw_msg_by_timestamp_with_chat( + chat_id=chat_id, + timestamp_start=thirty_days_ago, + timestamp_end=current_time, + limit=10, + limit_mode="latest" + ) + + # 打印结果数量 + print(f"\n使用get_raw_msg_by_timestamp_with_chat,找到 {len(results)} 条消息") + + # 如果有结果,打印一些信息 + if results: + print("\n消息时间顺序:") + for idx, msg in enumerate(results): + msg_time = datetime.datetime.fromtimestamp(msg['time']).strftime('%Y-%m-%d %H:%M:%S') + print(f"{idx+1}. ID: {msg['message_id']}, 时间: {msg_time}") + print(f" 文本: {msg.get('processed_plain_text', '无文本内容')[:50]}...") + + # 验证结果按时间排序 + times = [msg['time'] for msg in results] + self.assertEqual(times, sorted(times), "消息应该按时间升序排列") + else: + print(f"未找到chat_id为 {chat_id} 的消息") + + # 最基本的断言,确保测试有效 + self.assertIsInstance(results, list, "结果应该是一个列表") + +if __name__ == '__main__': + unittest.main() \ No newline at end of file