feat(cross_context): 新增以用户为中心的跨上下文检索功能

引入了一种全新的“用户中心”跨上下文检索模式,以替代并废弃了原有的固定共享组模式。

当回复特定用户时,系统现在可以自动从该用户参与的其他聊天(包括私聊和群聊)中检索其最近的发言记录,从而为大语言模型提供更丰富、更具个性化的上下文,以生成更相关的回复。

此功能可通过配置进行精细化控制,支持“全局启用”、“白名单”和“禁用”三种模式,并可设置检索的消息数量和聊天流数量上限。

此外,本次更新还包含一些健壮性修复:
- 修正了事件管理器返回结果可能为None时导致属性错误的潜在问题。
- 增强了对消息内容和用户昵称等可能为空值的处理。

BREAKING CHANGE: `cross_context` 的配置结构已完全重构。原有的 `groups` 配置项已被废弃。请用户根据新的 `bot_config_template.toml` 文件更新配置,迁移到新的 `user_centric_retrieval_mode`、`whitelist_chats` 和 `blacklist_chats` 格式。
This commit is contained in:
tt-P607
2025-10-10 19:08:30 +08:00
committed by Windpicker-owo
parent 379b3b3bac
commit cde8b25d98
4 changed files with 176 additions and 157 deletions

View File

