refactor(chat): 将消息获取相关函数重构为异步

将 `chat_message_builder` 中的多个同步消息获取函数(如 `get_raw_msg_by_timestamp`)及其调用全部修改为异步函数。这统一了数据库查询的异步模式,提高了代码一致性和可维护性。

主要改动包括:
- 将 `chat_message_builder.py` 中的数据库查询函数标记为 `async` 并使用 `await`。
- 更新了 `message_api.py`、`mood_manager.py` 和 `qzone_service.py` 中对这些函数的调用,以适应异步接口。
- 调整了 `message_api.py` 中的函数签名和返回类型提示,以反映异步特性。
This commit is contained in:
minecraft1024a
2025-10-02 17:32:02 +08:00
parent a7acd98023
commit d5627b0661
4 changed files with 65 additions and 59 deletions

View File

@@ -171,7 +171,7 @@ async def replace_user_references_async(
return content
def get_raw_msg_by_timestamp(
async def get_raw_msg_by_timestamp(
timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]:
"""
@@ -182,10 +182,10 @@ def get_raw_msg_by_timestamp(
filter_query = {"time": {"$gt": timestamp_start, "$lt": timestamp_end}}
# 只有当 limit 为 0 时才应用外部 sort
sort_order = [("time", 1)] if limit == 0 else None
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
def get_raw_msg_by_timestamp_with_chat(
async def get_raw_msg_by_timestamp_with_chat(
chat_id: str,
timestamp_start: float,
timestamp_end: float,
@@ -202,7 +202,7 @@ def get_raw_msg_by_timestamp_with_chat(
# 只有当 limit 为 0 时才应用外部 sort
sort_order = [("time", 1)] if limit == 0 else None
# 直接将 limit_mode 传递给 find_messages
return find_messages(
return await find_messages(
message_filter=filter_query,
sort=sort_order,
limit=limit,
@@ -212,7 +212,7 @@ def get_raw_msg_by_timestamp_with_chat(
)
def get_raw_msg_by_timestamp_with_chat_inclusive(
async def get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id: str,
timestamp_start: float,
timestamp_end: float,
@@ -229,12 +229,12 @@ def get_raw_msg_by_timestamp_with_chat_inclusive(
sort_order = [("time", 1)] if limit == 0 else None
# 直接将 limit_mode 传递给 find_messages
return find_messages(
return await find_messages(
message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode, filter_bot=filter_bot
)
def get_raw_msg_by_timestamp_with_chat_users(
async def get_raw_msg_by_timestamp_with_chat_users(
chat_id: str,
timestamp_start: float,
timestamp_end: float,
@@ -253,7 +253,7 @@ def get_raw_msg_by_timestamp_with_chat_users(
}
# 只有当 limit 为 0 时才应用外部 sort
sort_order = [("time", 1)] if limit == 0 else None
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
async def get_actions_by_timestamp_with_chat(
@@ -420,14 +420,14 @@ async def get_actions_by_timestamp_with_chat_inclusive(
return [action.__dict__ for action in actions]
def get_raw_msg_by_timestamp_random(
async def get_raw_msg_by_timestamp_random(
timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]:
"""
先在范围时间戳内随机选择一条消息取得消息的chat_id然后根据chat_id获取该聊天在指定时间戳范围内的消息
"""
# 获取所有消息只取chat_id字段
all_msgs = get_raw_msg_by_timestamp(timestamp_start, timestamp_end)
all_msgs = await get_raw_msg_by_timestamp(timestamp_start, timestamp_end)
if not all_msgs:
return []
# 随机选一条
@@ -435,10 +435,10 @@ def get_raw_msg_by_timestamp_random(
chat_id = msg["chat_id"]
timestamp_start = msg["time"]
# 用 chat_id 获取该聊天在指定时间戳范围内的消息
return get_raw_msg_by_timestamp_with_chat(chat_id, timestamp_start, timestamp_end, limit, "earliest")
return await get_raw_msg_by_timestamp_with_chat(chat_id, timestamp_start, timestamp_end, limit, "earliest")
def get_raw_msg_by_timestamp_with_users(
async def get_raw_msg_by_timestamp_with_users(
timestamp_start: float, timestamp_end: float, person_ids: list, limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]:
"""获取某些特定用户在 *所有聊天* 中从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
@@ -448,16 +448,16 @@ def get_raw_msg_by_timestamp_with_users(
filter_query = {"time": {"$gt": timestamp_start, "$lt": timestamp_end}, "user_id": {"$in": person_ids}}
# 只有当 limit 为 0 时才应用外部 sort
sort_order = [("time", 1)] if limit == 0 else None
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit, limit_mode=limit_mode)
def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[Dict[str, Any]]:
async def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[Dict[str, Any]]:
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
"""
filter_query = {"time": {"$lt": timestamp}}
sort_order = [("time", 1)]
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
async def get_raw_msg_before_timestamp_with_chat(chat_id: str, timestamp: float, limit: int = 0) -> List[Dict[str, Any]]:
@@ -469,16 +469,16 @@ async def get_raw_msg_before_timestamp_with_chat(chat_id: str, timestamp: float,
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids: list, limit: int = 0) -> List[Dict[str, Any]]:
async def get_raw_msg_before_timestamp_with_users(timestamp: float, person_ids: list, limit: int = 0) -> List[Dict[str, Any]]:
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
"""
filter_query = {"time": {"$lt": timestamp}, "user_id": {"$in": person_ids}}
sort_order = [("time", 1)]
return find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
return await find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None) -> int:
async def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp_end: Optional[float] = None) -> int:
"""
检查特定聊天从 timestamp_start (不含) 到 timestamp_end (不含) 之间有多少新消息。
如果 timestamp_end 为 None则检查从 timestamp_start (不含) 到当前时间的消息。
@@ -492,10 +492,10 @@ def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp
return 0 # 起始时间大于等于结束时间,没有新消息
filter_query = {"chat_id": chat_id, "time": {"$gt": timestamp_start, "$lt": _timestamp_end}}
return count_messages(message_filter=filter_query)
return await count_messages(message_filter=filter_query)
def num_new_messages_since_with_users(
async def num_new_messages_since_with_users(
chat_id: str, timestamp_start: float, timestamp_end: float, person_ids: list
) -> int:
"""检查某些特定用户在特定聊天在指定时间戳之间有多少新消息"""
@@ -506,7 +506,7 @@ def num_new_messages_since_with_users(
"time": {"$gt": timestamp_start, "$lt": timestamp_end},
"user_id": {"$in": person_ids},
}
return count_messages(message_filter=filter_query)
return await count_messages(message_filter=filter_query)
async def _build_readable_messages_internal(
@@ -645,7 +645,7 @@ async def _build_readable_messages_internal(
person_name = f"{person_name}({user_id})"
# 使用独立函数处理用户引用格式
content = replace_user_references_sync(content, platform, replace_bot_name=replace_bot_name)
content = await replace_user_references_async(content, platform, replace_bot_name=replace_bot_name)
target_str = "这是QQ的一个功能用于提及某人但没那么明显"
if target_str in content and random.random() < 0.6:
@@ -1216,13 +1216,15 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
# print(f"anon_name:{anon_name}")
# 使用独立函数处理用户引用格式,传入自定义的匿名名称解析器
def anon_name_resolver(platform: str, user_id: str) -> str:
async def anon_name_resolver(platform: str, user_id: str) -> str:
try:
return get_anon_name(platform, user_id)
except Exception:
return "?"
content = replace_user_references_sync(content, platform, anon_name_resolver, replace_bot_name=False)
content = await replace_user_references_async(
content, platform, anon_name_resolver, replace_bot_name=False
)
header = f"{anon_name}"
output_lines.append(header)

View File

@@ -103,7 +103,7 @@ class ChatMood:
logger.debug(
f"{self.log_prefix} 更新情绪状态,感兴趣度: {interested_rate:.2f}, 更新概率: {update_probability:.2f}"
)
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
message_list_before_now = await get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_change_time,
timestamp_end=message_time,
@@ -152,7 +152,7 @@ class ChatMood:
async def regress_mood(self):
message_time = time.time()
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
message_list_before_now = await get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_change_time,
timestamp_end=message_time,

View File

@@ -34,9 +34,9 @@ from src.chat.utils.chat_message_builder import (
# =============================================================================
def get_messages_by_time(
async def get_messages_by_time(
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
) -> Coroutine[Any, Any, list[dict[str, Any]]]:
) -> List[Dict[str, Any]]:
"""
获取指定时间范围内的消息
@@ -58,8 +58,8 @@ def get_messages_by_time(
if limit < 0:
raise ValueError("limit 不能为负数")
if filter_mai:
return filter_mai_messages(get_raw_msg_by_timestamp(start_time, end_time, limit, limit_mode))
return get_raw_msg_by_timestamp(start_time, end_time, limit, limit_mode)
return await filter_mai_messages(await get_raw_msg_by_timestamp(start_time, end_time, limit, limit_mode))
return await get_raw_msg_by_timestamp(start_time, end_time, limit, limit_mode)
async def get_messages_by_time_in_chat(
@@ -148,14 +148,14 @@ async def get_messages_by_time_in_chat_inclusive(
)
def get_messages_by_time_in_chat_for_users(
async def get_messages_by_time_in_chat_for_users(
chat_id: str,
start_time: float,
end_time: float,
person_ids: List[str],
limit: int = 0,
limit_mode: str = "latest",
) -> Coroutine[Any, Any, list[dict[str, Any]]]:
) -> List[Dict[str, Any]]:
"""
获取指定聊天中指定用户在指定时间范围内的消息
@@ -181,12 +181,14 @@ def get_messages_by_time_in_chat_for_users(
raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型")
return get_raw_msg_by_timestamp_with_chat_users(chat_id, start_time, end_time, person_ids, limit, limit_mode)
return await get_raw_msg_by_timestamp_with_chat_users(
chat_id, start_time, end_time, person_ids, limit, limit_mode
)
def get_random_chat_messages(
async def get_random_chat_messages(
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
) -> Coroutine[Any, Any, list[dict[str, Any]]]:
) -> List[Dict[str, Any]]:
"""
随机选择一个聊天,返回该聊天在指定时间范围内的消息
@@ -208,13 +210,13 @@ def get_random_chat_messages(
if limit < 0:
raise ValueError("limit 不能为负数")
if filter_mai:
return filter_mai_messages(get_raw_msg_by_timestamp_random(start_time, end_time, limit, limit_mode))
return get_raw_msg_by_timestamp_random(start_time, end_time, limit, limit_mode)
return await filter_mai_messages(await get_raw_msg_by_timestamp_random(start_time, end_time, limit, limit_mode))
return await get_raw_msg_by_timestamp_random(start_time, end_time, limit, limit_mode)
def get_messages_by_time_for_users(
async def get_messages_by_time_for_users(
start_time: float, end_time: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest"
) -> Coroutine[Any, Any, list[dict[str, Any]]]:
) -> List[Dict[str, Any]]:
"""
获取指定用户在所有聊天中指定时间范围内的消息
@@ -235,11 +237,10 @@ def get_messages_by_time_for_users(
raise ValueError("start_time 和 end_time 必须是数字类型")
if limit < 0:
raise ValueError("limit 不能为负数")
return get_raw_msg_by_timestamp_with_users(start_time, end_time, person_ids, limit, limit_mode)
return await get_raw_msg_by_timestamp_with_users(start_time, end_time, person_ids, limit, limit_mode)
def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> Coroutine[
Any, Any, list[dict[str, Any]]]:
async def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool = False) -> List[Dict[str, Any]]:
"""
获取指定时间戳之前的消息
@@ -259,8 +260,8 @@ def get_messages_before_time(timestamp: float, limit: int = 0, filter_mai: bool
if limit < 0:
raise ValueError("limit 不能为负数")
if filter_mai:
return filter_mai_messages(get_raw_msg_before_timestamp(timestamp, limit))
return get_raw_msg_before_timestamp(timestamp, limit)
return await filter_mai_messages(await get_raw_msg_before_timestamp(timestamp, limit))
return await get_raw_msg_before_timestamp(timestamp, limit)
async def get_messages_before_time_in_chat(
@@ -294,8 +295,9 @@ async def get_messages_before_time_in_chat(
return await get_raw_msg_before_timestamp_with_chat(chat_id, timestamp, limit)
def get_messages_before_time_for_users(timestamp: float, person_ids: List[str], limit: int = 0) -> Coroutine[
Any, Any, list[dict[str, Any]]]:
async def get_messages_before_time_for_users(
timestamp: float, person_ids: List[str], limit: int = 0
) -> List[Dict[str, Any]]:
"""
获取指定用户在指定时间戳之前的消息
@@ -314,12 +316,12 @@ def get_messages_before_time_for_users(timestamp: float, person_ids: List[str],
raise ValueError("timestamp 必须是数字类型")
if limit < 0:
raise ValueError("limit 不能为负数")
return get_raw_msg_before_timestamp_with_users(timestamp, person_ids, limit)
return await get_raw_msg_before_timestamp_with_users(timestamp, person_ids, limit)
def get_recent_messages(
async def get_recent_messages(
chat_id: str, hours: float = 24.0, limit: int = 100, limit_mode: str = "latest", filter_mai: bool = False
) -> Coroutine[Any, Any, list[dict[str, Any]]]:
) -> List[Dict[str, Any]]:
"""
获取指定聊天中最近一段时间的消息
@@ -347,8 +349,10 @@ def get_recent_messages(
now = time.time()
start_time = now - hours * 3600
if filter_mai:
return filter_mai_messages(get_raw_msg_by_timestamp_with_chat(chat_id, start_time, now, limit, limit_mode))
return get_raw_msg_by_timestamp_with_chat(chat_id, start_time, now, limit, limit_mode)
return await filter_mai_messages(
await get_raw_msg_by_timestamp_with_chat(chat_id, start_time, now, limit, limit_mode)
)
return await get_raw_msg_by_timestamp_with_chat(chat_id, start_time, now, limit, limit_mode)
# =============================================================================
@@ -356,8 +360,7 @@ def get_recent_messages(
# =============================================================================
def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional[float] = None) -> Coroutine[
Any, Any, int]:
async def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional[float] = None) -> int:
"""
计算指定聊天中从开始时间到结束时间的新消息数量
@@ -378,11 +381,12 @@ def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional
raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型")
return num_new_messages_since(chat_id, start_time, end_time)
return await num_new_messages_since(chat_id, start_time, end_time)
def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: List[str]) -> Coroutine[
Any, Any, int]:
async def count_new_messages_for_users(
chat_id: str, start_time: float, end_time: float, person_ids: List[str]
) -> int:
"""
计算指定聊天中指定用户从开始时间到结束时间的新消息数量
@@ -404,7 +408,7 @@ def count_new_messages_for_users(chat_id: str, start_time: float, end_time: floa
raise ValueError("chat_id 不能为空")
if not isinstance(chat_id, str):
raise ValueError("chat_id 必须是字符串类型")
return num_new_messages_since_with_users(chat_id, start_time, end_time, person_ids)
return await num_new_messages_since_with_users(chat_id, start_time, end_time, person_ids)
# =============================================================================
@@ -420,7 +424,7 @@ async def build_readable_messages_to_str(
read_mark: float = 0.0,
truncate: bool = False,
show_actions: bool = False,
) -> Coroutine[Any, Any, str]:
) -> str:
"""
将消息列表构建成可读的字符串

View File

@@ -221,7 +221,7 @@ class QZoneService:
for chat_id in group.chat_ids:
# 使用正确的函数获取历史消息
messages = get_raw_msg_by_timestamp_with_chat(
messages = await get_raw_msg_by_timestamp_with_chat(
chat_id=chat_id,
timestamp_start=start_time,
timestamp_end=end_time,