diff --git a/src/chat/utils/prompt.py b/src/chat/utils/prompt.py index 5fdb6c24e..1745cb851 100644 --- a/src/chat/utils/prompt.py +++ b/src/chat/utils/prompt.py @@ -1064,12 +1064,7 @@ class Prompt: chat_stream.group_info.group_id if chat_stream.group_info else chat_stream.user_info.user_id ) current_type = "group" if chat_stream.group_info else "private" - other_chat_infos = [ - chat_info - for chat_info in context_group.chat_ids - if chat_info[:2] != [current_type, str(current_chat_raw_id)] - ] - return await cross_context_api.build_cross_context_normal(chat_stream, other_chat_infos) + return await cross_context_api.build_cross_context_normal(chat_stream, context_group) elif prompt_mode == "s4u": return await cross_context_api.build_cross_context_s4u(chat_stream, context_group, target_user_info) diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 0d6dd2b2e..f12759824 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -630,15 +630,28 @@ class SleepSystemConfig(ValidatedConfigBase): class ContextGroup(ValidatedConfigBase): - """上下文共享组配置""" + """ + 上下文共享组配置 - name: str = Field(..., description="共享组的名称") + 定义了一个聊天上下文的共享范围和规则。 + """ + + name: str = Field(..., description="共享组的名称,用于唯一标识一个共享组") + mode: Literal["whitelist", "blacklist"] = Field( + default="whitelist", + description="共享模式。'whitelist'表示仅共享chat_ids中列出的聊天;'blacklist'表示共享除chat_ids中列出的所有聊天。", + ) + default_limit: int = Field( + default=5, + description="在'blacklist'模式下,对于未明确指定数量的聊天,默认获取的消息条数。也用于s4u_ignore_whitelist开启时获取私聊消息的数量。", + ) s4u_ignore_whitelist: bool = Field( - default=False, description="在s4u模式下, 是否无视白名单, 获取用户所有私聊消息" + default=False, + description="在s4u模式下,是否无视白名单,获取目标用户与Bot的所有私聊消息,以构建更完整的用户画像。", ) chat_ids: list[list[str]] = Field( ..., - description='属于该组的聊天ID列表,格式为 [["type", "chat_id"], ...],例如 [["group", "123456"], ["private", "789012"]]', + description='定义组内成员的列表。格式为 [["type", "id", "limit"(可选)]]。type为"group"或"private",id为群号或用户ID,limit为可选的消息条数。', ) diff --git a/src/plugin_system/apis/cross_context_api.py b/src/plugin_system/apis/cross_context_api.py index bf8e146cd..4b7f128aa 100644 --- a/src/plugin_system/apis/cross_context_api.py +++ b/src/plugin_system/apis/cross_context_api.py @@ -45,16 +45,52 @@ async def get_context_group(chat_id: str) -> ContextGroup | None: return None -async def build_cross_context_normal(chat_stream: ChatStream, other_chat_infos: list[list[str]]) -> str: +async def build_cross_context_normal(chat_stream: ChatStream, context_group: ContextGroup) -> str: """ - 构建跨群聊/私聊上下文 (Normal模式) + 构建跨群聊/私聊上下文 (Normal模式)。 + + 根据共享组的配置(白名单或黑名单模式),获取相关聊天的近期消息,并格式化为字符串。 + + Args: + chat_stream: 当前的聊天流对象。 + context_group: 当前聊天所在的上下文共享组配置。 + + Returns: + 一个包含格式化后的跨上下文消息的字符串,如果无消息则为空字符串。 """ cross_context_messages = [] - for chat_info in other_chat_infos: - chat_type, chat_raw_id, limit = chat_info[0], chat_info[1], int(chat_info[2]) if len(chat_info) > 2 else 5 + chat_manager = get_chat_manager() + + chat_infos_to_fetch = [] + if context_group.mode == "blacklist": + # 黑名单模式:获取所有聊天,并排除在 chat_ids 中定义过的聊天 + blacklisted_ids = {tuple(info[:2]) for info in context_group.chat_ids} + for stream_id, stream in chat_manager.streams.items(): + is_group = stream.group_info is not None + chat_type = "group" if is_group else "private" + + # 安全地获取 raw_id + if is_group and stream.group_info: + raw_id = stream.group_info.group_id + elif not is_group and stream.user_info: + raw_id = stream.user_info.user_id + else: + continue # 如果缺少关键信息则跳过 + + # 如果当前聊天不在黑名单中,则添加到待获取列表 + if (chat_type, str(raw_id)) not in blacklisted_ids: + chat_infos_to_fetch.append([chat_type, str(raw_id), str(context_group.default_limit)]) + else: + # 白名单模式:直接使用配置中定义的 chat_ids + chat_infos_to_fetch = context_group.chat_ids + + # 遍历待获取列表,抓取并格式化消息 + for chat_info in chat_infos_to_fetch: + chat_type, chat_raw_id, limit_str = chat_info[0], chat_info[1], chat_info[2] if len(chat_info) > 2 else str(context_group.default_limit) + limit = int(limit_str) is_group = chat_type == "group" - stream_id = get_chat_manager().get_stream_id(chat_stream.platform, chat_raw_id, is_group=is_group) - if not stream_id: + stream_id = chat_manager.get_stream_id(chat_stream.platform, chat_raw_id, is_group=is_group) + if not stream_id or stream_id == chat_stream.stream_id: continue try: @@ -64,7 +100,7 @@ async def build_cross_context_normal(chat_stream: ChatStream, other_chat_infos: limit=limit, ) if messages: - chat_name = await get_chat_manager().get_stream_name(stream_id) or chat_raw_id + chat_name = await chat_manager.get_stream_name(stream_id) or chat_raw_id formatted_messages, _ = await build_readable_messages_with_id(messages, timestamp_mode="relative") cross_context_messages.append(f'[以下是来自"{chat_name}"的近期消息]\n{formatted_messages}') except Exception as e: @@ -95,19 +131,45 @@ async def build_cross_context_s4u( ) current_type = "group" if chat_stream.group_info else "private" - other_chat_infos = [ - chat_info - for chat_info in context_group.chat_ids - if chat_info[:2] != [current_type, str(current_chat_raw_id)] - ] + # 根据模式(黑名单/白名单)决定需要处理哪些聊天 + chat_infos_to_process = [] + if context_group.mode == "blacklist": + # 黑名单模式:获取除当前聊天和黑名单内聊天之外的所有聊天 + blacklisted_ids = {tuple(info[:2]) for info in context_group.chat_ids} + for stream_id, stream in chat_manager.streams.items(): + if stream_id == chat_stream.stream_id: + continue # 排除当前聊天 + + is_group = stream.group_info is not None + chat_type = "group" if is_group else "private" + + # 安全地获取 raw_id + if is_group and stream.group_info: + raw_id = stream.group_info.group_id + elif not is_group and stream.user_info: + raw_id = stream.user_info.user_id + else: + continue # 如果缺少关键信息则跳过 - # 1. 处理在白名单内的聊天 - for chat_info in other_chat_infos: - chat_type, chat_raw_id, limit = ( + # 如果不在黑名单中,则加入处理列表 + if (chat_type, str(raw_id)) not in blacklisted_ids: + chat_infos_to_process.append([chat_type, str(raw_id), str(context_group.default_limit)]) + else: # 白名单模式 + # 白名单模式:只获取在 chat_ids 中且非当前聊天的聊天 + chat_infos_to_process = [ + chat_info + for chat_info in context_group.chat_ids + if chat_info[:2] != [current_type, str(current_chat_raw_id)] + ] + + # 1. 处理筛选出的目标聊天 + for chat_info in chat_infos_to_process: + chat_type, chat_raw_id, limit_str = ( chat_info[0], chat_info[1], - int(chat_info[2]) if len(chat_info) > 2 else 5, + chat_info[2] if len(chat_info) > 2 else str(context_group.default_limit), ) + limit = int(limit_str) is_group = chat_type == "group" stream_id = chat_manager.get_stream_id(chat_stream.platform, chat_raw_id, is_group=is_group) if not stream_id: @@ -141,7 +203,7 @@ async def build_cross_context_s4u( if private_stream_id and not is_already_processed: try: - limit = 5 # 使用默认值 + limit = context_group.default_limit messages = await get_raw_msg_before_timestamp_with_chat( chat_id=private_stream_id, timestamp=time.time(), limit=limit * 4 ) diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index 540e2bc9f..01119db17 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -1,5 +1,5 @@ [inner] -version = "7.2.10" +version = "7.2.11" #----以下是给开发人员阅读的,如果你只是部署了MoFox-Bot,不需要阅读---- #如果你想要修改配置文件,请递增version的值 @@ -552,6 +552,14 @@ enable = true # limit 是一个可选的整数(但需要以字符串形式写入),用于指定从该聊天流中获取的消息数量,如果未指定,默认为5 [[cross_context.groups]] name = "项目A技术讨论组" +# mode: (可选, 默认为 "whitelist") +# "whitelist": 白名单模式,只有在 chat_ids 中明确列出的聊天才会共享上下文。 +# "blacklist": 黑名单模式,除了在 chat_ids 中列出的聊天外,所有其他聊天都会共享上下文。 +mode = "whitelist" +# default_limit: (可选, 默认为 5) +# 在 "blacklist" 模式下,未在 chat_ids 中指定的聊天将默认获取此数量的消息。 +# 同时,当 s4u_ignore_whitelist 设置为 true 时,获取用户私聊消息的数量也将使用此值。 +default_limit = 5 # s4u_ignore_whitelist: (可选, 默认为 false) # 如果设置为 true, 并且 prompt_mode 为 "s4u", # Bot将获取目标用户在所有与Bot的私聊中的消息, 即使该私聊没有被明确配置在下面的 chat_ids 中。