feat(context): 增强s4u跨上下文模式并重构API
在跨上下文功能中为s4u模式引入`s4u_ignore_whitelist`配置项。当启用时,除了白名单中配置的聊天记录外,还会自动获取目标用户与Bot的私聊记录,以构建更全面的用户画像。 主要变更: - 在 `ContextGroup` 配置中添加 `s4u_ignore_whitelist` 字段。 - 重构 `cross_context_api`,将 `get_context_groups` 更改为 `get_context_group`,使其返回完整的 `ContextGroup` 对象而非仅ID列表,以便于访问新配置。 - 调整 `build_cross_context_s4u` 函数以处理新逻辑,包括获取私聊记录和避免重复处理。 - 更新了配置文件模板以包含新选项的说明和示例。
This commit is contained in:
@@ -1047,8 +1047,8 @@ class Prompt:
|
|||||||
|
|
||||||
from src.plugin_system.apis import cross_context_api
|
from src.plugin_system.apis import cross_context_api
|
||||||
|
|
||||||
other_chat_raw_ids = await cross_context_api.get_context_groups(chat_id)
|
context_group = await cross_context_api.get_context_group(chat_id)
|
||||||
if not other_chat_raw_ids:
|
if not context_group:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
chat_stream = await get_chat_manager().get_stream(chat_id)
|
chat_stream = await get_chat_manager().get_stream(chat_id)
|
||||||
@@ -1056,9 +1056,18 @@ class Prompt:
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
if prompt_mode == "normal":
|
if prompt_mode == "normal":
|
||||||
return await cross_context_api.build_cross_context_normal(chat_stream, other_chat_raw_ids)
|
current_chat_raw_id = (
|
||||||
|
chat_stream.group_info.group_id if chat_stream.group_info else chat_stream.user_info.user_id
|
||||||
|
)
|
||||||
|
current_type = "group" if chat_stream.group_info else "private"
|
||||||
|
other_chat_infos = [
|
||||||
|
chat_info
|
||||||
|
for chat_info in context_group.chat_ids
|
||||||
|
if chat_info[:2] != [current_type, str(current_chat_raw_id)]
|
||||||
|
]
|
||||||
|
return await cross_context_api.build_cross_context_normal(chat_stream, other_chat_infos)
|
||||||
elif prompt_mode == "s4u":
|
elif prompt_mode == "s4u":
|
||||||
return await cross_context_api.build_cross_context_s4u(chat_stream, other_chat_raw_ids, target_user_info)
|
return await cross_context_api.build_cross_context_s4u(chat_stream, context_group, target_user_info)
|
||||||
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|||||||
@@ -633,6 +633,9 @@ class ContextGroup(ValidatedConfigBase):
|
|||||||
"""上下文共享组配置"""
|
"""上下文共享组配置"""
|
||||||
|
|
||||||
name: str = Field(..., description="共享组的名称")
|
name: str = Field(..., description="共享组的名称")
|
||||||
|
s4u_ignore_whitelist: bool = Field(
|
||||||
|
default=False, description="在s4u模式下, 是否无视白名单, 获取用户所有私聊消息"
|
||||||
|
)
|
||||||
chat_ids: list[list[str]] = Field(
|
chat_ids: list[list[str]] = Field(
|
||||||
...,
|
...,
|
||||||
description='属于该组的聊天ID列表,格式为 [["type", "chat_id"], ...],例如 [["group", "123456"], ["private", "789012"]]',
|
description='属于该组的聊天ID列表,格式为 [["type", "chat_id"], ...],例如 [["group", "123456"], ["private", "789012"]]',
|
||||||
|
|||||||
@@ -13,13 +13,14 @@ from src.chat.utils.chat_message_builder import (
|
|||||||
)
|
)
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
from src.config.official_configs import ContextGroup
|
||||||
|
|
||||||
logger = get_logger("cross_context_api")
|
logger = get_logger("cross_context_api")
|
||||||
|
|
||||||
|
|
||||||
async def get_context_groups(chat_id: str) -> list[list[str]] | None:
|
async def get_context_group(chat_id: str) -> ContextGroup | None:
|
||||||
"""
|
"""
|
||||||
获取当前聊天所在的共享组的其他聊天ID
|
获取当前聊天所在的共享组
|
||||||
"""
|
"""
|
||||||
current_stream = await get_chat_manager().get_stream(chat_id)
|
current_stream = await get_chat_manager().get_stream(chat_id)
|
||||||
if not current_stream:
|
if not current_stream:
|
||||||
@@ -39,8 +40,7 @@ async def get_context_groups(chat_id: str) -> list[list[str]] | None:
|
|||||||
# 排除maizone专用组
|
# 排除maizone专用组
|
||||||
if group.name == "maizone_context_group":
|
if group.name == "maizone_context_group":
|
||||||
continue
|
continue
|
||||||
# 返回组内其他聊天的 [type, id] 列表
|
return group
|
||||||
return [chat_info for chat_info in group.chat_ids if chat_info != [current_type, str(current_chat_raw_id)]]
|
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -79,51 +79,85 @@ async def build_cross_context_normal(chat_stream: ChatStream, other_chat_infos:
|
|||||||
|
|
||||||
async def build_cross_context_s4u(
|
async def build_cross_context_s4u(
|
||||||
chat_stream: ChatStream,
|
chat_stream: ChatStream,
|
||||||
other_chat_infos: list[list[str]],
|
context_group: ContextGroup,
|
||||||
target_user_info: dict[str, Any] | None,
|
target_user_info: dict[str, Any] | None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
构建跨群聊/私聊上下文 (S4U模式)
|
构建跨群聊/私聊上下文 (S4U模式)
|
||||||
"""
|
"""
|
||||||
cross_context_messages = []
|
cross_context_messages = []
|
||||||
if target_user_info:
|
if not target_user_info or not (user_id := target_user_info.get("user_id")):
|
||||||
user_id = target_user_info.get("user_id")
|
return ""
|
||||||
|
|
||||||
if user_id:
|
chat_manager = get_chat_manager()
|
||||||
for chat_info in other_chat_infos:
|
current_chat_raw_id = (
|
||||||
chat_type, chat_raw_id, limit = (
|
chat_stream.group_info.group_id if chat_stream.group_info else chat_stream.user_info.user_id
|
||||||
chat_info[0],
|
)
|
||||||
chat_info[1],
|
current_type = "group" if chat_stream.group_info else "private"
|
||||||
int(chat_info[2]) if len(chat_info) > 2 else 5,
|
|
||||||
|
other_chat_infos = [
|
||||||
|
chat_info
|
||||||
|
for chat_info in context_group.chat_ids
|
||||||
|
if chat_info[:2] != [current_type, str(current_chat_raw_id)]
|
||||||
|
]
|
||||||
|
|
||||||
|
# 1. 处理在白名单内的聊天
|
||||||
|
for chat_info in other_chat_infos:
|
||||||
|
chat_type, chat_raw_id, limit = (
|
||||||
|
chat_info[0],
|
||||||
|
chat_info[1],
|
||||||
|
int(chat_info[2]) if len(chat_info) > 2 else 5,
|
||||||
|
)
|
||||||
|
is_group = chat_type == "group"
|
||||||
|
stream_id = chat_manager.get_stream_id(chat_stream.platform, chat_raw_id, is_group=is_group)
|
||||||
|
if not stream_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
messages = await get_raw_msg_before_timestamp_with_chat(
|
||||||
|
chat_id=stream_id, timestamp=time.time(), limit=limit * 4
|
||||||
|
)
|
||||||
|
user_messages = [msg for msg in messages if msg.get("user_id") == user_id][-limit:]
|
||||||
|
|
||||||
|
if user_messages:
|
||||||
|
chat_name = await chat_manager.get_stream_name(stream_id) or chat_raw_id
|
||||||
|
user_name = target_user_info.get("person_name") or target_user_info.get("user_nickname") or user_id
|
||||||
|
formatted_messages, _ = await build_readable_messages_with_id(
|
||||||
|
user_messages, timestamp_mode="relative"
|
||||||
)
|
)
|
||||||
is_group = chat_type == "group"
|
cross_context_messages.append(
|
||||||
stream_id = get_chat_manager().get_stream_id(chat_stream.platform, chat_raw_id, is_group=is_group)
|
f'[以下是"{user_name}"在"{chat_name}"的近期发言]\n{formatted_messages}'
|
||||||
if not stream_id:
|
)
|
||||||
continue
|
except Exception as e:
|
||||||
|
logger.error(f"获取用户 {user_id} 在聊天 {chat_raw_id} 的消息失败: {e}")
|
||||||
|
|
||||||
try:
|
# 2. 如果开启了 s4u_ignore_whitelist,则获取用户与Bot的私聊记录
|
||||||
# S4U模式下,我们获取更多消息以供筛选
|
if context_group.s4u_ignore_whitelist:
|
||||||
messages = await get_raw_msg_before_timestamp_with_chat(
|
private_stream_id = chat_manager.get_stream_id(chat_stream.platform, user_id, is_group=False)
|
||||||
chat_id=stream_id,
|
# 检查该私聊是否已在白名单中处理过
|
||||||
timestamp=time.time(),
|
is_already_processed = any(
|
||||||
limit=limit * 4, # 获取4倍limit的消息以供筛选
|
info[0] == "private" and info[1] == user_id for info in context_group.chat_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
if private_stream_id and not is_already_processed:
|
||||||
|
try:
|
||||||
|
limit = 5 # 使用默认值
|
||||||
|
messages = await get_raw_msg_before_timestamp_with_chat(
|
||||||
|
chat_id=private_stream_id, timestamp=time.time(), limit=limit * 4
|
||||||
|
)
|
||||||
|
user_messages = [msg for msg in messages if msg.get("user_id") == user_id][-limit:]
|
||||||
|
|
||||||
|
if user_messages:
|
||||||
|
chat_name = await chat_manager.get_stream_name(private_stream_id) or user_id
|
||||||
|
user_name = target_user_info.get("person_name") or target_user_info.get("user_nickname") or user_id
|
||||||
|
formatted_messages, _ = await build_readable_messages_with_id(
|
||||||
|
user_messages, timestamp_mode="relative"
|
||||||
)
|
)
|
||||||
user_messages = [msg for msg in messages if msg.get("user_id") == user_id][-limit:]
|
cross_context_messages.append(
|
||||||
|
f'[以下是"{user_name}"在与你的私聊中的近期发言]\n{formatted_messages}'
|
||||||
if user_messages:
|
)
|
||||||
chat_name = await get_chat_manager().get_stream_name(stream_id) or chat_raw_id
|
except Exception as e:
|
||||||
user_name = (
|
logger.error(f"获取用户 {user_id} 的私聊消息失败: {e}")
|
||||||
target_user_info.get("person_name") or target_user_info.get("user_nickname") or user_id
|
|
||||||
)
|
|
||||||
formatted_messages, _ = await build_readable_messages_with_id(
|
|
||||||
user_messages, timestamp_mode="relative"
|
|
||||||
)
|
|
||||||
cross_context_messages.append(
|
|
||||||
f'[以下是"{user_name}"在"{chat_name}"的近期发言]\n{formatted_messages}'
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取用户 {user_id} 在聊天 {chat_raw_id} 的消息失败: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not cross_context_messages:
|
if not cross_context_messages:
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
[inner]
|
[inner]
|
||||||
version = "7.2.9"
|
version = "7.2.10"
|
||||||
|
|
||||||
#----以下是给开发人员阅读的,如果你只是部署了MoFox-Bot,不需要阅读----
|
#----以下是给开发人员阅读的,如果你只是部署了MoFox-Bot,不需要阅读----
|
||||||
#如果你想要修改配置文件,请递增version的值
|
#如果你想要修改配置文件,请递增version的值
|
||||||
@@ -528,6 +528,11 @@ enable = true
|
|||||||
# limit 是一个可选的整数(但需要以字符串形式写入),用于指定从该聊天流中获取的消息数量,如果未指定,默认为5
|
# limit 是一个可选的整数(但需要以字符串形式写入),用于指定从该聊天流中获取的消息数量,如果未指定,默认为5
|
||||||
[[cross_context.groups]]
|
[[cross_context.groups]]
|
||||||
name = "项目A技术讨论组"
|
name = "项目A技术讨论组"
|
||||||
|
# s4u_ignore_whitelist: (可选, 默认为 false)
|
||||||
|
# 如果设置为 true, 并且 prompt_mode 为 "s4u",
|
||||||
|
# Bot将获取目标用户在所有与Bot的私聊中的消息, 即使该私聊没有被明确配置在下面的 chat_ids 中。
|
||||||
|
# 这有助于构建更完整的用户画像, 但可能会增加token消耗。
|
||||||
|
s4u_ignore_whitelist = false
|
||||||
chat_ids = [
|
chat_ids = [
|
||||||
["group", "169850076", "10"], # 假设这是“开发群”的ID, 从这个群里拿10条消息
|
["group", "169850076", "10"], # 假设这是“开发群”的ID, 从这个群里拿10条消息
|
||||||
["group", "1025509724", "5"], # 假设这是“产品群”的ID,拿5条
|
["group", "1025509724", "5"], # 假设这是“产品群”的ID,拿5条
|
||||||
|
|||||||
Reference in New Issue
Block a user