From d5627b066159f34a8eaa8bacb54d320acef91cd6 Mon Sep 17 00:00:00 2001 From: minecraft1024a Date: Thu, 2 Oct 2025 17:32:02 +0800 Subject: [PATCH] =?UTF-8?q?refactor(chat):=20=E5=B0=86=E6=B6=88=E6=81=AF?= =?UTF-8?q?=E8=8E=B7=E5=8F=96=E7=9B=B8=E5=85=B3=E5=87=BD=E6=95=B0=E9=87=8D?= =?UTF-8?q?=E6=9E=84=E4=B8=BA=E5=BC=82=E6=AD=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 将 `chat_message_builder` 中的多个同步消息获取函数(如 `get_raw_msg_by_timestamp`)及其调用全部修改为异步函数。这统一了数据库查询的异步模式,提高了代码一致性和可维护性。 主要改动包括: - 将 `chat_message_builder.py` 中的数据库查询函数标记为 `async` 并使用 `await`。 - 更新了 `message_api.py`、`mood_manager.py` 和 `qzone_service.py` 中对这些函数的调用,以适应异步接口。 - 调整了 `message_api.py` 中的函数签名和返回类型提示,以反映异步特性。 --- src/chat/utils/chat_message_builder.py | 50 +++++++------- src/mood/mood_manager.py | 4 +- src/plugin_system/apis/message_api.py | 68 ++++++++++--------- .../services/qzone_service.py | 2 +- 4 files changed, 65 insertions(+), 59 deletions(-) diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index 2119a3d59..277ba3d23 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -171,7 +171,7 @@ async def replace_user_references_async( return content -def get_raw_msg_by_timestamp( +async def get_raw_msg_by_timestamp( timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest" ) -> List[Dict[str, Any]]: """ @@ -182,10 +182,10 @@ def get_raw_msg_by_timestamp( filter_query = {"time": {"$gt": timestamp_start, "$lt": timestamp_end}} # 只有当 limit 为 0 时才应用外部 sort sort_order = [("time", 1)] if limit == 0 else None - return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode) + return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode) -def get_raw_msg_by_timestamp_with_chat( +async def get_raw_msg_by_timestamp_with_chat( chat_id: str, timestamp_start: float, timestamp_end: float, @@ -202,7 +202,7 @@ def get_raw_msg_by_timestamp_with_chat( # 只有当 limit 为 0 时才应用外部 sort sort_order = [("time", 1)] if limit == 0 else None # 直接将 limit_mode 传递给 find_messages - return find_messages( + return await find_messages( message_filter=filter_query, sort=sort_order, limit=limit, @@ -212,7 +212,7 @@ def get_raw_msg_by_timestamp_with_chat( ) -def get_raw_msg_by_timestamp_with_chat_inclusive( +async def get_raw_msg_by_timestamp_with_chat_inclusive( chat_id: str, timestamp_start: float, timestamp_end: float, @@ -229,12 +229,12 @@ def get_raw_msg_by_timestamp_with_chat_inclusive( sort_order = [("time", 1)] if limit == 0 else None # 直接将 limit_mode 传递给 find_messages - return find_messages( + return await find_messages( message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode, filter_bot=filter_bot ) -def get_raw_msg_by_timestamp_with_chat_users( +async def get_raw_msg_by_timestamp_with_chat_users( chat_id: str, timestamp_start: float, timestamp_end: float, @@ -253,7 +253,7 @@ def get_raw_msg_by_timestamp_with_chat_users( } # 只有当 limit 为 0 时才应用外部 sort sort_order = [("time", 1)] if limit == 0 else None - return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode) + return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode) async def get_actions_by_timestamp_with_chat( @@ -420,14 +420,14 @@ async def get_actions_by_timestamp_with_chat_inclusive( return [action.__dict__ for action in actions] -def get_raw_msg_by_timestamp_random( +async def get_raw_msg_by_timestamp_random( timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest" ) -> List[Dict[str, Any]]: """ 先在范围时间戳内随机选择一条消息,取得消息的chat_id,然后根据chat_id获取该聊天在指定时间戳范围内的消息 """ # 获取所有消息,只取chat_id字段 - all_msgs = get_raw_msg_by_timestamp(timestamp_start, timestamp_end) + all_msgs = await get_raw_msg_by_timestamp(timestamp_start, timestamp_end) if not all_msgs: return [] # 随机选一条 @@ -435,10 +435,10 @@ def get_raw_msg_by_timestamp_random( chat_id = msg["chat_id"] timestamp_start = msg["time"] # 用 chat_id 获取该聊天在指定时间戳范围内的消息 - return get_raw_msg_by_timestamp_with_chat(chat_id, timestamp_start, timestamp_end, limit, "earliest") + return await get_raw_msg_by_timestamp_with_chat(chat_id, timestamp_start, timestamp_end, limit, "earliest") -def get_raw_msg_by_timestamp_with_users( +async def get_raw_msg_by_timestamp_with_users( timestamp_start: float, timestamp_end: float, person_ids: list, limit: int = 0, limit_mode: str = "latest" ) -> List[Dict[str, Any]]: """获取某些特定用户在 *所有聊天* 中从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表 @@ -448,16 +448,16 @@ def get_raw_msg_by_timestamp_with_users( filter_query = {"time": {"$gt": timestamp_start, "$lt": timestamp_end}, "user_id": {"$in": person_ids}} # 只有当 limit 为 0 时才应用外部 sort sort_order = [("time", 1)] if limit == 0 else None - return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode) + return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode) -def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[Dict[str, Any]]: +async def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[Dict[str, Any]]: """获取指定时间戳之前的消息,按时间升序排序,返回消息列表 limit: 限制返回的消息数量,0为不限制 """ filter_query = {"time": {"$lt": timestamp}} sort_order = [("time", 1)] - return find_messages(message_filter=filter_query, sort=sort_order, limit=limit) + return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit) async def get_raw_msg_before_timestamp_with_chat(chat_id: str, timestamp: float, limit: int = 0) -> List[Dict[str, Any]]: @@ -469,16 +469,16 @@ async def get_raw_msg_before_timestamp_with_chat(chat_id: str, timestamp: float, return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit) -def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids: list, limit: int = 0) -> List[Dict[str, Any]]: +async def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids: list, limit: int = 0) -> List[Dict[str, Any]]: """获取指定时间戳之前的消息,按时间升序排序,返回消息列表 limit: 限制返回的消息数量,0为不限制 """ filter_query = {"time": {"$lt": timestamp}, "user_id": {"$in": person_ids}} sort_order = [("time", 1)] - return find_messages(message_filter=filter_query, sort=sort_order, limit=limit) + return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit) -def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None) -> int: +async def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None) -> int: """ 检查特定聊天从 timestamp_start (不含) 到 timestamp_end (不含) 之间有多少新消息。 如果 timestamp_end 为 None,则检查从 timestamp_start (不含) 到当前时间的消息。 @@ -492,10 +492,10 @@ def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp return 0 # 起始时间大于等于结束时间,没有新消息 filter_query = {"chat_id": chat_id, "time": {"$gt": timestamp_start, "$lt": _timestamp_end}} - return count_messages(message_filter=filter_query) + return await count_messages(message_filter=filter_query) -def num_new_messages_since_with_users( +async def num_new_messages_since_with_users( chat_id: str, timestamp_start: float, timestamp_end: float, person_ids: list ) -> int: """检查某些特定用户在特定聊天在指定时间戳之间有多少新消息""" @@ -506,7 +506,7 @@ def num_new_messages_since_with_users( "time": {"$gt": timestamp_start, "$lt": timestamp_end}, "user_id": {"$in": person_ids}, } - return count_messages(message_filter=filter_query) + return await count_messages(message_filter=filter_query) async def _build_readable_messages_internal( @@ -645,7 +645,7 @@ async def _build_readable_messages_internal( person_name = f"{person_name}({user_id})" # 使用独立函数处理用户引用格式 - content = replace_user_references_sync(content, platform, replace_bot_name=replace_bot_name) + content = await replace_user_references_async(content, platform, replace_bot_name=replace_bot_name) target_str = "这是QQ的一个功能,用于提及某人,但没那么明显" if target_str in content and random.random() < 0.6: @@ -1216,13 +1216,15 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str: # print(f"anon_name:{anon_name}") # 使用独立函数处理用户引用格式,传入自定义的匿名名称解析器 - def anon_name_resolver(platform: str, user_id: str) -> str: + async def anon_name_resolver(platform: str, user_id: str) -> str: try: return get_anon_name(platform, user_id) except Exception: return "?" - content = replace_user_references_sync(content, platform, anon_name_resolver, replace_bot_name=False) + content = await replace_user_references_async( + content, platform, anon_name_resolver, replace_bot_name=False + ) header = f"{anon_name}说 " output_lines.append(header) diff --git a/src/mood/mood_manager.py b/src/mood/mood_manager.py index 307c47d4e..dc7f0f24b 100644 --- a/src/mood/mood_manager.py +++ b/src/mood/mood_manager.py @@ -103,7 +103,7 @@ class ChatMood: logger.debug( f"{self.log_prefix} 更新情绪状态,感兴趣度: {interested_rate:.2f}, 更新概率: {update_probability:.2f}" ) - message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive( + message_list_before_now = await get_raw_msg_by_timestamp_with_chat_inclusive( chat_id=self.chat_id, timestamp_start=self.last_change_time, timestamp_end=message_time, @@ -152,7 +152,7 @@ class ChatMood: async def regress_mood(self): message_time = time.time() - message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive( + message_list_before_now = await get_raw_msg_by_timestamp_with_chat_inclusive( chat_id=self.chat_id, timestamp_start=self.last_change_time, timestamp_end=message_time, diff --git a/src/plugin_system/apis/message_api.py b/src/plugin_system/apis/message_api.py index 0ddb4254d..612c243a3 100644 --- a/src/plugin_system/apis/message_api.py +++ b/src/plugin_system/apis/message_api.py @@ -34,9 +34,9 @@ from src.chat.utils.chat_message_builder import ( # ============================================================================= -def get_messages_by_time( +async def get_messages_by_time( start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False -) -> Coroutine[Any, Any, list[dict[str, Any]]]: +) -> List[Dict[str, Any]]: """ 获取指定时间范围内的消息 @@ -58,8 +58,8 @@ def get_messages_by_time( if limit < 0: raise ValueError("limit 不能为负数") if filter_mai: - return filter_mai_messages(get_raw_msg_by_timestamp(start_time, end_time, limit, limit_mode)) - return get_raw_msg_by_timestamp(start_time, end_time, limit, limit_mode) + return await filter_mai_messages(await get_raw_msg_by_timestamp(start_time, end_time, limit, limit_mode)) + return await get_raw_msg_by_timestamp(start_time, end_time, limit, limit_mode) async def get_messages_by_time_in_chat( @@ -148,14 +148,14 @@ async def get_messages_by_time_in_chat_inclusive( ) -def get_messages_by_time_in_chat_for_users( +async def get_messages_by_time_in_chat_for_users( chat_id: str, start_time: float, end_time: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest", -) -> Coroutine[Any, Any, list[dict[str, Any]]]: +) -> List[Dict[str, Any]]: """ 获取指定聊天中指定用户在指定时间范围内的消息 @@ -181,12 +181,14 @@ def get_messages_by_time_in_chat_for_users( raise ValueError("chat_id 不能为空") if not isinstance(chat_id, str): raise ValueError("chat_id 必须是字符串类型") - return get_raw_msg_by_timestamp_with_chat_users(chat_id, start_time, end_time, person_ids, limit, limit_mode) + return await get_raw_msg_by_timestamp_with_chat_users( + chat_id, start_time, end_time, person_ids, limit, limit_mode + ) -def get_random_chat_messages( +async def get_random_chat_messages( start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False -) -> Coroutine[Any, Any, list[dict[str, Any]]]: +) -> List[Dict[str, Any]]: """ 随机选择一个聊天,返回该聊天在指定时间范围内的消息 @@ -208,13 +210,13 @@ def get_random_chat_messages( if limit < 0: raise ValueError("limit 不能为负数") if filter_mai: - return filter_mai_messages(get_raw_msg_by_timestamp_random(start_time, end_time, limit, limit_mode)) - return get_raw_msg_by_timestamp_random(start_time, end_time, limit, limit_mode) + return await filter_mai_messages(await get_raw_msg_by_timestamp_random(start_time, end_time, limit, limit_mode)) + return await get_raw_msg_by_timestamp_random(start_time, end_time, limit, limit_mode) -def get_messages_by_time_for_users( +async def get_messages_by_time_for_users( start_time: float, end_time: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest" -) -> Coroutine[Any, Any, list[dict[str, Any]]]: +) -> List[Dict[str, Any]]: """ 获取指定用户在所有聊天中指定时间范围内的消息 @@ -235,11 +237,10 @@ def get_messages_by_time_for_users( raise ValueError("start_time 和 end_time 必须是数字类型") if limit < 0: raise ValueError("limit 不能为负数") - return get_raw_msg_by_timestamp_with_users(start_time, end_time, person_ids, limit, limit_mode) + return await get_raw_msg_by_timestamp_with_users(start_time, end_time, person_ids, limit, limit_mode) -def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> Coroutine[ - Any, Any, list[dict[str, Any]]]: +async def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> List[Dict[str, Any]]: """ 获取指定时间戳之前的消息 @@ -259,8 +260,8 @@ def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool if limit < 0: raise ValueError("limit 不能为负数") if filter_mai: - return filter_mai_messages(get_raw_msg_before_timestamp(timestamp, limit)) - return get_raw_msg_before_timestamp(timestamp, limit) + return await filter_mai_messages(await get_raw_msg_before_timestamp(timestamp, limit)) + return await get_raw_msg_before_timestamp(timestamp, limit) async def get_messages_before_time_in_chat( @@ -294,8 +295,9 @@ async def get_messages_before_time_in_chat( return await get_raw_msg_before_timestamp_with_chat(chat_id, timestamp, limit) -def get_messages_before_time_for_users(timestamp: float, person_ids: List[str], limit: int = 0) -> Coroutine[ - Any, Any, list[dict[str, Any]]]: +async def get_messages_before_time_for_users( + timestamp: float, person_ids: List[str], limit: int = 0 +) -> List[Dict[str, Any]]: """ 获取指定用户在指定时间戳之前的消息 @@ -314,12 +316,12 @@ def get_messages_before_time_for_users(timestamp: float, person_ids: List[str], raise ValueError("timestamp 必须是数字类型") if limit < 0: raise ValueError("limit 不能为负数") - return get_raw_msg_before_timestamp_with_users(timestamp, person_ids, limit) + return await get_raw_msg_before_timestamp_with_users(timestamp, person_ids, limit) -def get_recent_messages( +async def get_recent_messages( chat_id: str, hours: float = 24.0, limit: int = 100, limit_mode: str = "latest", filter_mai: bool = False -) -> Coroutine[Any, Any, list[dict[str, Any]]]: +) -> List[Dict[str, Any]]: """ 获取指定聊天中最近一段时间的消息 @@ -347,8 +349,10 @@ def get_recent_messages( now = time.time() start_time = now - hours * 3600 if filter_mai: - return filter_mai_messages(get_raw_msg_by_timestamp_with_chat(chat_id, start_time, now, limit, limit_mode)) - return get_raw_msg_by_timestamp_with_chat(chat_id, start_time, now, limit, limit_mode) + return await filter_mai_messages( + await get_raw_msg_by_timestamp_with_chat(chat_id, start_time, now, limit, limit_mode) + ) + return await get_raw_msg_by_timestamp_with_chat(chat_id, start_time, now, limit, limit_mode) # ============================================================================= @@ -356,8 +360,7 @@ def get_recent_messages( # ============================================================================= -def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional[float] = None) -> Coroutine[ - Any, Any, int]: +async def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional[float] = None) -> int: """ 计算指定聊天中从开始时间到结束时间的新消息数量 @@ -378,11 +381,12 @@ def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional raise ValueError("chat_id 不能为空") if not isinstance(chat_id, str): raise ValueError("chat_id 必须是字符串类型") - return num_new_messages_since(chat_id, start_time, end_time) + return await num_new_messages_since(chat_id, start_time, end_time) -def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: List[str]) -> Coroutine[ - Any, Any, int]: +async def count_new_messages_for_users( + chat_id: str, start_time: float, end_time: float, person_ids: List[str] +) -> int: """ 计算指定聊天中指定用户从开始时间到结束时间的新消息数量 @@ -404,7 +408,7 @@ def count_new_messages_for_users(chat_id: str, start_time: float, end_time: floa raise ValueError("chat_id 不能为空") if not isinstance(chat_id, str): raise ValueError("chat_id 必须是字符串类型") - return num_new_messages_since_with_users(chat_id, start_time, end_time, person_ids) + return await num_new_messages_since_with_users(chat_id, start_time, end_time, person_ids) # ============================================================================= @@ -420,7 +424,7 @@ async def build_readable_messages_to_str( read_mark: float = 0.0, truncate: bool = False, show_actions: bool = False, -) -> Coroutine[Any, Any, str]: +) -> str: """ 将消息列表构建成可读的字符串 diff --git a/src/plugins/built_in/maizone_refactored/services/qzone_service.py b/src/plugins/built_in/maizone_refactored/services/qzone_service.py index cfee3787c..186079965 100644 --- a/src/plugins/built_in/maizone_refactored/services/qzone_service.py +++ b/src/plugins/built_in/maizone_refactored/services/qzone_service.py @@ -221,7 +221,7 @@ class QZoneService: for chat_id in group.chat_ids: # 使用正确的函数获取历史消息 - messages = get_raw_msg_by_timestamp_with_chat( + messages = await get_raw_msg_by_timestamp_with_chat( chat_id=chat_id, timestamp_start=start_time, timestamp_end=end_time,