diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 67c2d4b6b..54a138d0b 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -652,7 +652,9 @@ class ContextGroup(ValidatedConfigBase): """上下文共享组配置""" name: str = Field(..., description="共享组的名称") - chat_ids: List[str] = Field(..., description="属于该组的聊天ID列表") + chat_ids: List[List[str]] = Field( + ..., description='属于该组的聊天ID列表,格式为 [["type", "chat_id"], ...],例如 [["group", "123456"], ["private", "789012"]]' + ) class CrossContextConfig(ValidatedConfigBase): diff --git a/src/plugin_system/apis/cross_context_api.py b/src/plugin_system/apis/cross_context_api.py index f926742aa..5a0b896df 100644 --- a/src/plugin_system/apis/cross_context_api.py +++ b/src/plugin_system/apis/cross_context_api.py @@ -16,34 +16,45 @@ from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream logger = get_logger("cross_context_api") -def get_context_groups(chat_id: str) -> Optional[List[str]]: +def get_context_groups(chat_id: str) -> Optional[List[List[str]]]: """ - 获取当前群聊所在的共享组的其他群聊ID + 获取当前聊天所在的共享组的其他聊天ID """ current_stream = get_chat_manager().get_stream(chat_id) - if not current_stream or not current_stream.group_info: + if not current_stream: return None - try: - current_chat_raw_id = current_stream.group_info.group_id - except Exception as e: - logger.error(f"获取群聊ID失败: {e}") - return None + is_group = current_stream.group_info is not None + current_chat_raw_id = ( + current_stream.group_info.group_id if is_group else current_stream.user_info.user_id + ) + current_type = "group" if is_group else "private" for group in global_config.cross_context.groups: - if str(current_chat_raw_id) in group.chat_ids: - return [chat_id for chat_id in group.chat_ids if chat_id != str(current_chat_raw_id)] + # 检查当前聊天的ID和类型是否在组的chat_ids中 + if [current_type, str(current_chat_raw_id)] in group.chat_ids: + # 返回组内其他聊天的 [type, id] 列表 + return [ + chat_info + for chat_info in group.chat_ids + if chat_info != [current_type, str(current_chat_raw_id)] + ] return None -async def build_cross_context_normal(chat_stream: ChatStream, other_chat_raw_ids: List[str]) -> str: +async def build_cross_context_normal( + chat_stream: ChatStream, other_chat_infos: List[List[str]] +) -> str: """ - 构建跨群聊上下文 (Normal模式) + 构建跨群聊/私聊上下文 (Normal模式) """ cross_context_messages = [] - for chat_raw_id in other_chat_raw_ids: - stream_id = get_chat_manager().get_stream_id(chat_stream.platform, chat_raw_id, is_group=True) + for chat_type, chat_raw_id in other_chat_infos: + is_group = chat_type == "group" + stream_id = get_chat_manager().get_stream_id( + chat_stream.platform, chat_raw_id, is_group=is_group + ) if not stream_id: continue @@ -54,33 +65,38 @@ async def build_cross_context_normal(chat_stream: ChatStream, other_chat_raw_ids limit=5, # 可配置 ) if messages: - chat_name = get_chat_manager().get_stream_name(stream_id) or stream_id - formatted_messages, _ = build_readable_messages_with_id(messages, timestamp_mode="relative") + chat_name = get_chat_manager().get_stream_name(stream_id) or chat_raw_id + formatted_messages, _ = build_readable_messages_with_id( + messages, timestamp_mode="relative" + ) cross_context_messages.append(f'[以下是来自"{chat_name}"的近期消息]\n{formatted_messages}') except Exception as e: - logger.error(f"获取群聊{chat_raw_id}的消息失败: {e}") + logger.error(f"获取聊天 {chat_raw_id} 的消息失败: {e}") continue if not cross_context_messages: return "" - return "# 跨群上下文参考\n" + "\n\n".join(cross_context_messages) + "\n" + return "# 跨上下文参考\n" + "\n\n".join(cross_context_messages) + "\n" async def build_cross_context_s4u( - chat_stream: ChatStream, other_chat_raw_ids: List[str], target_user_info: Optional[Dict[str, Any]] + chat_stream: ChatStream, + other_chat_infos: List[List[str]], + target_user_info: Optional[Dict[str, Any]], ) -> str: """ - 构建跨群聊上下文 (S4U模式) + 构建跨群聊/私聊上下文 (S4U模式) """ cross_context_messages = [] if target_user_info: user_id = target_user_info.get("user_id") if user_id: - for chat_raw_id in other_chat_raw_ids: + for chat_type, chat_raw_id in other_chat_infos: + is_group = chat_type == "group" stream_id = get_chat_manager().get_stream_id( - chat_stream.platform, chat_raw_id, is_group=True + chat_stream.platform, chat_raw_id, is_group=is_group ) if not stream_id: continue @@ -91,12 +107,10 @@ async def build_cross_context_s4u( timestamp=time.time(), limit=20, # 获取更多消息以供筛选 ) - user_messages = [msg for msg in messages if msg.get("user_id") == user_id][ - -5: - ] # 筛选并取最近5条 + user_messages = [msg for msg in messages if msg.get("user_id") == user_id][-5:] if user_messages: - chat_name = get_chat_manager().get_stream_name(stream_id) or stream_id + chat_name = get_chat_manager().get_stream_name(stream_id) or chat_raw_id user_name = ( target_user_info.get("person_name") or target_user_info.get("user_nickname") @@ -109,10 +123,10 @@ async def build_cross_context_s4u( f'[以下是"{user_name}"在"{chat_name}"的近期发言]\n{formatted_messages}' ) except Exception as e: - logger.error(f"获取用户{user_id}在群聊{chat_raw_id}的消息失败: {e}") + logger.error(f"获取用户 {user_id} 在聊天 {chat_raw_id} 的消息失败: {e}") continue if not cross_context_messages: return "" - return "# 跨群上下文参考\n" + "\n\n".join(cross_context_messages) + "\n" \ No newline at end of file + return "# 跨上下文参考\n" + "\n\n".join(cross_context_messages) + "\n" \ No newline at end of file diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index 196ff5ff7..e7221cf7c 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -473,19 +473,20 @@ pre_sleep_notification_groups = [] # 用于生成睡前消息的提示。AI会根据这个提示生成一句晚安问候。 pre_sleep_prompt = "我准备睡觉了,请生成一句简短自然的晚安问候。" -[cross_context] # 跨群聊上下文共享配置 +[cross_context] # 跨群聊/私聊上下文共享配置 # 这是总开关,用于一键启用或禁用此功能 -enable = false +enable = true # 在这里定义您的“共享组” -# 只有在同一个组内的群聊才会共享上下文 -# 注意:这里的chat_ids需要填写群号 +# 只有在同一个组内的聊天才会共享上下文 +# 格式:chat_ids = [["type", "id"], ["type", "id"], ...] +# type 可选 "group" 或 "private" [[cross_context.groups]] name = "项目A技术讨论组" chat_ids = [ - "111111", # 假设这是“开发群”的ID - "222222" # 假设这是“产品群”的ID + ["group", "169850076"], # 假设这是“开发群”的ID + ["group", "1025509724"], # 假设这是“产品群”的ID + ["private", "123456789"] # 假设这是某个用户的私聊 ] - [maizone_intercom] # QQ空间互通组配置 # 启用后,发布说说时会读取指定互通组的上下文 @@ -495,6 +496,6 @@ enable = false [[maizone_intercom.groups]] name = "Maizone默认互通组" chat_ids = [ - "111111", # 示例群聊1 - "222222" # 示例群聊2 + ["group", "111111"], # 示例群聊1 + ["private", "222222"] # 示例私聊2 ] \ No newline at end of file