diff --git a/src/plugin_system/apis/cross_context_api.py b/src/plugin_system/apis/cross_context_api.py index 4b7f128aa..78cab2339 100644 --- a/src/plugin_system/apis/cross_context_api.py +++ b/src/plugin_system/apis/cross_context_api.py @@ -9,7 +9,6 @@ from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager from src.chat.utils.chat_message_builder import ( build_readable_messages_with_id, get_raw_msg_before_timestamp_with_chat, - get_raw_msg_by_timestamp_with_chat, ) from src.common.logger import get_logger from src.config.config import global_config @@ -60,7 +59,7 @@ async def build_cross_context_normal(chat_stream: ChatStream, context_group: Con """ cross_context_messages = [] chat_manager = get_chat_manager() - + chat_infos_to_fetch = [] if context_group.mode == "blacklist": # 黑名单模式:获取所有聊天,并排除在 chat_ids 中定义过的聊天 @@ -68,7 +67,7 @@ async def build_cross_context_normal(chat_stream: ChatStream, context_group: Con for stream_id, stream in chat_manager.streams.items(): is_group = stream.group_info is not None chat_type = "group" if is_group else "private" - + # 安全地获取 raw_id if is_group and stream.group_info: raw_id = stream.group_info.group_id @@ -86,7 +85,11 @@ async def build_cross_context_normal(chat_stream: ChatStream, context_group: Con # 遍历待获取列表,抓取并格式化消息 for chat_info in chat_infos_to_fetch: - chat_type, chat_raw_id, limit_str = chat_info[0], chat_info[1], chat_info[2] if len(chat_info) > 2 else str(context_group.default_limit) + chat_type, chat_raw_id, limit_str = ( + chat_info[0], + chat_info[1], + chat_info[2] if len(chat_info) > 2 else str(context_group.default_limit), + ) limit = int(limit_str) is_group = chat_type == "group" stream_id = chat_manager.get_stream_id(chat_stream.platform, chat_raw_id, is_group=is_group) @@ -126,9 +129,7 @@ async def build_cross_context_s4u( return "" chat_manager = get_chat_manager() - current_chat_raw_id = ( - chat_stream.group_info.group_id if chat_stream.group_info else chat_stream.user_info.user_id - ) + current_chat_raw_id = chat_stream.group_info.group_id if chat_stream.group_info else chat_stream.user_info.user_id current_type = "group" if chat_stream.group_info else "private" # 根据模式(黑名单/白名单)决定需要处理哪些聊天 @@ -139,17 +140,17 @@ async def build_cross_context_s4u( for stream_id, stream in chat_manager.streams.items(): if stream_id == chat_stream.stream_id: continue # 排除当前聊天 - + is_group = stream.group_info is not None chat_type = "group" if is_group else "private" - + # 安全地获取 raw_id if is_group and stream.group_info: raw_id = stream.group_info.group_id elif not is_group and stream.user_info: raw_id = stream.user_info.user_id else: - continue # 如果缺少关键信息则跳过 + continue # 如果缺少关键信息则跳过 # 如果不在黑名单中,则加入处理列表 if (chat_type, str(raw_id)) not in blacklisted_ids: @@ -184,12 +185,8 @@ async def build_cross_context_s4u( if user_messages: chat_name = await chat_manager.get_stream_name(stream_id) or chat_raw_id user_name = target_user_info.get("person_name") or target_user_info.get("user_nickname") or user_id - formatted_messages, _ = await build_readable_messages_with_id( - user_messages, timestamp_mode="relative" - ) - cross_context_messages.append( - f'[以下是"{user_name}"在"{chat_name}"的近期发言]\n{formatted_messages}' - ) + formatted_messages, _ = await build_readable_messages_with_id(user_messages, timestamp_mode="relative") + cross_context_messages.append(f'[以下是"{user_name}"在"{chat_name}"的近期发言]\n{formatted_messages}') except Exception as e: logger.error(f"获取用户 {user_id} 在聊天 {chat_raw_id} 的消息失败: {e}") @@ -197,9 +194,7 @@ async def build_cross_context_s4u( if context_group.s4u_ignore_whitelist: private_stream_id = chat_manager.get_stream_id(chat_stream.platform, user_id, is_group=False) # 检查该私聊是否已在白名单中处理过 - is_already_processed = any( - info[0] == "private" and info[1] == user_id for info in context_group.chat_ids - ) + is_already_processed = any(info[0] == "private" and info[1] == user_id for info in context_group.chat_ids) if private_stream_id and not is_already_processed: try: @@ -227,91 +222,24 @@ async def build_cross_context_s4u( return "### 其他群聊中的聊天记录\n" + "\n\n".join(cross_context_messages) + "\n" -async def get_chat_history_by_group_name(group_name: str) -> str: - """ - 根据互通组名字获取聊天记录 - """ - target_group = None - for group in global_config.cross_context.groups: - if group.name == group_name: - target_group = group - break - - if not target_group: - return f"找不到名为 {group_name} 的互通组。" - - if not target_group.chat_ids: - return f"互通组 {group_name} 中没有配置任何聊天。" - - chat_infos = target_group.chat_ids - chat_manager = get_chat_manager() - - cross_context_messages = [] - for chat_info in chat_infos: - chat_type, chat_raw_id, limit = chat_info[0], chat_info[1], int(chat_info[2]) if len(chat_info) > 2 else 5 - is_group = chat_type == "group" - - found_stream = None - for stream in chat_manager.streams.values(): - if is_group: - if stream.group_info and stream.group_info.group_id == chat_raw_id: - found_stream = stream - break - else: # private - if stream.user_info and stream.user_info.user_id == chat_raw_id and not stream.group_info: - found_stream = stream - break - - if not found_stream: - logger.warning(f"在已加载的聊天流中找不到ID为 {chat_raw_id} 的聊天。") - continue - - stream_id = found_stream.stream_id - - try: - messages = await get_raw_msg_before_timestamp_with_chat( - chat_id=stream_id, - timestamp=time.time(), - limit=limit, - ) - if messages: - chat_name = await get_chat_manager().get_stream_name(stream_id) or chat_raw_id - formatted_messages, _ = await build_readable_messages_with_id(messages, timestamp_mode="relative") - cross_context_messages.append(f'[以下是来自"{chat_name}"的近期消息]\n{formatted_messages}') - except Exception as e: - logger.error(f"获取聊天 {chat_raw_id} 的消息失败: {e}") - continue - - if not cross_context_messages: - return f"无法从互通组 {group_name} 中获取任何聊天记录。" - - return "# 跨上下文参考\n" + "\n\n".join(cross_context_messages) + "\n" - - -async def get_intercom_group_context_by_name( - group_name: str, days: int = 3, limit_per_chat: int = 20, total_limit: int = 100 -) -> str | None: +async def get_intercom_group_context(group_name: str, limit_per_chat: int = 20, total_limit: int = 100) -> str | None: """ 根据互通组的名称,构建该组的聊天上下文。 + 支持黑白名单模式,并以分块形式返回每个聊天的消息。 Args: group_name: 互通组的名称。 - days: 获取过去多少天的消息。 limit_per_chat: 每个聊天最多获取的消息条数。 total_limit: 返回的总消息条数上限。 Returns: - 如果找到匹配的组,则返回一个包含聊天记录的字符串;否则返回 None。 + 如果找到匹配的组并获取到消息,则返回一个包含聊天记录的字符串;否则返回 None。 """ cross_context_config = global_config.cross_context if not (cross_context_config and cross_context_config.enable): return None - target_group = None - for group in cross_context_config.groups: - if group.name == group_name: - target_group = group - break + target_group = next((g for g in cross_context_config.groups if g.name == group_name), None) if not target_group: logger.error(f"在 cross_context 配置中未找到名为 '{group_name}' 的组。") @@ -319,15 +247,34 @@ async def get_intercom_group_context_by_name( chat_manager = get_chat_manager() - all_messages = [] - end_time = time.time() - start_time = end_time - (days * 24 * 60 * 60) + # 1. 根据黑白名单模式确定要处理的聊天列表 + chat_infos_to_fetch = [] + if target_group.mode == "blacklist": + blacklisted_ids = {tuple(info[:2]) for info in target_group.chat_ids} + for stream in chat_manager.streams.values(): + is_group = stream.group_info is not None + chat_type = "group" if is_group else "private" - for chat_type, chat_raw_id in target_group.chat_ids: + if is_group and stream.group_info: + raw_id = stream.group_info.group_id + elif not is_group and stream.user_info: + raw_id = stream.user_info.user_id + else: + continue + + if (chat_type, str(raw_id)) not in blacklisted_ids: + chat_infos_to_fetch.append([chat_type, str(raw_id)]) + else: # whitelist mode + chat_infos_to_fetch = target_group.chat_ids + + # 2. 获取所有相关消息 + all_messages = [] + for chat_info in chat_infos_to_fetch: + chat_type, chat_raw_id = chat_info[0], chat_info[1] is_group = chat_type == "group" + # 查找 stream found_stream = None - # 采用与 get_chat_history_by_group_name 相同的健壮的 stream 查找方式 for stream in chat_manager.streams.values(): if is_group: if stream.group_info and stream.group_info.group_id == chat_raw_id: @@ -336,32 +283,50 @@ async def get_intercom_group_context_by_name( else: # private if stream.user_info and stream.user_info.user_id == chat_raw_id and not stream.group_info: found_stream = stream - break - + break if not found_stream: logger.warning(f"在已加载的聊天流中找不到ID为 {chat_raw_id} 的聊天。") continue - stream_id = found_stream.stream_id - messages = await get_raw_msg_by_timestamp_with_chat( - chat_id=stream_id, - timestamp_start=start_time, - timestamp_end=end_time, - limit=limit_per_chat, - limit_mode="latest", - ) - all_messages.extend(messages) + + try: + messages = await get_raw_msg_before_timestamp_with_chat( + chat_id=stream_id, + timestamp=time.time(), + limit=limit_per_chat, + ) + if messages: + # 为每条消息附加 stream_id 以便后续分组 + for msg in messages: + msg["_stream_id"] = stream_id + all_messages.extend(messages) + except Exception as e: + logger.error(f"获取聊天 {chat_raw_id} 的消息失败: {e}") if not all_messages: return None - # 按时间戳对所有消息进行排序 + # 3. 应用总数限制 all_messages.sort(key=lambda x: x.get("time", 0)) - - # 限制总消息数 if len(all_messages) > total_limit: all_messages = all_messages[-total_limit:] - # build_readable_messages_with_id 返回一个元组 (formatted_string, message_id_list) - formatted_string, _ = await build_readable_messages_with_id(all_messages) - return formatted_string + # 4. 按聊天分组并格式化 + messages_by_stream = {} + for msg in all_messages: + stream_id = msg.get("_stream_id") + if stream_id not in messages_by_stream: + messages_by_stream[stream_id] = [] + messages_by_stream[stream_id].append(msg) + + cross_context_messages = [] + for stream_id, messages in messages_by_stream.items(): + if messages: + chat_name = await chat_manager.get_stream_name(stream_id) or "未知聊天" + formatted_messages, _ = await build_readable_messages_with_id(messages, timestamp_mode="relative") + cross_context_messages.append(f'[以下是来自"{chat_name}"的近期消息]\n{formatted_messages}') + + if not cross_context_messages: + return None + + return "# 跨上下文参考\n" + "\n\n".join(cross_context_messages) + "\n" diff --git a/src/plugins/built_in/maizone_refactored/services/qzone_service.py b/src/plugins/built_in/maizone_refactored/services/qzone_service.py index 02f9fa86a..af71a281e 100644 --- a/src/plugins/built_in/maizone_refactored/services/qzone_service.py +++ b/src/plugins/built_in/maizone_refactored/services/qzone_service.py @@ -190,7 +190,7 @@ class QZoneService: 获取互通组的聊天上下文。 """ # 实际的逻辑已迁移到 cross_context_api - return await cross_context_api.get_intercom_group_context_by_name("maizone_context_group") + return await cross_context_api.get_intercom_group_context("maizone_context_group") async def _reply_to_own_feed_comments(self, feed: dict, api_client: dict): """处理对自己说说的评论并进行回复"""