diff --git a/src/chat/utils/prompt.py b/src/chat/utils/prompt.py index 291d891e0..33a137df2 100644 --- a/src/chat/utils/prompt.py +++ b/src/chat/utils/prompt.py @@ -1047,8 +1047,8 @@ class Prompt: from src.plugin_system.apis import cross_context_api - other_chat_raw_ids = await cross_context_api.get_context_groups(chat_id) - if not other_chat_raw_ids: + context_group = await cross_context_api.get_context_group(chat_id) + if not context_group: return "" chat_stream = await get_chat_manager().get_stream(chat_id) @@ -1056,9 +1056,18 @@ class Prompt: return "" if prompt_mode == "normal": - return await cross_context_api.build_cross_context_normal(chat_stream, other_chat_raw_ids) + current_chat_raw_id = ( + chat_stream.group_info.group_id if chat_stream.group_info else chat_stream.user_info.user_id + ) + current_type = "group" if chat_stream.group_info else "private" + other_chat_infos = [ + chat_info + for chat_info in context_group.chat_ids + if chat_info[:2] != [current_type, str(current_chat_raw_id)] + ] + return await cross_context_api.build_cross_context_normal(chat_stream, other_chat_infos) elif prompt_mode == "s4u": - return await cross_context_api.build_cross_context_s4u(chat_stream, other_chat_raw_ids, target_user_info) + return await cross_context_api.build_cross_context_s4u(chat_stream, context_group, target_user_info) return "" diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 86d116242..c63631aa7 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -633,6 +633,9 @@ class ContextGroup(ValidatedConfigBase): """上下文共享组配置""" name: str = Field(..., description="共享组的名称") + s4u_ignore_whitelist: bool = Field( + default=False, description="在s4u模式下, 是否无视白名单, 获取用户所有私聊消息" + ) chat_ids: list[list[str]] = Field( ..., description='属于该组的聊天ID列表,格式为 [["type", "chat_id"], ...],例如 [["group", "123456"], ["private", "789012"]]', diff --git a/src/plugin_system/apis/cross_context_api.py b/src/plugin_system/apis/cross_context_api.py index 43286a5b1..bf8e146cd 100644 --- a/src/plugin_system/apis/cross_context_api.py +++ b/src/plugin_system/apis/cross_context_api.py @@ -13,13 +13,14 @@ from src.chat.utils.chat_message_builder import ( ) from src.common.logger import get_logger from src.config.config import global_config +from src.config.official_configs import ContextGroup logger = get_logger("cross_context_api") -async def get_context_groups(chat_id: str) -> list[list[str]] | None: +async def get_context_group(chat_id: str) -> ContextGroup | None: """ - 获取当前聊天所在的共享组的其他聊天ID + 获取当前聊天所在的共享组 """ current_stream = await get_chat_manager().get_stream(chat_id) if not current_stream: @@ -39,8 +40,7 @@ async def get_context_groups(chat_id: str) -> list[list[str]] | None: # 排除maizone专用组 if group.name == "maizone_context_group": continue - # 返回组内其他聊天的 [type, id] 列表 - return [chat_info for chat_info in group.chat_ids if chat_info != [current_type, str(current_chat_raw_id)]] + return group return None @@ -79,51 +79,85 @@ async def build_cross_context_normal(chat_stream: ChatStream, other_chat_infos: async def build_cross_context_s4u( chat_stream: ChatStream, - other_chat_infos: list[list[str]], + context_group: ContextGroup, target_user_info: dict[str, Any] | None, ) -> str: """ 构建跨群聊/私聊上下文 (S4U模式) """ cross_context_messages = [] - if target_user_info: - user_id = target_user_info.get("user_id") + if not target_user_info or not (user_id := target_user_info.get("user_id")): + return "" - if user_id: - for chat_info in other_chat_infos: - chat_type, chat_raw_id, limit = ( - chat_info[0], - chat_info[1], - int(chat_info[2]) if len(chat_info) > 2 else 5, + chat_manager = get_chat_manager() + current_chat_raw_id = ( + chat_stream.group_info.group_id if chat_stream.group_info else chat_stream.user_info.user_id + ) + current_type = "group" if chat_stream.group_info else "private" + + other_chat_infos = [ + chat_info + for chat_info in context_group.chat_ids + if chat_info[:2] != [current_type, str(current_chat_raw_id)] + ] + + # 1. 处理在白名单内的聊天 + for chat_info in other_chat_infos: + chat_type, chat_raw_id, limit = ( + chat_info[0], + chat_info[1], + int(chat_info[2]) if len(chat_info) > 2 else 5, + ) + is_group = chat_type == "group" + stream_id = chat_manager.get_stream_id(chat_stream.platform, chat_raw_id, is_group=is_group) + if not stream_id: + continue + + try: + messages = await get_raw_msg_before_timestamp_with_chat( + chat_id=stream_id, timestamp=time.time(), limit=limit * 4 + ) + user_messages = [msg for msg in messages if msg.get("user_id") == user_id][-limit:] + + if user_messages: + chat_name = await chat_manager.get_stream_name(stream_id) or chat_raw_id + user_name = target_user_info.get("person_name") or target_user_info.get("user_nickname") or user_id + formatted_messages, _ = await build_readable_messages_with_id( + user_messages, timestamp_mode="relative" ) - is_group = chat_type == "group" - stream_id = get_chat_manager().get_stream_id(chat_stream.platform, chat_raw_id, is_group=is_group) - if not stream_id: - continue + cross_context_messages.append( + f'[以下是"{user_name}"在"{chat_name}"的近期发言]\n{formatted_messages}' + ) + except Exception as e: + logger.error(f"获取用户 {user_id} 在聊天 {chat_raw_id} 的消息失败: {e}") - try: - # S4U模式下,我们获取更多消息以供筛选 - messages = await get_raw_msg_before_timestamp_with_chat( - chat_id=stream_id, - timestamp=time.time(), - limit=limit * 4, # 获取4倍limit的消息以供筛选 + # 2. 如果开启了 s4u_ignore_whitelist,则获取用户与Bot的私聊记录 + if context_group.s4u_ignore_whitelist: + private_stream_id = chat_manager.get_stream_id(chat_stream.platform, user_id, is_group=False) + # 检查该私聊是否已在白名单中处理过 + is_already_processed = any( + info[0] == "private" and info[1] == user_id for info in context_group.chat_ids + ) + + if private_stream_id and not is_already_processed: + try: + limit = 5 # 使用默认值 + messages = await get_raw_msg_before_timestamp_with_chat( + chat_id=private_stream_id, timestamp=time.time(), limit=limit * 4 + ) + user_messages = [msg for msg in messages if msg.get("user_id") == user_id][-limit:] + + if user_messages: + chat_name = await chat_manager.get_stream_name(private_stream_id) or user_id + user_name = target_user_info.get("person_name") or target_user_info.get("user_nickname") or user_id + formatted_messages, _ = await build_readable_messages_with_id( + user_messages, timestamp_mode="relative" ) - user_messages = [msg for msg in messages if msg.get("user_id") == user_id][-limit:] - - if user_messages: - chat_name = await get_chat_manager().get_stream_name(stream_id) or chat_raw_id - user_name = ( - target_user_info.get("person_name") or target_user_info.get("user_nickname") or user_id - ) - formatted_messages, _ = await build_readable_messages_with_id( - user_messages, timestamp_mode="relative" - ) - cross_context_messages.append( - f'[以下是"{user_name}"在"{chat_name}"的近期发言]\n{formatted_messages}' - ) - except Exception as e: - logger.error(f"获取用户 {user_id} 在聊天 {chat_raw_id} 的消息失败: {e}") - continue + cross_context_messages.append( + f'[以下是"{user_name}"在与你的私聊中的近期发言]\n{formatted_messages}' + ) + except Exception as e: + logger.error(f"获取用户 {user_id} 的私聊消息失败: {e}") if not cross_context_messages: return "" diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index 8b3bd2fb0..23794e7a5 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -1,5 +1,5 @@ [inner] -version = "7.2.9" +version = "7.2.10" #----以下是给开发人员阅读的,如果你只是部署了MoFox-Bot,不需要阅读---- #如果你想要修改配置文件,请递增version的值 @@ -528,6 +528,11 @@ enable = true # limit 是一个可选的整数(但需要以字符串形式写入),用于指定从该聊天流中获取的消息数量,如果未指定,默认为5 [[cross_context.groups]] name = "项目A技术讨论组" +# s4u_ignore_whitelist: (可选, 默认为 false) +# 如果设置为 true, 并且 prompt_mode 为 "s4u", +# Bot将获取目标用户在所有与Bot的私聊中的消息, 即使该私聊没有被明确配置在下面的 chat_ids 中。 +# 这有助于构建更完整的用户画像, 但可能会增加token消耗。 +s4u_ignore_whitelist = false chat_ids = [ ["group", "169850076", "10"], # 假设这是“开发群”的ID, 从这个群里拿10条消息 ["group", "1025509724", "5"], # 假设这是“产品群”的ID,拿5条