diff --git a/src/chat/utils/prompt_utils.py b/src/chat/utils/prompt_utils.py index f9985be53..4f9e36777 100644 --- a/src/chat/utils/prompt_utils.py +++ b/src/chat/utils/prompt_utils.py @@ -9,13 +9,9 @@ from typing import Dict, Any, Optional, Tuple from src.common.logger import get_logger from src.config.config import global_config -from src.chat.utils.chat_message_builder import ( - get_raw_msg_before_timestamp_with_chat, - build_readable_messages_with_id, -) from src.chat.message_receive.chat_stream import get_chat_manager from src.person_info.person_info import get_person_info_manager - +from src.plugin_system.apis import cross_context_api logger = get_logger("prompt_utils") @@ -84,113 +80,29 @@ class PromptUtils: @staticmethod async def build_cross_context( - chat_id: str, target_user_info: Optional[Dict[str, Any]], current_prompt_mode: str - ) -> str: - """ - 构建跨群聊上下文 - 统一实现,完全继承DefaultReplyer功能 + chat_id: str, target_user_info: Optional[Dict[str, Any]], current_prompt_mode: str + ) -> str: + """ + 构建跨群聊上下文 - 统一实现,完全继承DefaultReplyer功能 + """ + if not global_config.cross_context.enable: + return "" - Args: - chat_id: 当前聊天ID - target_user_info: 目标用户信息 - current_prompt_mode: 当前提示模式 + other_chat_raw_ids = cross_context_api.get_context_groups(chat_id) + if not other_chat_raw_ids: + return "" + + chat_stream = get_chat_manager().get_stream(chat_id) + if not chat_stream: + return "" + + if current_prompt_mode == "normal": + return await cross_context_api.build_cross_context_normal(chat_stream, other_chat_raw_ids) + elif current_prompt_mode == "s4u": + return await cross_context_api.build_cross_context_s4u(chat_stream, other_chat_raw_ids, target_user_info) - Returns: - str: 跨群上下文块 - """ - if not global_config.cross_context.enable: return "" - # 找到当前群聊所在的共享组 - target_group = None - current_stream = get_chat_manager().get_stream(chat_id) - if not current_stream or not current_stream.group_info: - return "" - - try: - current_chat_raw_id = current_stream.group_info.group_id - except Exception as e: - logger.error(f"获取群聊ID失败: {e}") - return "" - - for group in global_config.cross_context.groups: - if str(current_chat_raw_id) in group.chat_ids: - target_group = group - break - - if not target_group: - return "" - - # 根据prompt_mode选择策略 - other_chat_raw_ids = [chat_id for chat_id in target_group.chat_ids if chat_id != str(current_chat_raw_id)] - - cross_context_messages = [] - - if current_prompt_mode == "normal": - # normal模式:获取其他群聊的最近N条消息 - for chat_raw_id in other_chat_raw_ids: - stream_id = get_chat_manager().get_stream_id(current_stream.platform, chat_raw_id, is_group=True) - if not stream_id: - continue - - try: - messages = get_raw_msg_before_timestamp_with_chat( - chat_id=stream_id, - timestamp=time.time(), - limit=5, # 可配置 - ) - if messages: - chat_name = get_chat_manager().get_stream_name(stream_id) or stream_id - formatted_messages, _ = build_readable_messages_with_id(messages, timestamp_mode="relative") - cross_context_messages.append(f'[以下是来自"{chat_name}"的近期消息]\n{formatted_messages}') - except Exception as e: - logger.error(f"获取群聊{chat_raw_id}的消息失败: {e}") - continue - - elif current_prompt_mode == "s4u": - # s4u模式:获取当前发言用户在其他群聊的消息 - if target_user_info: - user_id = target_user_info.get("user_id") - - if user_id: - for chat_raw_id in other_chat_raw_ids: - stream_id = get_chat_manager().get_stream_id( - current_stream.platform, chat_raw_id, is_group=True - ) - if not stream_id: - continue - - try: - messages = get_raw_msg_before_timestamp_with_chat( - chat_id=stream_id, - timestamp=time.time(), - limit=20, # 获取更多消息以供筛选 - ) - user_messages = [msg for msg in messages if msg.get("user_id") == user_id][ - -5: - ] # 筛选并取最近5条 - - if user_messages: - chat_name = get_chat_manager().get_stream_name(stream_id) or stream_id - user_name = ( - target_user_info.get("person_name") - or target_user_info.get("user_nickname") - or user_id - ) - formatted_messages, _ = build_readable_messages_with_id( - user_messages, timestamp_mode="relative" - ) - cross_context_messages.append( - f'[以下是"{user_name}"在"{chat_name}"的近期发言]\n{formatted_messages}' - ) - except Exception as e: - logger.error(f"获取用户{user_id}在群聊{chat_raw_id}的消息失败: {e}") - continue - - if not cross_context_messages: - return "" - - return "# 跨群上下文参考\n" + "\n\n".join(cross_context_messages) + "\n" - @staticmethod def parse_reply_target_id(reply_to: str) -> str: """ diff --git a/src/plugin_system/apis/cross_context_api.py b/src/plugin_system/apis/cross_context_api.py new file mode 100644 index 000000000..f926742aa --- /dev/null +++ b/src/plugin_system/apis/cross_context_api.py @@ -0,0 +1,118 @@ +""" +跨群聊上下文API +""" + +import time +from typing import Dict, Any, Optional, List + +from src.common.logger import get_logger +from src.config.config import global_config +from src.chat.utils.chat_message_builder import ( + get_raw_msg_before_timestamp_with_chat, + build_readable_messages_with_id, +) +from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream + +logger = get_logger("cross_context_api") + + +def get_context_groups(chat_id: str) -> Optional[List[str]]: + """ + 获取当前群聊所在的共享组的其他群聊ID + """ + current_stream = get_chat_manager().get_stream(chat_id) + if not current_stream or not current_stream.group_info: + return None + + try: + current_chat_raw_id = current_stream.group_info.group_id + except Exception as e: + logger.error(f"获取群聊ID失败: {e}") + return None + + for group in global_config.cross_context.groups: + if str(current_chat_raw_id) in group.chat_ids: + return [chat_id for chat_id in group.chat_ids if chat_id != str(current_chat_raw_id)] + + return None + + +async def build_cross_context_normal(chat_stream: ChatStream, other_chat_raw_ids: List[str]) -> str: + """ + 构建跨群聊上下文 (Normal模式) + """ + cross_context_messages = [] + for chat_raw_id in other_chat_raw_ids: + stream_id = get_chat_manager().get_stream_id(chat_stream.platform, chat_raw_id, is_group=True) + if not stream_id: + continue + + try: + messages = get_raw_msg_before_timestamp_with_chat( + chat_id=stream_id, + timestamp=time.time(), + limit=5, # 可配置 + ) + if messages: + chat_name = get_chat_manager().get_stream_name(stream_id) or stream_id + formatted_messages, _ = build_readable_messages_with_id(messages, timestamp_mode="relative") + cross_context_messages.append(f'[以下是来自"{chat_name}"的近期消息]\n{formatted_messages}') + except Exception as e: + logger.error(f"获取群聊{chat_raw_id}的消息失败: {e}") + continue + + if not cross_context_messages: + return "" + + return "# 跨群上下文参考\n" + "\n\n".join(cross_context_messages) + "\n" + + +async def build_cross_context_s4u( + chat_stream: ChatStream, other_chat_raw_ids: List[str], target_user_info: Optional[Dict[str, Any]] +) -> str: + """ + 构建跨群聊上下文 (S4U模式) + """ + cross_context_messages = [] + if target_user_info: + user_id = target_user_info.get("user_id") + + if user_id: + for chat_raw_id in other_chat_raw_ids: + stream_id = get_chat_manager().get_stream_id( + chat_stream.platform, chat_raw_id, is_group=True + ) + if not stream_id: + continue + + try: + messages = get_raw_msg_before_timestamp_with_chat( + chat_id=stream_id, + timestamp=time.time(), + limit=20, # 获取更多消息以供筛选 + ) + user_messages = [msg for msg in messages if msg.get("user_id") == user_id][ + -5: + ] # 筛选并取最近5条 + + if user_messages: + chat_name = get_chat_manager().get_stream_name(stream_id) or stream_id + user_name = ( + target_user_info.get("person_name") + or target_user_info.get("user_nickname") + or user_id + ) + formatted_messages, _ = build_readable_messages_with_id( + user_messages, timestamp_mode="relative" + ) + cross_context_messages.append( + f'[以下是"{user_name}"在"{chat_name}"的近期发言]\n{formatted_messages}' + ) + except Exception as e: + logger.error(f"获取用户{user_id}在群聊{chat_raw_id}的消息失败: {e}") + continue + + if not cross_context_messages: + return "" + + return "# 跨群上下文参考\n" + "\n\n".join(cross_context_messages) + "\n" \ No newline at end of file