refactor(cross_context): 重构互通组上下文获取逻辑

重构并简化了互通组上下文的获取函数 `get_intercom_group_context`。旧的 `get_chat_history_by_group_name` 和 `get_intercom_group_context_by_name` 函数被合并和优化。

主要变更:
- 移除了冗余的 `get_chat_history_by_group_name` 函数。
- 将 `get_intercom_group_context_by_name` 重命名为 `get_intercom_group_context`,并简化了其参数。
- 增加了对黑名单模式的支持,现在可以正确地从所有聊天中排除指定会话。
- 优化了消息获取和格式化流程,现在按聊天会话分块返回消息,而不是将所有消息混合在一起排序,提高了上下文的可读性。
- 清理了代码格式和移除了未使用的导入。
This commit is contained in:
minecraft1024a
2025-10-11 20:56:15 +08:00
parent 3040000531
commit 1fb01ef8a5
2 changed files with 78 additions and 113 deletions

View File

@@ -9,7 +9,6 @@ from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.utils.chat_message_builder import (
build_readable_messages_with_id,
get_raw_msg_before_timestamp_with_chat,
get_raw_msg_by_timestamp_with_chat,
)
from src.common.logger import get_logger
from src.config.config import global_config
@@ -60,7 +59,7 @@ async def build_cross_context_normal(chat_stream: ChatStream, context_group: Con
"""
cross_context_messages = []
chat_manager = get_chat_manager()
chat_infos_to_fetch = []
if context_group.mode == "blacklist":
# 黑名单模式:获取所有聊天,并排除在 chat_ids 中定义过的聊天
@@ -68,7 +67,7 @@ async def build_cross_context_normal(chat_stream: ChatStream, context_group: Con
for stream_id, stream in chat_manager.streams.items():
is_group = stream.group_info is not None
chat_type = "group" if is_group else "private"
# 安全地获取 raw_id
if is_group and stream.group_info:
raw_id = stream.group_info.group_id
@@ -86,7 +85,11 @@ async def build_cross_context_normal(chat_stream: ChatStream, context_group: Con
# 遍历待获取列表,抓取并格式化消息
for chat_info in chat_infos_to_fetch:
chat_type, chat_raw_id, limit_str = chat_info[0], chat_info[1], chat_info[2] if len(chat_info) > 2 else str(context_group.default_limit)
chat_type, chat_raw_id, limit_str = (
chat_info[0],
chat_info[1],
chat_info[2] if len(chat_info) > 2 else str(context_group.default_limit),
)
limit = int(limit_str)
is_group = chat_type == "group"
stream_id = chat_manager.get_stream_id(chat_stream.platform, chat_raw_id, is_group=is_group)
@@ -126,9 +129,7 @@ async def build_cross_context_s4u(
return ""
chat_manager = get_chat_manager()
current_chat_raw_id = (
chat_stream.group_info.group_id if chat_stream.group_info else chat_stream.user_info.user_id
)
current_chat_raw_id = chat_stream.group_info.group_id if chat_stream.group_info else chat_stream.user_info.user_id
current_type = "group" if chat_stream.group_info else "private"
# 根据模式(黑名单/白名单)决定需要处理哪些聊天
@@ -139,17 +140,17 @@ async def build_cross_context_s4u(
for stream_id, stream in chat_manager.streams.items():
if stream_id == chat_stream.stream_id:
continue # 排除当前聊天
is_group = stream.group_info is not None
chat_type = "group" if is_group else "private"
# 安全地获取 raw_id
if is_group and stream.group_info:
raw_id = stream.group_info.group_id
elif not is_group and stream.user_info:
raw_id = stream.user_info.user_id
else:
continue # 如果缺少关键信息则跳过
continue # 如果缺少关键信息则跳过
# 如果不在黑名单中,则加入处理列表
if (chat_type, str(raw_id)) not in blacklisted_ids:
@@ -184,12 +185,8 @@ async def build_cross_context_s4u(
if user_messages:
chat_name = await chat_manager.get_stream_name(stream_id) or chat_raw_id
user_name = target_user_info.get("person_name") or target_user_info.get("user_nickname") or user_id
formatted_messages, _ = await build_readable_messages_with_id(
user_messages, timestamp_mode="relative"
)
cross_context_messages.append(
f'[以下是"{user_name}""{chat_name}"的近期发言]\n{formatted_messages}'
)
formatted_messages, _ = await build_readable_messages_with_id(user_messages, timestamp_mode="relative")
cross_context_messages.append(f'[以下是"{user_name}""{chat_name}"的近期发言]\n{formatted_messages}')
except Exception as e:
logger.error(f"获取用户 {user_id} 在聊天 {chat_raw_id} 的消息失败: {e}")
@@ -197,9 +194,7 @@ async def build_cross_context_s4u(
if context_group.s4u_ignore_whitelist:
private_stream_id = chat_manager.get_stream_id(chat_stream.platform, user_id, is_group=False)
# 检查该私聊是否已在白名单中处理过
is_already_processed = any(
info[0] == "private" and info[1] == user_id for info in context_group.chat_ids
)
is_already_processed = any(info[0] == "private" and info[1] == user_id for info in context_group.chat_ids)
if private_stream_id and not is_already_processed:
try:
@@ -227,91 +222,24 @@ async def build_cross_context_s4u(
return "### 其他群聊中的聊天记录\n" + "\n\n".join(cross_context_messages) + "\n"
async def get_chat_history_by_group_name(group_name: str) -> str:
"""
根据互通组名字获取聊天记录
"""
target_group = None
for group in global_config.cross_context.groups:
if group.name == group_name:
target_group = group
break
if not target_group:
return f"找不到名为 {group_name} 的互通组。"
if not target_group.chat_ids:
return f"互通组 {group_name} 中没有配置任何聊天。"
chat_infos = target_group.chat_ids
chat_manager = get_chat_manager()
cross_context_messages = []
for chat_info in chat_infos:
chat_type, chat_raw_id, limit = chat_info[0], chat_info[1], int(chat_info[2]) if len(chat_info) > 2 else 5
is_group = chat_type == "group"
found_stream = None
for stream in chat_manager.streams.values():
if is_group:
if stream.group_info and stream.group_info.group_id == chat_raw_id:
found_stream = stream
break
else: # private
if stream.user_info and stream.user_info.user_id == chat_raw_id and not stream.group_info:
found_stream = stream
break
if not found_stream:
logger.warning(f"在已加载的聊天流中找不到ID为 {chat_raw_id} 的聊天。")
continue
stream_id = found_stream.stream_id
try:
messages = await get_raw_msg_before_timestamp_with_chat(
chat_id=stream_id,
timestamp=time.time(),
limit=limit,
)
if messages:
chat_name = await get_chat_manager().get_stream_name(stream_id) or chat_raw_id
formatted_messages, _ = await build_readable_messages_with_id(messages, timestamp_mode="relative")
cross_context_messages.append(f'[以下是来自"{chat_name}"的近期消息]\n{formatted_messages}')
except Exception as e:
logger.error(f"获取聊天 {chat_raw_id} 的消息失败: {e}")
continue
if not cross_context_messages:
return f"无法从互通组 {group_name} 中获取任何聊天记录。"
return "# 跨上下文参考\n" + "\n\n".join(cross_context_messages) + "\n"
async def get_intercom_group_context_by_name(
group_name: str, days: int = 3, limit_per_chat: int = 20, total_limit: int = 100
) -> str | None:
async def get_intercom_group_context(group_name: str, limit_per_chat: int = 20, total_limit: int = 100) -> str | None:
"""
根据互通组的名称,构建该组的聊天上下文。
支持黑白名单模式,并以分块形式返回每个聊天的消息。
Args:
group_name: 互通组的名称。
days: 获取过去多少天的消息。
limit_per_chat: 每个聊天最多获取的消息条数。
total_limit: 返回的总消息条数上限。
Returns:
如果找到匹配的组,则返回一个包含聊天记录的字符串;否则返回 None。
如果找到匹配的组并获取到消息,则返回一个包含聊天记录的字符串;否则返回 None。
"""
cross_context_config = global_config.cross_context
if not (cross_context_config and cross_context_config.enable):
return None
target_group = None
for group in cross_context_config.groups:
if group.name == group_name:
target_group = group
break
target_group = next((g for g in cross_context_config.groups if g.name == group_name), None)
if not target_group:
logger.error(f"在 cross_context 配置中未找到名为 '{group_name}' 的组。")
@@ -319,15 +247,34 @@ async def get_intercom_group_context_by_name(
chat_manager = get_chat_manager()
all_messages = []
end_time = time.time()
start_time = end_time - (days * 24 * 60 * 60)
# 1. 根据黑白名单模式确定要处理的聊天列表
chat_infos_to_fetch = []
if target_group.mode == "blacklist":
blacklisted_ids = {tuple(info[:2]) for info in target_group.chat_ids}
for stream in chat_manager.streams.values():
is_group = stream.group_info is not None
chat_type = "group" if is_group else "private"
for chat_type, chat_raw_id in target_group.chat_ids:
if is_group and stream.group_info:
raw_id = stream.group_info.group_id
elif not is_group and stream.user_info:
raw_id = stream.user_info.user_id
else:
continue
if (chat_type, str(raw_id)) not in blacklisted_ids:
chat_infos_to_fetch.append([chat_type, str(raw_id)])
else: # whitelist mode
chat_infos_to_fetch = target_group.chat_ids
# 2. 获取所有相关消息
all_messages = []
for chat_info in chat_infos_to_fetch:
chat_type, chat_raw_id = chat_info[0], chat_info[1]
is_group = chat_type == "group"
# 查找 stream
found_stream = None
# 采用与 get_chat_history_by_group_name 相同的健壮的 stream 查找方式
for stream in chat_manager.streams.values():
if is_group:
if stream.group_info and stream.group_info.group_id == chat_raw_id:
@@ -336,32 +283,50 @@ async def get_intercom_group_context_by_name(
else: # private
if stream.user_info and stream.user_info.user_id == chat_raw_id and not stream.group_info:
found_stream = stream
break
break
if not found_stream:
logger.warning(f"在已加载的聊天流中找不到ID为 {chat_raw_id} 的聊天。")
continue
stream_id = found_stream.stream_id
messages = await get_raw_msg_by_timestamp_with_chat(
chat_id=stream_id,
timestamp_start=start_time,
timestamp_end=end_time,
limit=limit_per_chat,
limit_mode="latest",
)
all_messages.extend(messages)
try:
messages = await get_raw_msg_before_timestamp_with_chat(
chat_id=stream_id,
timestamp=time.time(),
limit=limit_per_chat,
)
if messages:
# 为每条消息附加 stream_id 以便后续分组
for msg in messages:
msg["_stream_id"] = stream_id
all_messages.extend(messages)
except Exception as e:
logger.error(f"获取聊天 {chat_raw_id} 的消息失败: {e}")
if not all_messages:
return None
# 按时间戳对所有消息进行排序
# 3. 应用总数限制
all_messages.sort(key=lambda x: x.get("time", 0))
# 限制总消息数
if len(all_messages) > total_limit:
all_messages = all_messages[-total_limit:]
# build_readable_messages_with_id 返回一个元组 (formatted_string, message_id_list)
formatted_string, _ = await build_readable_messages_with_id(all_messages)
return formatted_string
# 4. 按聊天分组并格式化
messages_by_stream = {}
for msg in all_messages:
stream_id = msg.get("_stream_id")
if stream_id not in messages_by_stream:
messages_by_stream[stream_id] = []
messages_by_stream[stream_id].append(msg)
cross_context_messages = []
for stream_id, messages in messages_by_stream.items():
if messages:
chat_name = await chat_manager.get_stream_name(stream_id) or "未知聊天"
formatted_messages, _ = await build_readable_messages_with_id(messages, timestamp_mode="relative")
cross_context_messages.append(f'[以下是来自"{chat_name}"的近期消息]\n{formatted_messages}')
if not cross_context_messages:
return None
return "# 跨上下文参考\n" + "\n\n".join(cross_context_messages) + "\n"

View File

@@ -190,7 +190,7 @@ class QZoneService:
获取互通组的聊天上下文。
"""
# 实际的逻辑已迁移到 cross_context_api
return await cross_context_api.get_intercom_group_context_by_name("maizone_context_group")
return await cross_context_api.get_intercom_group_context("maizone_context_group")
async def _reply_to_own_feed_comments(self, feed: dict, api_client: dict):
"""处理对自己说说的评论并进行回复"""