refactor(cross_context): 重构互通组上下文获取逻辑

重构并简化了互通组上下文的获取函数 `get_intercom_group_context`。旧的 `get_chat_history_by_group_name` 和 `get_intercom_group_context_by_name` 函数被合并和优化。

主要变更:
- 移除了冗余的 `get_chat_history_by_group_name` 函数。
- 将 `get_intercom_group_context_by_name` 重命名为 `get_intercom_group_context`,并简化了其参数。
- 增加了对黑名单模式的支持,现在可以正确地从所有聊天中排除指定会话。
- 优化了消息获取和格式化流程,现在按聊天会话分块返回消息,而不是将所有消息混合在一起排序,提高了上下文的可读性。
- 清理了代码格式和移除了未使用的导入。
This commit is contained in:
minecraft1024a
2025-10-11 20:56:15 +08:00
parent 3040000531
commit 1fb01ef8a5
2 changed files with 78 additions and 113 deletions

View File

@@ -9,7 +9,6 @@ from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.utils.chat_message_builder import ( from src.chat.utils.chat_message_builder import (
build_readable_messages_with_id, build_readable_messages_with_id,
get_raw_msg_before_timestamp_with_chat, get_raw_msg_before_timestamp_with_chat,
get_raw_msg_by_timestamp_with_chat,
) )
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config
@@ -86,7 +85,11 @@ async def build_cross_context_normal(chat_stream: ChatStream, context_group: Con
# 遍历待获取列表,抓取并格式化消息 # 遍历待获取列表,抓取并格式化消息
for chat_info in chat_infos_to_fetch: for chat_info in chat_infos_to_fetch:
chat_type, chat_raw_id, limit_str = chat_info[0], chat_info[1], chat_info[2] if len(chat_info) > 2 else str(context_group.default_limit) chat_type, chat_raw_id, limit_str = (
chat_info[0],
chat_info[1],
chat_info[2] if len(chat_info) > 2 else str(context_group.default_limit),
)
limit = int(limit_str) limit = int(limit_str)
is_group = chat_type == "group" is_group = chat_type == "group"
stream_id = chat_manager.get_stream_id(chat_stream.platform, chat_raw_id, is_group=is_group) stream_id = chat_manager.get_stream_id(chat_stream.platform, chat_raw_id, is_group=is_group)
@@ -126,9 +129,7 @@ async def build_cross_context_s4u(
return "" return ""
chat_manager = get_chat_manager() chat_manager = get_chat_manager()
current_chat_raw_id = ( current_chat_raw_id = chat_stream.group_info.group_id if chat_stream.group_info else chat_stream.user_info.user_id
chat_stream.group_info.group_id if chat_stream.group_info else chat_stream.user_info.user_id
)
current_type = "group" if chat_stream.group_info else "private" current_type = "group" if chat_stream.group_info else "private"
# 根据模式(黑名单/白名单)决定需要处理哪些聊天 # 根据模式(黑名单/白名单)决定需要处理哪些聊天
@@ -184,12 +185,8 @@ async def build_cross_context_s4u(
if user_messages: if user_messages:
chat_name = await chat_manager.get_stream_name(stream_id) or chat_raw_id chat_name = await chat_manager.get_stream_name(stream_id) or chat_raw_id
user_name = target_user_info.get("person_name") or target_user_info.get("user_nickname") or user_id user_name = target_user_info.get("person_name") or target_user_info.get("user_nickname") or user_id
formatted_messages, _ = await build_readable_messages_with_id( formatted_messages, _ = await build_readable_messages_with_id(user_messages, timestamp_mode="relative")
user_messages, timestamp_mode="relative" cross_context_messages.append(f'[以下是"{user_name}""{chat_name}"的近期发言]\n{formatted_messages}')
)
cross_context_messages.append(
f'[以下是"{user_name}""{chat_name}"的近期发言]\n{formatted_messages}'
)
except Exception as e: except Exception as e:
logger.error(f"获取用户 {user_id} 在聊天 {chat_raw_id} 的消息失败: {e}") logger.error(f"获取用户 {user_id} 在聊天 {chat_raw_id} 的消息失败: {e}")
@@ -197,9 +194,7 @@ async def build_cross_context_s4u(
if context_group.s4u_ignore_whitelist: if context_group.s4u_ignore_whitelist:
private_stream_id = chat_manager.get_stream_id(chat_stream.platform, user_id, is_group=False) private_stream_id = chat_manager.get_stream_id(chat_stream.platform, user_id, is_group=False)
# 检查该私聊是否已在白名单中处理过 # 检查该私聊是否已在白名单中处理过
is_already_processed = any( is_already_processed = any(info[0] == "private" and info[1] == user_id for info in context_group.chat_ids)
info[0] == "private" and info[1] == user_id for info in context_group.chat_ids
)
if private_stream_id and not is_already_processed: if private_stream_id and not is_already_processed:
try: try:
@@ -227,91 +222,24 @@ async def build_cross_context_s4u(
return "### 其他群聊中的聊天记录\n" + "\n\n".join(cross_context_messages) + "\n" return "### 其他群聊中的聊天记录\n" + "\n\n".join(cross_context_messages) + "\n"
async def get_chat_history_by_group_name(group_name: str) -> str: async def get_intercom_group_context(group_name: str, limit_per_chat: int = 20, total_limit: int = 100) -> str | None:
"""
根据互通组名字获取聊天记录
"""
target_group = None
for group in global_config.cross_context.groups:
if group.name == group_name:
target_group = group
break
if not target_group:
return f"找不到名为 {group_name} 的互通组。"
if not target_group.chat_ids:
return f"互通组 {group_name} 中没有配置任何聊天。"
chat_infos = target_group.chat_ids
chat_manager = get_chat_manager()
cross_context_messages = []
for chat_info in chat_infos:
chat_type, chat_raw_id, limit = chat_info[0], chat_info[1], int(chat_info[2]) if len(chat_info) > 2 else 5
is_group = chat_type == "group"
found_stream = None
for stream in chat_manager.streams.values():
if is_group:
if stream.group_info and stream.group_info.group_id == chat_raw_id:
found_stream = stream
break
else: # private
if stream.user_info and stream.user_info.user_id == chat_raw_id and not stream.group_info:
found_stream = stream
break
if not found_stream:
logger.warning(f"在已加载的聊天流中找不到ID为 {chat_raw_id} 的聊天。")
continue
stream_id = found_stream.stream_id
try:
messages = await get_raw_msg_before_timestamp_with_chat(
chat_id=stream_id,
timestamp=time.time(),
limit=limit,
)
if messages:
chat_name = await get_chat_manager().get_stream_name(stream_id) or chat_raw_id
formatted_messages, _ = await build_readable_messages_with_id(messages, timestamp_mode="relative")
cross_context_messages.append(f'[以下是来自"{chat_name}"的近期消息]\n{formatted_messages}')
except Exception as e:
logger.error(f"获取聊天 {chat_raw_id} 的消息失败: {e}")
continue
if not cross_context_messages:
return f"无法从互通组 {group_name} 中获取任何聊天记录。"
return "# 跨上下文参考\n" + "\n\n".join(cross_context_messages) + "\n"
async def get_intercom_group_context_by_name(
group_name: str, days: int = 3, limit_per_chat: int = 20, total_limit: int = 100
) -> str | None:
""" """
根据互通组的名称,构建该组的聊天上下文。 根据互通组的名称,构建该组的聊天上下文。
支持黑白名单模式,并以分块形式返回每个聊天的消息。
Args: Args:
group_name: 互通组的名称。 group_name: 互通组的名称。
days: 获取过去多少天的消息。
limit_per_chat: 每个聊天最多获取的消息条数。 limit_per_chat: 每个聊天最多获取的消息条数。
total_limit: 返回的总消息条数上限。 total_limit: 返回的总消息条数上限。
Returns: Returns:
如果找到匹配的组,则返回一个包含聊天记录的字符串;否则返回 None。 如果找到匹配的组并获取到消息,则返回一个包含聊天记录的字符串;否则返回 None。
""" """
cross_context_config = global_config.cross_context cross_context_config = global_config.cross_context
if not (cross_context_config and cross_context_config.enable): if not (cross_context_config and cross_context_config.enable):
return None return None
target_group = None target_group = next((g for g in cross_context_config.groups if g.name == group_name), None)
for group in cross_context_config.groups:
if group.name == group_name:
target_group = group
break
if not target_group: if not target_group:
logger.error(f"在 cross_context 配置中未找到名为 '{group_name}' 的组。") logger.error(f"在 cross_context 配置中未找到名为 '{group_name}' 的组。")
@@ -319,15 +247,34 @@ async def get_intercom_group_context_by_name(
chat_manager = get_chat_manager() chat_manager = get_chat_manager()
all_messages = [] # 1. 根据黑白名单模式确定要处理的聊天列表
end_time = time.time() chat_infos_to_fetch = []
start_time = end_time - (days * 24 * 60 * 60) if target_group.mode == "blacklist":
blacklisted_ids = {tuple(info[:2]) for info in target_group.chat_ids}
for stream in chat_manager.streams.values():
is_group = stream.group_info is not None
chat_type = "group" if is_group else "private"
for chat_type, chat_raw_id in target_group.chat_ids: if is_group and stream.group_info:
raw_id = stream.group_info.group_id
elif not is_group and stream.user_info:
raw_id = stream.user_info.user_id
else:
continue
if (chat_type, str(raw_id)) not in blacklisted_ids:
chat_infos_to_fetch.append([chat_type, str(raw_id)])
else: # whitelist mode
chat_infos_to_fetch = target_group.chat_ids
# 2. 获取所有相关消息
all_messages = []
for chat_info in chat_infos_to_fetch:
chat_type, chat_raw_id = chat_info[0], chat_info[1]
is_group = chat_type == "group" is_group = chat_type == "group"
# 查找 stream
found_stream = None found_stream = None
# 采用与 get_chat_history_by_group_name 相同的健壮的 stream 查找方式
for stream in chat_manager.streams.values(): for stream in chat_manager.streams.values():
if is_group: if is_group:
if stream.group_info and stream.group_info.group_id == chat_raw_id: if stream.group_info and stream.group_info.group_id == chat_raw_id:
@@ -337,31 +284,49 @@ async def get_intercom_group_context_by_name(
if stream.user_info and stream.user_info.user_id == chat_raw_id and not stream.group_info: if stream.user_info and stream.user_info.user_id == chat_raw_id and not stream.group_info:
found_stream = stream found_stream = stream
break break
if not found_stream: if not found_stream:
logger.warning(f"在已加载的聊天流中找不到ID为 {chat_raw_id} 的聊天。") logger.warning(f"在已加载的聊天流中找不到ID为 {chat_raw_id} 的聊天。")
continue continue
stream_id = found_stream.stream_id stream_id = found_stream.stream_id
messages = await get_raw_msg_by_timestamp_with_chat(
try:
messages = await get_raw_msg_before_timestamp_with_chat(
chat_id=stream_id, chat_id=stream_id,
timestamp_start=start_time, timestamp=time.time(),
timestamp_end=end_time,
limit=limit_per_chat, limit=limit_per_chat,
limit_mode="latest",
) )
if messages:
# 为每条消息附加 stream_id 以便后续分组
for msg in messages:
msg["_stream_id"] = stream_id
all_messages.extend(messages) all_messages.extend(messages)
except Exception as e:
logger.error(f"获取聊天 {chat_raw_id} 的消息失败: {e}")
if not all_messages: if not all_messages:
return None return None
# 按时间戳对所有消息进行排序 # 3. 应用总数限制
all_messages.sort(key=lambda x: x.get("time", 0)) all_messages.sort(key=lambda x: x.get("time", 0))
# 限制总消息数
if len(all_messages) > total_limit: if len(all_messages) > total_limit:
all_messages = all_messages[-total_limit:] all_messages = all_messages[-total_limit:]
# build_readable_messages_with_id 返回一个元组 (formatted_string, message_id_list) # 4. 按聊天分组并格式化
formatted_string, _ = await build_readable_messages_with_id(all_messages) messages_by_stream = {}
return formatted_string for msg in all_messages:
stream_id = msg.get("_stream_id")
if stream_id not in messages_by_stream:
messages_by_stream[stream_id] = []
messages_by_stream[stream_id].append(msg)
cross_context_messages = []
for stream_id, messages in messages_by_stream.items():
if messages:
chat_name = await chat_manager.get_stream_name(stream_id) or "未知聊天"
formatted_messages, _ = await build_readable_messages_with_id(messages, timestamp_mode="relative")
cross_context_messages.append(f'[以下是来自"{chat_name}"的近期消息]\n{formatted_messages}')
if not cross_context_messages:
return None
return "# 跨上下文参考\n" + "\n\n".join(cross_context_messages) + "\n"

View File

@@ -190,7 +190,7 @@ class QZoneService:
获取互通组的聊天上下文。 获取互通组的聊天上下文。
""" """
# 实际的逻辑已迁移到 cross_context_api # 实际的逻辑已迁移到 cross_context_api
return await cross_context_api.get_intercom_group_context_by_name("maizone_context_group") return await cross_context_api.get_intercom_group_context("maizone_context_group")
async def _reply_to_own_feed_comments(self, feed: dict, api_client: dict): async def _reply_to_own_feed_comments(self, feed: dict, api_client: dict):
"""处理对自己说说的评论并进行回复""" """处理对自己说说的评论并进行回复"""