diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index 4acbcf7c8..a334472df 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -12,7 +12,7 @@ from datetime import datetime from typing import Any from src.chat.express.expression_selector import expression_selector -from src.chat.message_receive.chat_stream import ChatStream +from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager from src.chat.message_receive.message import MessageRecv, MessageSending, Seg, UserInfo from src.chat.message_receive.uni_message_sender import HeartFCSender from src.chat.utils.chat_message_builder import ( @@ -299,7 +299,7 @@ class DefaultReplyer: result = await event_manager.trigger_event( EventType.POST_LLM, permission_group="SYSTEM", prompt=prompt, stream_id=stream_id ) - if not result.all_continue_process(): + if result and not result.all_continue_process(): raise UserWarning(f"插件{result.get_summary().get('stopped_handlers', '')}于请求前中断了内容生成") # 4. 调用 LLM 生成回复 @@ -326,7 +326,7 @@ class DefaultReplyer: llm_response=llm_response, stream_id=stream_id, ) - if not result.all_continue_process(): + if result and not result.all_continue_process(): raise UserWarning( f"插件{result.get_summary().get('stopped_handlers', '')}于请求后取消了内容生成" ) @@ -899,7 +899,7 @@ class DefaultReplyer: # 处理消息内容中的用户引用,确保bot回复在消息内容中也正确显示 from src.chat.utils.chat_message_builder import replace_user_references_sync msg_content = replace_user_references_sync( - msg_content, + msg_content or "", platform, replace_bot_name=True ) @@ -1158,8 +1158,8 @@ class DefaultReplyer: await person_info_manager.first_knowing_some_one( platform, # type: ignore reply_message.get("user_id"), # type: ignore - reply_message.get("user_nickname"), - reply_message.get("user_cardname"), + reply_message.get("user_nickname") or "", + reply_message.get("user_cardname") or "", ) # 检查是否是bot自己的名字,如果是则替换为"(你)" @@ -1248,7 +1248,7 @@ class DefaultReplyer: ), "cross_context": asyncio.create_task( self._time_and_run_task( - Prompt.build_cross_context(chat_id, global_config.personality.prompt_mode, target_user_info), + self.build_full_cross_context(chat_id, target_user_info), "cross_context", ) ), @@ -1431,6 +1431,12 @@ class DefaultReplyer: template_name = "default_expressor_prompt" # 获取模板内容 + if not template_name: + logger.error("无法根据prompt_mode确定模板名称,请检查配置。") + return "" + if not template_name: + logger.error("无法根据prompt_mode确定模板名称,请检查配置。") + return "" template_prompt = await global_prompt_manager.get_prompt_async(template_name) prompt = Prompt(template=template_prompt.template, parameters=prompt_parameters) prompt_text = await prompt.build() @@ -1875,6 +1881,57 @@ class DefaultReplyer: except Exception as e: logger.error(f"存储聊天记忆失败: {e}") + async def build_full_cross_context(self, chat_id: str, target_user_info: dict[str, Any] | None) -> str: + """ + 构建完整的跨上下文信息,包括固定共享组和用户中心检索。 + """ + # 1. 处理固定的共享组 + from src.chat.utils.prompt import Prompt + cross_context_block = await Prompt.build_cross_context( + chat_id, global_config.personality.prompt_mode, target_user_info + ) + + # 2. 处理用户中心检索 + config = global_config.cross_context + if config.enable and config.user_centric_retrieval_mode != "disabled": + chat_manager = get_chat_manager() + current_stream = await chat_manager.get_stream(chat_id) + if not current_stream: + return cross_context_block + + # 检查黑白名单 + is_group = current_stream.group_info is not None + raw_id = None + if is_group and current_stream.group_info: + raw_id = current_stream.group_info.group_id + elif not is_group and current_stream.user_info: + raw_id = current_stream.user_info.user_id + + if not raw_id: + return cross_context_block + chat_type = "group" if is_group else "private" + + allow_retrieval = False + if config.user_centric_retrieval_mode == "all": + if [chat_type, str(raw_id)] not in config.blacklist_chats: + allow_retrieval = True + elif config.user_centric_retrieval_mode == "whitelist": + if [chat_type, str(raw_id)] in config.whitelist_chats: + allow_retrieval = True + + if allow_retrieval and target_user_info and "user_id" in target_user_info: + from src.plugin_system.apis.cross_context_api import get_user_centric_context + user_centric_context = await get_user_centric_context( + user_id=str(target_user_info["user_id"]), + platform=current_stream.platform, + limit=config.user_centric_retrieval_limit, + exclude_chat_id=chat_id, + ) + if user_centric_context: + cross_context_block = f"{cross_context_block}\n{user_centric_context}" + + return cross_context_block + def weighted_sample_no_replacement(items, weights, k) -> list: """ diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 86d116242..87f19e6ce 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -640,10 +640,19 @@ class ContextGroup(ValidatedConfigBase): class CrossContextConfig(ValidatedConfigBase): - """跨群聊上下文共享配置""" + """跨上下文共享配置""" - enable: bool = Field(default=False, description="是否启用跨群聊上下文共享功能") - groups: list[ContextGroup] = Field(default_factory=list, description="上下文共享组列表") + enable: bool = Field(default=True, description="是否启用跨上下文共享功能") + user_centric_retrieval_mode: Literal["disabled", "all", "whitelist"] = Field( + default="disabled", description="用户中心上下文检索模式" + ) + user_centric_retrieval_limit: int = Field(default=5, ge=1, le=50, description="用户中心上下文检索数量上限") + user_centric_retrieval_stream_limit: int = Field( + default=3, ge=0, description="用户中心上下文检索的聊天流数量上限,0为不限制" + ) + whitelist_chats: list[list[str]] = Field(default_factory=list, description="白名单聊天列表") + blacklist_chats: list[list[str]] = Field(default_factory=list, description="黑名单聊天列表") + # DEPRECATED: groups: list[ContextGroup] = Field(default_factory=list, description="上下文共享组列表") class CommandConfig(ValidatedConfigBase): diff --git a/src/plugin_system/apis/cross_context_api.py b/src/plugin_system/apis/cross_context_api.py index eed13697c..8d534aab2 100644 --- a/src/plugin_system/apis/cross_context_api.py +++ b/src/plugin_system/apis/cross_context_api.py @@ -33,14 +33,15 @@ async def get_context_groups(chat_id: str) -> list[list[str]] | None: current_chat_raw_id = current_stream.user_info.user_id current_type = "group" if is_group else "private" - for group in global_config.cross_context.groups: - # 检查当前聊天的ID和类型是否在组的chat_ids中 - if [current_type, str(current_chat_raw_id)] in group.chat_ids: - # 排除maizone专用组 - if group.name == "maizone_context_group": - continue - # 返回组内其他聊天的 [type, id] 列表 - return [chat_info for chat_info in group.chat_ids if chat_info != [current_type, str(current_chat_raw_id)]] + # This feature is deprecated + # for group in global_config.cross_context.groups: + # # 检查当前聊天的ID和类型是否在组的chat_ids中 + # if [current_type, str(current_chat_raw_id)] in group.chat_ids: + # # 排除maizone专用组 + # if group.name == "maizone_context_group": + # continue + # # 返回组内其他聊天的 [type, id] 列表 + # return [chat_info for chat_info in group.chat_ids if chat_info != [current_type, str(current_chat_raw_id)]] return None @@ -124,140 +125,86 @@ async def build_cross_context_s4u( return "### 其他群聊中的聊天记录\n" + "\n\n".join(cross_context_messages) + "\n" -async def get_chat_history_by_group_name(group_name: str) -> str: - """ - 根据互通组名字获取聊天记录 - """ - target_group = None - for group in global_config.cross_context.groups: - if group.name == group_name: - target_group = group - break - if not target_group: - return f"找不到名为 {group_name} 的互通组。" - - if not target_group.chat_ids: - return f"互通组 {group_name} 中没有配置任何聊天。" - - chat_infos = target_group.chat_ids - chat_manager = get_chat_manager() - - cross_context_messages = [] - for chat_type, chat_raw_id in chat_infos: - is_group = chat_type == "group" - - found_stream = None - for stream in chat_manager.streams.values(): - if is_group: - if stream.group_info and stream.group_info.group_id == chat_raw_id: - found_stream = stream - break - else: # private - if stream.user_info and stream.user_info.user_id == chat_raw_id and not stream.group_info: - found_stream = stream - break - - if not found_stream: - logger.warning(f"在已加载的聊天流中找不到ID为 {chat_raw_id} 的聊天。") - continue - - stream_id = found_stream.stream_id - - try: - messages = await get_raw_msg_before_timestamp_with_chat( - chat_id=stream_id, - timestamp=time.time(), - limit=5, # 可配置 - ) - if messages: - chat_name = await get_chat_manager().get_stream_name(stream_id) or chat_raw_id - formatted_messages, _ = await build_readable_messages_with_id(messages, timestamp_mode="relative") - cross_context_messages.append(f'[以下是来自"{chat_name}"的近期消息]\n{formatted_messages}') - except Exception as e: - logger.error(f"获取聊天 {chat_raw_id} 的消息失败: {e}") - continue - - if not cross_context_messages: - return f"无法从互通组 {group_name} 中获取任何聊天记录。" - - return "# 跨上下文参考\n" + "\n\n".join(cross_context_messages) + "\n" - - -async def get_intercom_group_context_by_name( - group_name: str, days: int = 3, limit_per_chat: int = 20, total_limit: int = 100 +async def get_user_centric_context( + user_id: str, platform: str, limit: int, exclude_chat_id: str | None = None ) -> str | None: """ - 根据互通组的名称,构建该组的聊天上下文。 + 获取以用户为中心的全域聊天记录。 Args: - group_name: 互通组的名称。 - days: 获取过去多少天的消息。 - limit_per_chat: 每个聊天最多获取的消息条数。 - total_limit: 返回的总消息条数上限。 + user_id: 目标用户的ID。 + platform: 用户所在的平台。 + limit: 每个聊天中获取的最大消息数量。 + exclude_chat_id: 需要排除的当前聊天ID。 Returns: - 如果找到匹配的组,则返回一个包含聊天记录的字符串;否则返回 None。 + 构建好的上下文信息字符串,如果没有找到则返回None。 """ - cross_context_config = global_config.cross_context - if not (cross_context_config and cross_context_config.enable): - return None - - target_group = None - for group in cross_context_config.groups: - if group.name == group_name: - target_group = group - break - - if not target_group: - logger.error(f"在 cross_context 配置中未找到名为 '{group_name}' 的组。") - return None - chat_manager = get_chat_manager() + user_messages_map = {} - all_messages = [] - end_time = time.time() - start_time = end_time - (days * 24 * 60 * 60) + # 遍历所有相关的聊天流 + streams_to_search = [] + private_stream = None + group_streams = [] - for chat_type, chat_raw_id in target_group.chat_ids: - is_group = chat_type == "group" - - found_stream = None - # 采用与 get_chat_history_by_group_name 相同的健壮的 stream 查找方式 - for stream in chat_manager.streams.values(): - if is_group: - if stream.group_info and stream.group_info.group_id == chat_raw_id: - found_stream = stream - break - else: # private - if stream.user_info and stream.user_info.user_id == chat_raw_id and not stream.group_info: - found_stream = stream - break - - if not found_stream: - logger.warning(f"在已加载的聊天流中找不到ID为 {chat_raw_id} 的聊天。") + for stream in chat_manager.streams.values(): + if stream.stream_id == exclude_chat_id: continue - stream_id = found_stream.stream_id - messages = await get_raw_msg_by_timestamp_with_chat( - chat_id=stream_id, - timestamp_start=start_time, - timestamp_end=end_time, - limit=limit_per_chat, - limit_mode="latest", - ) - all_messages.extend(messages) + is_group = stream.group_info is not None + if is_group: + # 对于群聊,检查用户是否是成员之一 (通过消息记录判断) + group_streams.append(stream) + else: + # 对于私聊,检查是否是与目标用户的私聊 + if stream.user_info and stream.user_info.user_id == user_id: + private_stream = stream - if not all_messages: + # 优先添加私聊流 + if private_stream: + streams_to_search.append(private_stream) + + # 按最近活跃时间对群聊流进行排序 + group_streams.sort(key=lambda s: s.last_active_time, reverse=True) + streams_to_search.extend(group_streams) + + # 应用聊天流数量限制 + stream_limit = global_config.cross_context.user_centric_retrieval_stream_limit + if stream_limit > 0: + streams_to_search = streams_to_search[:stream_limit] + + for stream in streams_to_search: + try: + messages = await get_raw_msg_before_timestamp_with_chat( + chat_id=stream.stream_id, + timestamp=time.time(), + limit=limit * 5, # 获取更多消息以供筛选 + ) + user_messages = [msg for msg in messages if msg.get("user_id") == user_id][-limit:] + + if user_messages: + chat_name = await chat_manager.get_stream_name(stream.stream_id) or stream.stream_id + if chat_name not in user_messages_map: + user_messages_map[chat_name] = [] + user_messages_map[chat_name].extend(user_messages) + except Exception as e: + logger.error(f"获取用户 {user_id} 在聊天 {stream.stream_id} 的消息失败: {e}") + continue + + if not user_messages_map: return None - # 按时间戳对所有消息进行排序 - all_messages.sort(key=lambda x: x.get("time", 0)) + # 构建最终的上下文字符串 + cross_context_parts = [] + for chat_name, messages in user_messages_map.items(): + # 按时间戳对消息进行排序 + messages.sort(key=lambda x: x.get("time", 0)) + formatted_messages, _ = await build_readable_messages_with_id(messages, timestamp_mode="relative") + cross_context_parts.append(f'[以下是该用户在"{chat_name}"的近期发言]\n{formatted_messages}') - # 限制总消息数 - if len(all_messages) > total_limit: - all_messages = all_messages[-total_limit:] + if not cross_context_parts: + return None - # build_readable_messages_with_id 返回一个元组 (formatted_string, message_id_list) - formatted_string, _ = await build_readable_messages_with_id(all_messages) - return formatted_string + return "### 该用户在其他地方的聊天记录\n" + "\n\n".join(cross_context_parts) + "\n" diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index 75cd84da3..b4d79fceb 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -1,5 +1,5 @@ [inner] -version = "7.2.6" +version = "7.2.7" #----以下是给开发人员阅读的,如果你只是部署了MoFox-Bot,不需要阅读---- #如果你想要修改配置文件,请递增version的值 @@ -518,29 +518,35 @@ insomnia_duration_minutes = [30, 60] # 单次失眠状态的持续时间范围 # 入睡后,经过一段延迟后触发失眠判定的延迟时间(分钟),设置为范围以增加随机性 insomnia_trigger_delay_minutes = [15, 45] -[cross_context] # 跨群聊/私聊上下文共享配置 -# 这是总开关,用于一键启用或禁用此功能 +[cross_context] # 跨上下文共享配置 +# 总开关,用于一键启用或禁用所有跨上下文功能 enable = true -# 在这里定义您的“共享组” -# 只有在同一个组内的聊天才会共享上下文 -# 格式:chat_ids = [["type", "id"], ["type", "id"], ...] -# type 可选 "group" 或 "private" -[[cross_context.groups]] -name = "项目A技术讨论组" -chat_ids = [ - ["group", "169850076"], # 假设这是“开发群”的ID - ["group", "1025509724"], # 假设这是“产品群”的ID - ["private", "123456789"] # 假设这是某个用户的私聊 + +# --- 用户中心上下文检索 --- +# 当回复特定用户时,自动从该用户参与的其他聊天中检索最近对话,以提供更丰富的上下文。 +# mode: "disabled" - 禁用此功能 +# "all" - 对所有聊天启用(黑名单除外) +# "whitelist" - 仅对白名单内的聊天启用 +user_centric_retrieval_mode = "disabled" +user_centric_retrieval_limit = 5 # 检索附加上下文时,获取的最大历史消息数量 +user_centric_retrieval_stream_limit = 3 # 检索附加上下文时,检索的聊天流数量上限,0为不限制 + +# 白名单 (仅当 mode = "whitelist" 时生效) +# 如果列表为空,则功能不应用于任何聊天 +# 格式: [["type", "id"], ...] (type: "group" 或 "private") +whitelist_chats = [ + # ["group", "123456"], + # ["private", "789012"] ] -# 定义QQ空间互通组 -# 同一个组的chat_id会共享上下文,用于生成更相关的说说 -[[cross_context.maizone_context_group]] -name = "Maizone默认互通组" -chat_ids = [ - ["group", "111111"], # 示例群聊1 - ["private", "222222"] # 示例私聊2 + +# 黑名单 (当 mode = "all" 时生效) +# 如果列表为空,则功能应用于所有聊天 +# 格式: [["type", "id"], ...] +blacklist_chats = [ + # ["group", "987654"] ] + [affinity_flow] # 兴趣评分系统参数 reply_action_interest_threshold = 1.1 # 回复动作兴趣阈值