@@ -33,14 +33,15 @@ async def get_context_groups(chat_id: str) -> list[list[str]] | None:
current_chat_raw_id = current_stream.user_info.user_id
current_type = "group" if is_group else "private"
for group in global_config.cross_context.groups:
# 检查当前聊天的ID和类型是否在组的chat_ids中
if [current_type, str(current_chat_raw_id)] in group.chat_ids:
# 排除maizone专用组
if group.name == "maizone_context_group":
continue
# 返回组内其他聊天的 [type, id] 列表
return [chat_info for chat_info in group.chat_ids if chat_info != [current_type, str(current_chat_raw_id)]]
# This feature is deprecated
# for group in global_config.cross_context.groups:
# # 检查当前聊天的ID和类型是否在组的chat_ids
# if [current_type, str(current_chat_raw_id)] in group.chat_ids:
# # 排除maizone专用组
# if group.name == "maizone_context_group":
# continue
# # 返回组内其他聊天的 [type, id] 列表
# return [chat_info for chat_info in group.chat_ids if chat_info != [current_type, str(current_chat_raw_id)]]
return None
@@ -124,140 +125,86 @@ async def build_cross_context_s4u(
return "### 其他群聊中的聊天记录\n" + "\n\n".join(cross_context_messages) + "\n"
async def get_chat_history_by_group_name(group_name: str) -> str:
"""
根据互通组名字获取聊天记录
"""
target_group = None
for group in global_config.cross_context.groups:
if group.name == group_name:
target_group = group
break
if not target_group:
return f"找不到名为 {group_name} 的互通组。"
if not target_group.chat_ids:
return f"互通组 {group_name} 中没有配置任何聊天。"
chat_infos = target_group.chat_ids
chat_manager = get_chat_manager()
cross_context_messages = []
for chat_type, chat_raw_id in chat_infos:
is_group = chat_type == "group"
found_stream = None
for stream in chat_manager.streams.values():
if is_group:
if stream.group_info and stream.group_info.group_id == chat_raw_id:
found_stream = stream
break
else: # private
if stream.user_info and stream.user_info.user_id == chat_raw_id and not stream.group_info:
found_stream = stream
break
if not found_stream:
logger.warning(f"在已加载的聊天流中找不到ID为 {chat_raw_id} 的聊天。")
continue
stream_id = found_stream.stream_id
try:
messages = await get_raw_msg_before_timestamp_with_chat(
chat_id=stream_id,
timestamp=time.time(),
limit=5, # 可配置
)
if messages:
chat_name = await get_chat_manager().get_stream_name(stream_id) or chat_raw_id
formatted_messages, _ = await build_readable_messages_with_id(messages, timestamp_mode="relative")
cross_context_messages.append(f'[以下是来自"{chat_name}"的近期消息]\n{formatted_messages}')
except Exception as e:
logger.error(f"获取聊天 {chat_raw_id} 的消息失败: {e}")
continue
if not cross_context_messages:
return f"无法从互通组 {group_name} 中获取任何聊天记录。"
return "# 跨上下文参考\n" + "\n\n".join(cross_context_messages) + "\n"
async def get_intercom_group_context_by_name(
group_name: str, days: int = 3, limit_per_chat: int = 20, total_limit: int = 100
async def get_user_centric_context(
user_id: str, platform: str, limit: int, exclude_chat_id: str | None = None
) -> str | None:
"""
根据互通组的名称,构建该组的聊天上下文
获取以用户为中心的全域聊天记录
Args:
group_name: 互通组的名称
days: 获取过去多少天的消息
limit_per_chat: 每个聊天最多获取的消息条数
total_limit: 返回的总消息条数上限
user_id: 目标用户的ID
platform: 用户所在的平台
limit: 每个聊天获取的最大消息数量
exclude_chat_id: 需要排除的当前聊天ID
Returns:
如果找到匹配的组,则返回一个包含聊天记录的字符串;否则返回 None。
构建好的上下文信息字符串,如果没有找到则返回None。
"""
cross_context_config = global_config.cross_context
if not (cross_context_config and cross_context_config.enable):
return None
target_group = None
for group in cross_context_config.groups:
if group.name == group_name:
target_group = group
break
if not target_group:
logger.error(f"在 cross_context 配置中未找到名为 '{group_name}' 的组。")
return None
chat_manager = get_chat_manager()
user_messages_map = {}
all_messages = []
end_time = time.time()
start_time = end_time - (days * 24 * 60 * 60)
# 遍历所有相关的聊天流
streams_to_search = []
private_stream = None
group_streams = []
for chat_type, chat_raw_id in target_group.chat_ids:
is_group = chat_type == "group"
found_stream = None
# 采用与 get_chat_history_by_group_name 相同的健壮的 stream 查找方式
for stream in chat_manager.streams.values():
if is_group:
if stream.group_info and stream.group_info.group_id == chat_raw_id:
found_stream = stream
break
else: # private
if stream.user_info and stream.user_info.user_id == chat_raw_id and not stream.group_info:
found_stream = stream
break
if not found_stream:
logger.warning(f"在已加载的聊天流中找不到ID为 {chat_raw_id} 的聊天。")
for stream in chat_manager.streams.values():
if stream.stream_id == exclude_chat_id:
continue
stream_id = found_stream.stream_id
messages = await get_raw_msg_by_timestamp_with_chat(
chat_id=stream_id,
timestamp_start=start_time,
timestamp_end=end_time,
limit=limit_per_chat,
limit_mode="latest",
)
all_messages.extend(messages)
is_group = stream.group_info is not None
if is_group:
# 对于群聊,检查用户是否是成员之一 (通过消息记录判断)
group_streams.append(stream)
else:
# 对于私聊,检查是否是与目标用户的私聊
if stream.user_info and stream.user_info.user_id == user_id:
private_stream = stream
if not all_messages:
# 优先添加私聊流
if private_stream:
streams_to_search.append(private_stream)
# 按最近活跃时间对群聊流进行排序
group_streams.sort(key=lambda s: s.last_active_time, reverse=True)
streams_to_search.extend(group_streams)
# 应用聊天流数量限制
stream_limit = global_config.cross_context.user_centric_retrieval_stream_limit
if stream_limit > 0:
streams_to_search = streams_to_search[:stream_limit]
for stream in streams_to_search:
try:
messages = await get_raw_msg_before_timestamp_with_chat(
chat_id=stream.stream_id,
timestamp=time.time(),
limit=limit * 5, # 获取更多消息以供筛选
)
user_messages = [msg for msg in messages if msg.get("user_id") == user_id][-limit:]
if user_messages:
chat_name = await chat_manager.get_stream_name(stream.stream_id) or stream.stream_id
if chat_name not in user_messages_map:
user_messages_map[chat_name] = []
user_messages_map[chat_name].extend(user_messages)
except Exception as e:
logger.error(f"获取用户 {user_id} 在聊天 {stream.stream_id} 的消息失败: {e}")
continue
if not user_messages_map:
return None
# 按时间戳对所有消息进行排序
all_messages.sort(key=lambda x: x.get("time", 0))
# 构建最终的上下文字符串
cross_context_parts = []
for chat_name, messages in user_messages_map.items():
# 按时间戳对消息进行排序
messages.sort(key=lambda x: x.get("time", 0))
formatted_messages, _ = await build_readable_messages_with_id(messages, timestamp_mode="relative")
cross_context_parts.append(f'[以下是该用户在"{chat_name}"的近期发言]\n{formatted_messages}')
# 限制总消息数
if len(all_messages) > total_limit:
all_messages = all_messages[-total_limit:]
if not cross_context_parts:
return None
# build_readable_messages_with_id 返回一个元组 (formatted_string, message_id_list)
formatted_string, _ = await build_readable_messages_with_id(all_messages)
return formatted_string
return "### 该用户在其他地方的聊天记录\n" + "\n\n".join(cross_context_parts) + "\n"