refactor(cross_context): 重构S4U上下文检索逻辑并更新配置
将S4U(Search for User)上下文检索模式从依赖于共享组(ContextGroup)的配置中解耦,改为使用独立的全局配置。这使得S4U模式的管理更加清晰和灵活。 主要变更: - **配置模型更新**: 在`CrossContextConfig`中移除了与S4U相关的字段(如`s4u_ignore_whitelist`),并添加了新的S4U专用配置项,包括`s4u_mode`, `s4u_limit`, `s4u_stream_limit`, `s4u_whitelist_chats`, 和 `s4u_blacklist_chats`。 - **S4U逻辑重构**: `build_cross_context_s4u`函数不再接收`context_group`参数,而是直接读取全局的S4U配置来检索用户在白名单或黑名单聊天中的消息。 - **简化调用**: `Prompt.get_cross_context_prompt`中的调用逻辑被简化,以适应新的函数签名。 - **文档与模板更新**: 更新了`bot_config_template.toml`配置文件模板,以反映新的S4U配置结构,并提供了更清晰的注释说明。 此次重构将Normal模式(群组共享)和S4U模式(用户中心)的配置和实现完全分离,提高了代码的可维护性和配置的直观性。
This commit is contained in:
@@ -1047,22 +1047,17 @@ class Prompt:
|
||||
|
||||
from src.plugin_system.apis import cross_context_api
|
||||
|
||||
context_group = await cross_context_api.get_context_group(chat_id)
|
||||
if not context_group:
|
||||
return ""
|
||||
|
||||
chat_stream = await get_chat_manager().get_stream(chat_id)
|
||||
if not chat_stream:
|
||||
return ""
|
||||
|
||||
if prompt_mode == "normal":
|
||||
current_chat_raw_id = (
|
||||
chat_stream.group_info.group_id if chat_stream.group_info else chat_stream.user_info.user_id
|
||||
)
|
||||
current_type = "group" if chat_stream.group_info else "private"
|
||||
context_group = await cross_context_api.get_context_group(chat_id)
|
||||
if not context_group:
|
||||
return ""
|
||||
return await cross_context_api.build_cross_context_normal(chat_stream, context_group)
|
||||
elif prompt_mode == "s4u":
|
||||
return await cross_context_api.build_cross_context_s4u(chat_stream, context_group, target_user_info)
|
||||
return await cross_context_api.build_cross_context_s4u(chat_stream, target_user_info)
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
@@ -643,11 +643,7 @@ class ContextGroup(ValidatedConfigBase):
|
||||
)
|
||||
default_limit: int = Field(
|
||||
default=5,
|
||||
description="在'blacklist'模式下,对于未明确指定数量的聊天,默认获取的消息条数。也用于s4u_ignore_whitelist开启时获取私聊消息的数量。",
|
||||
)
|
||||
s4u_ignore_whitelist: bool = Field(
|
||||
default=False,
|
||||
description="在s4u模式下,是否无视白名单,获取目标用户与Bot的所有私聊消息,以构建更完整的用户画像。",
|
||||
description="在'blacklist'模式下,对于未明确指定数量的聊天,默认获取的消息条数。",
|
||||
)
|
||||
chat_ids: list[list[str]] = Field(
|
||||
...,
|
||||
@@ -655,12 +651,43 @@ class ContextGroup(ValidatedConfigBase):
|
||||
)
|
||||
|
||||
|
||||
class MaizoneContextGroup(ValidatedConfigBase):
|
||||
"""QQ空间专用互通组配置"""
|
||||
|
||||
name: str = Field(..., description="QQ空间互通组的名称")
|
||||
chat_ids: list[list[str]] = Field(
|
||||
...,
|
||||
description='定义组内成员的列表。格式为 [["type", "id"]]。type为"group"或"private",id为群号或用户ID。',
|
||||
)
|
||||
|
||||
|
||||
class CrossContextConfig(ValidatedConfigBase):
|
||||
"""跨群聊上下文共享配置"""
|
||||
|
||||
enable: bool = Field(default=False, description="是否启用跨群聊上下文共享功能")
|
||||
|
||||
# --- Normal模式: 共享组配置 ---
|
||||
groups: list[ContextGroup] = Field(default_factory=list, description="上下文共享组列表")
|
||||
|
||||
# --- S4U模式: 用户中心上下文检索 ---
|
||||
s4u_mode: Literal["whitelist", "blacklist"] = Field(
|
||||
default="whitelist",
|
||||
description="S4U模式的白名单/黑名单模式",
|
||||
)
|
||||
s4u_limit: int = Field(default=5, description="S4U模式下,每个聊天获取的消息条数")
|
||||
s4u_stream_limit: int = Field(default=3, description="S4U模式下,最多检索多少个不同的聊天流")
|
||||
s4u_whitelist_chats: list[str] = Field(
|
||||
default_factory=list,
|
||||
description='S4U模式的白名单列表。格式: ["platform:type:id", ...]',
|
||||
)
|
||||
s4u_blacklist_chats: list[str] = Field(
|
||||
default_factory=list,
|
||||
description='S4U模式的黑名单列表。格式: ["platform:type:id", ...]',
|
||||
)
|
||||
|
||||
# --- QQ空间专用互通组 ---
|
||||
maizone_context_group: list[MaizoneContextGroup] = Field(default_factory=list, description="QQ空间专用互通组列表")
|
||||
|
||||
|
||||
class CommandConfig(ValidatedConfigBase):
|
||||
"""命令系统配置类"""
|
||||
|
||||
@@ -34,12 +34,14 @@ async def get_context_group(chat_id: str) -> ContextGroup | None:
|
||||
current_type = "group" if is_group else "private"
|
||||
|
||||
for group in global_config.cross_context.groups:
|
||||
# 检查当前聊天的ID和类型是否在组的chat_ids中
|
||||
if [current_type, str(current_chat_raw_id)] in group.chat_ids:
|
||||
# 排除maizone专用组
|
||||
if group.name == "maizone_context_group":
|
||||
continue
|
||||
return group
|
||||
for chat_info in group.chat_ids:
|
||||
if len(chat_info) >= 2:
|
||||
chat_type, chat_raw_id = chat_info[0], chat_info[1]
|
||||
if chat_type == current_type and str(chat_raw_id) == str(current_chat_raw_id):
|
||||
# 排除maizone专用组
|
||||
if group.name == "maizone_context_group":
|
||||
continue
|
||||
return group
|
||||
|
||||
return None
|
||||
|
||||
@@ -118,108 +120,88 @@ async def build_cross_context_normal(chat_stream: ChatStream, context_group: Con
|
||||
|
||||
async def build_cross_context_s4u(
|
||||
chat_stream: ChatStream,
|
||||
context_group: ContextGroup,
|
||||
target_user_info: dict[str, Any] | None,
|
||||
) -> str:
|
||||
"""
|
||||
构建跨群聊/私聊上下文 (S4U模式)
|
||||
构建跨群聊/私聊上下文 (S4U模式)。
|
||||
|
||||
基于全局S4U配置,检索目标用户在其他聊天中的发言。
|
||||
"""
|
||||
cross_context_messages = []
|
||||
cross_context_config = global_config.cross_context
|
||||
if not target_user_info or not (user_id := target_user_info.get("user_id")):
|
||||
return ""
|
||||
|
||||
chat_manager = get_chat_manager()
|
||||
current_chat_raw_id = chat_stream.group_info.group_id if chat_stream.group_info else chat_stream.user_info.user_id
|
||||
current_type = "group" if chat_stream.group_info else "private"
|
||||
all_user_messages = []
|
||||
|
||||
# 根据模式(黑名单/白名单)决定需要处理哪些聊天
|
||||
chat_infos_to_process = []
|
||||
if context_group.mode == "blacklist":
|
||||
# 黑名单模式:获取除当前聊天和黑名单内聊天之外的所有聊天
|
||||
blacklisted_ids = {tuple(info[:2]) for info in context_group.chat_ids}
|
||||
for stream_id, stream in chat_manager.streams.items():
|
||||
if stream_id == chat_stream.stream_id:
|
||||
continue # 排除当前聊天
|
||||
# 1. 根据全局S4U配置确定要扫描的聊天范围
|
||||
streams_to_scan = []
|
||||
if cross_context_config.s4u_mode == "whitelist":
|
||||
for chat_str in cross_context_config.s4u_whitelist_chats:
|
||||
try:
|
||||
platform, chat_type, chat_raw_id = chat_str.split(":")
|
||||
is_group = chat_type == "group"
|
||||
stream_id = chat_manager.get_stream_id(platform, chat_raw_id, is_group=is_group)
|
||||
if stream_id and stream_id != chat_stream.stream_id:
|
||||
streams_to_scan.append(stream_id)
|
||||
except ValueError:
|
||||
logger.warning(f"无效的S4U白名单格式: {chat_str}")
|
||||
else: # blacklist mode
|
||||
blacklisted_streams = set()
|
||||
for chat_str in cross_context_config.s4u_blacklist_chats:
|
||||
try:
|
||||
platform, chat_type, chat_raw_id = chat_str.split(":")
|
||||
is_group = chat_type == "group"
|
||||
stream_id = chat_manager.get_stream_id(platform, chat_raw_id, is_group=is_group)
|
||||
if stream_id:
|
||||
blacklisted_streams.add(stream_id)
|
||||
except ValueError:
|
||||
logger.warning(f"无效的S4U黑名单格式: {chat_str}")
|
||||
|
||||
is_group = stream.group_info is not None
|
||||
chat_type = "group" if is_group else "private"
|
||||
|
||||
# 安全地获取 raw_id
|
||||
if is_group and stream.group_info:
|
||||
raw_id = stream.group_info.group_id
|
||||
elif not is_group and stream.user_info:
|
||||
raw_id = stream.user_info.user_id
|
||||
else:
|
||||
continue # 如果缺少关键信息则跳过
|
||||
|
||||
# 如果不在黑名单中,则加入处理列表
|
||||
if (chat_type, str(raw_id)) not in blacklisted_ids:
|
||||
chat_infos_to_process.append([chat_type, str(raw_id), str(context_group.default_limit)])
|
||||
else: # 白名单模式
|
||||
# 白名单模式:只获取在 chat_ids 中且非当前聊天的聊天
|
||||
chat_infos_to_process = [
|
||||
chat_info
|
||||
for chat_info in context_group.chat_ids
|
||||
if chat_info[:2] != [current_type, str(current_chat_raw_id)]
|
||||
]
|
||||
|
||||
# 1. 处理筛选出的目标聊天
|
||||
for chat_info in chat_infos_to_process:
|
||||
chat_type, chat_raw_id, limit_str = (
|
||||
chat_info[0],
|
||||
chat_info[1],
|
||||
chat_info[2] if len(chat_info) > 2 else str(context_group.default_limit),
|
||||
)
|
||||
limit = int(limit_str)
|
||||
is_group = chat_type == "group"
|
||||
stream_id = chat_manager.get_stream_id(chat_stream.platform, chat_raw_id, is_group=is_group)
|
||||
if not stream_id:
|
||||
continue
|
||||
for stream_id in chat_manager.streams:
|
||||
if stream_id != chat_stream.stream_id and stream_id not in blacklisted_streams:
|
||||
streams_to_scan.append(stream_id)
|
||||
|
||||
# 2. 从筛选出的聊天流中获取目标用户的消息
|
||||
limit = cross_context_config.s4u_limit
|
||||
for stream_id in streams_to_scan:
|
||||
try:
|
||||
# 获取稍多一些消息以确保能筛选出用户消息
|
||||
messages = await get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=stream_id, timestamp=time.time(), limit=limit * 4
|
||||
chat_id=stream_id, timestamp=time.time(), limit=limit * 5
|
||||
)
|
||||
user_messages = [msg for msg in messages if msg.get("user_id") == user_id][-limit:]
|
||||
|
||||
if user_messages:
|
||||
chat_name = await chat_manager.get_stream_name(stream_id) or chat_raw_id
|
||||
user_name = target_user_info.get("person_name") or target_user_info.get("user_nickname") or user_id
|
||||
formatted_messages, _ = await build_readable_messages_with_id(user_messages, timestamp_mode="relative")
|
||||
cross_context_messages.append(f'[以下是"{user_name}"在"{chat_name}"的近期发言]\n{formatted_messages}')
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户 {user_id} 在聊天 {chat_raw_id} 的消息失败: {e}")
|
||||
|
||||
# 2. 如果开启了 s4u_ignore_whitelist,则获取用户与Bot的私聊记录
|
||||
if context_group.s4u_ignore_whitelist:
|
||||
private_stream_id = chat_manager.get_stream_id(chat_stream.platform, user_id, is_group=False)
|
||||
# 检查该私聊是否已在白名单中处理过
|
||||
is_already_processed = any(info[0] == "private" and info[1] == user_id for info in context_group.chat_ids)
|
||||
|
||||
if private_stream_id and not is_already_processed:
|
||||
try:
|
||||
limit = context_group.default_limit
|
||||
messages = await get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=private_stream_id, timestamp=time.time(), limit=limit * 4
|
||||
# 记录消息来源的 stream_id 和最新消息的时间戳
|
||||
latest_timestamp = max(msg.get("time", 0) for msg in user_messages)
|
||||
all_user_messages.append(
|
||||
{"stream_id": stream_id, "messages": user_messages, "latest_timestamp": latest_timestamp}
|
||||
)
|
||||
user_messages = [msg for msg in messages if msg.get("user_id") == user_id][-limit:]
|
||||
except Exception as e:
|
||||
logger.error(f"S4U模式下获取聊天 {stream_id} 的消息失败: {e}")
|
||||
|
||||
if user_messages:
|
||||
chat_name = await chat_manager.get_stream_name(private_stream_id) or user_id
|
||||
user_name = target_user_info.get("person_name") or target_user_info.get("user_nickname") or user_id
|
||||
formatted_messages, _ = await build_readable_messages_with_id(
|
||||
user_messages, timestamp_mode="relative"
|
||||
)
|
||||
cross_context_messages.append(
|
||||
f'[以下是"{user_name}"在与你的私聊中的近期发言]\n{formatted_messages}'
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户 {user_id} 的私聊消息失败: {e}")
|
||||
# 3. 按最新消息时间排序,并根据 stream_limit 截断
|
||||
all_user_messages.sort(key=lambda x: x["latest_timestamp"], reverse=True)
|
||||
limited_messages = all_user_messages[: cross_context_config.s4u_stream_limit]
|
||||
|
||||
if not cross_context_messages:
|
||||
# 4. 格式化最终的消息文本
|
||||
cross_context_blocks = []
|
||||
for item in limited_messages:
|
||||
stream_id = item["stream_id"]
|
||||
messages = item["messages"]
|
||||
try:
|
||||
chat_name = await chat_manager.get_stream_name(stream_id) or "未知聊天"
|
||||
user_name = target_user_info.get("person_name") or target_user_info.get("user_nickname") or user_id
|
||||
formatted, _ = await build_readable_messages_with_id(messages, timestamp_mode="relative")
|
||||
cross_context_blocks.append(f'[以下是"{user_name}"在"{chat_name}"的近期发言]\n{formatted}')
|
||||
except Exception as e:
|
||||
logger.error(f"S4U模式下格式化消息失败 (stream: {stream_id}): {e}")
|
||||
|
||||
if not cross_context_blocks:
|
||||
return ""
|
||||
|
||||
return "### 其他群聊中的聊天记录\n" + "\n\n".join(cross_context_messages) + "\n"
|
||||
return "### 其他群聊中的聊天记录\n" + "\n\n".join(cross_context_blocks) + "\n"
|
||||
|
||||
|
||||
async def get_intercom_group_context(group_name: str, limit_per_chat: int = 20, total_limit: int = 100) -> str | None:
|
||||
|
||||
Reference in New Issue
Block a user