diff --git a/src/common/message_repository.py b/src/common/message_repository.py index f9a874859..35a1b5ec4 100644 --- a/src/common/message_repository.py +++ b/src/common/message_repository.py @@ -1,5 +1,6 @@ import traceback from typing import Any +from collections import defaultdict from sqlalchemy import func, not_, select from sqlalchemy.orm import DeclarativeBase @@ -212,3 +213,69 @@ async def count_messages(message_filter: dict[str, Any]) -> int: # 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。 # 注意:对于 SQLAlchemy,插入操作通常是使用 await session.add() 和 await session.commit()。 # 查找单个消息可以使用 session.execute(select(Messages).where(...)).scalar_one_or_none()。 + + +async def get_user_messages_from_streams( + user_ids: list[str], + stream_ids: list[str], + timestamp_after: float, + limit_per_stream: int, +) -> dict[str, list[dict[str, Any]]]: + """ + 一次性从多个聊天流中获取特定用户的近期消息。 + + Args: + user_ids: 目标用户的ID列表。 + stream_ids: 要查询的聊天流ID列表。 + timestamp_after: 只获取此时间戳之后的消息。 + limit_per_stream: 每个聊天流中获取该用户的消息数量上限。 + + Returns: + 一个字典,键为 stream_id,值为该聊天流中的消息列表。 + """ + if not stream_ids or not user_ids: + return {} + + try: + async with get_db_session() as session: + # 使用 CTE 和 row_number() 来为每个聊天流中的用户消息进行排序和编号 + ranked_messages_cte = ( + select( + Messages, + func.row_number().over(partition_by=Messages.chat_id, order_by=Messages.time.desc()).label("row_num"), + ) + .where( + Messages.user_id.in_(user_ids), + Messages.chat_id.in_(stream_ids), + Messages.time > timestamp_after, + ) + .cte("ranked_messages") + ) + + # 从 CTE 中选择每个聊天流最新的 `limit_per_stream` 条消息 + query = select(ranked_messages_cte).where(ranked_messages_cte.c.row_num <= limit_per_stream) + + result = await session.execute(query) + messages = result.all() + + # 按 stream_id 分组 + messages_by_stream = defaultdict(list) + for row in messages: + # Since the row is a Row object from a CTE, we need to manually construct the model instance + msg_instance = Messages(**{c.name: getattr(row, c.name) for c in Messages.__table__.columns}) + msg_dict = _model_to_dict(msg_instance) + messages_by_stream[msg_dict["chat_id"]].append(msg_dict) + + # 对每个流内的消息按时间升序排序 + for stream_id in messages_by_stream: + messages_by_stream[stream_id].sort(key=lambda m: m["time"]) + + return dict(messages_by_stream) + + except Exception as e: + log_message = ( + f"使用 SQLAlchemy 批量查找用户消息失败 (user_ids={user_ids}, streams={len(stream_ids)}): {e}\n" + + traceback.format_exc() + ) + logger.error(log_message) + return {} diff --git a/src/plugin_system/apis/cross_context_api.py b/src/plugin_system/apis/cross_context_api.py index 805c8f490..d699cac69 100644 --- a/src/plugin_system/apis/cross_context_api.py +++ b/src/plugin_system/apis/cross_context_api.py @@ -11,6 +11,7 @@ from src.chat.utils.chat_message_builder import ( get_raw_msg_before_timestamp_with_chat, ) from src.common.logger import get_logger +from src.common.message_repository import get_user_messages_from_streams from src.config.config import global_config from src.config.official_configs import ContextGroup @@ -127,9 +128,12 @@ async def build_cross_context_s4u( 基于全局S4U配置,检索目标用户在其他聊天中的发言。 """ + logger.info("[S4U] Starting S4U context build.") cross_context_config = global_config.cross_context if not target_user_info or not (user_id := target_user_info.get("user_id")): + logger.warning(f"[S4U] Failed: target_user_info ({target_user_info}) or user_id is missing.") return "" + logger.info(f"[S4U] Target user ID: {user_id}") chat_manager = get_chat_manager() all_user_messages = [] @@ -161,25 +165,37 @@ async def build_cross_context_s4u( for stream_id in chat_manager.streams: if stream_id != chat_stream.stream_id and stream_id not in blacklisted_streams: streams_to_scan.append(stream_id) + + logger.info(f"[S4U] Scan mode: {cross_context_config.s4u_mode}.") + logger.info(f"[S4U] Whitelist: {cross_context_config.s4u_whitelist_chats}, Blacklist: {cross_context_config.s4u_blacklist_chats}.") + logger.info(f"[S4U] Found {len(streams_to_scan)} streams to scan: {streams_to_scan}") # 2. 从筛选出的聊天流中获取目标用户的消息 + # 2. 从筛选出的聊天流中获取目标用户的消息 (优化后) limit = cross_context_config.s4u_limit - for stream_id in streams_to_scan: - try: - # 获取稍多一些消息以确保能筛选出用户消息 - messages = await get_raw_msg_before_timestamp_with_chat( - chat_id=stream_id, timestamp=time.time(), limit=limit * 5 - ) - user_messages = [msg for msg in messages if msg.get("user_id") == user_id][-limit:] + # 约 3 天内的消息 + timestamp_after = time.time() - (3 * 24 * 60 * 60) + + # 如果是私聊,则同时获取bot自身的消息 + user_ids_to_fetch = [str(user_id)] + if chat_stream.group_info is None: + user_ids_to_fetch.append(str(global_config.bot.qq_account)) - if user_messages: - # 记录消息来源的 stream_id 和最新消息的时间戳 - latest_timestamp = max(msg.get("time", 0) for msg in user_messages) - all_user_messages.append( - {"stream_id": stream_id, "messages": user_messages, "latest_timestamp": latest_timestamp} - ) - except Exception as e: - logger.error(f"S4U模式下获取聊天 {stream_id} 的消息失败: {e}") + # 一次性批量查询 + messages_by_stream = await get_user_messages_from_streams( + user_ids=user_ids_to_fetch, + stream_ids=streams_to_scan, + timestamp_after=timestamp_after, + limit_per_stream=limit, + ) + + for stream_id, user_messages in messages_by_stream.items(): + if user_messages: + logger.info(f"[S4U] Found {len(user_messages)} messages for user {user_id} in stream {stream_id}.") + latest_timestamp = max(msg.get("time", 0) for msg in user_messages) + all_user_messages.append( + {"stream_id": stream_id, "messages": user_messages, "latest_timestamp": latest_timestamp} + ) # 3. 按最新消息时间排序,并根据 stream_limit 截断 all_user_messages.sort(key=lambda x: x["latest_timestamp"], reverse=True) @@ -193,15 +209,24 @@ async def build_cross_context_s4u( try: chat_name = await chat_manager.get_stream_name(stream_id) or "未知聊天" user_name = target_user_info.get("person_name") or target_user_info.get("user_nickname") or user_id + + # 如果是私聊,标题不显示用户名 + stream = await chat_manager.get_stream(stream_id) + is_private = stream.group_info is None if stream else False + title = f'[以下是您在"{chat_name}"的近期发言]\n' if is_private else f'[以下是"{user_name}"在"{chat_name}"的近期发言]\n' + formatted, _ = await build_readable_messages_with_id(messages, timestamp_mode="relative") - cross_context_blocks.append(f'[以下是"{user_name}"在"{chat_name}"的近期发言]\n{formatted}') + cross_context_blocks.append(f"{title}{formatted}") except Exception as e: logger.error(f"S4U模式下格式化消息失败 (stream: {stream_id}): {e}") if not cross_context_blocks: + logger.info("[S4U] No context blocks were generated. Returning empty string.") return "" - return "### 其他群聊中的聊天记录\n" + "\n\n".join(cross_context_blocks) + "\n" + final_context = "### 其他群聊中的聊天记录\n" + "\n\n".join(cross_context_blocks) + "\n" + logger.info(f"[S4U] Successfully generated {len(cross_context_blocks)} context blocks. Total length: {len(final_context)}.") + return final_context async def get_intercom_group_context(group_name: str, limit_per_chat: int = 20, total_limit: int = 100) -> str | None: