refactor(cross_context): 重构互通组上下文获取逻辑
重构并简化了互通组上下文的获取函数 `get_intercom_group_context`。旧的 `get_chat_history_by_group_name` 和 `get_intercom_group_context_by_name` 函数被合并和优化。 主要变更: - 移除了冗余的 `get_chat_history_by_group_name` 函数。 - 将 `get_intercom_group_context_by_name` 重命名为 `get_intercom_group_context`,并简化了其参数。 - 增加了对黑名单模式的支持,现在可以正确地从所有聊天中排除指定会话。 - 优化了消息获取和格式化流程,现在按聊天会话分块返回消息,而不是将所有消息混合在一起排序,提高了上下文的可读性。 - 清理了代码格式和移除了未使用的导入。
This commit is contained in:
@@ -9,7 +9,6 @@ from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages_with_id,
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
get_raw_msg_by_timestamp_with_chat,
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
@@ -60,7 +59,7 @@ async def build_cross_context_normal(chat_stream: ChatStream, context_group: Con
|
||||
"""
|
||||
cross_context_messages = []
|
||||
chat_manager = get_chat_manager()
|
||||
|
||||
|
||||
chat_infos_to_fetch = []
|
||||
if context_group.mode == "blacklist":
|
||||
# 黑名单模式:获取所有聊天,并排除在 chat_ids 中定义过的聊天
|
||||
@@ -68,7 +67,7 @@ async def build_cross_context_normal(chat_stream: ChatStream, context_group: Con
|
||||
for stream_id, stream in chat_manager.streams.items():
|
||||
is_group = stream.group_info is not None
|
||||
chat_type = "group" if is_group else "private"
|
||||
|
||||
|
||||
# 安全地获取 raw_id
|
||||
if is_group and stream.group_info:
|
||||
raw_id = stream.group_info.group_id
|
||||
@@ -86,7 +85,11 @@ async def build_cross_context_normal(chat_stream: ChatStream, context_group: Con
|
||||
|
||||
# 遍历待获取列表,抓取并格式化消息
|
||||
for chat_info in chat_infos_to_fetch:
|
||||
chat_type, chat_raw_id, limit_str = chat_info[0], chat_info[1], chat_info[2] if len(chat_info) > 2 else str(context_group.default_limit)
|
||||
chat_type, chat_raw_id, limit_str = (
|
||||
chat_info[0],
|
||||
chat_info[1],
|
||||
chat_info[2] if len(chat_info) > 2 else str(context_group.default_limit),
|
||||
)
|
||||
limit = int(limit_str)
|
||||
is_group = chat_type == "group"
|
||||
stream_id = chat_manager.get_stream_id(chat_stream.platform, chat_raw_id, is_group=is_group)
|
||||
@@ -126,9 +129,7 @@ async def build_cross_context_s4u(
|
||||
return ""
|
||||
|
||||
chat_manager = get_chat_manager()
|
||||
current_chat_raw_id = (
|
||||
chat_stream.group_info.group_id if chat_stream.group_info else chat_stream.user_info.user_id
|
||||
)
|
||||
current_chat_raw_id = chat_stream.group_info.group_id if chat_stream.group_info else chat_stream.user_info.user_id
|
||||
current_type = "group" if chat_stream.group_info else "private"
|
||||
|
||||
# 根据模式(黑名单/白名单)决定需要处理哪些聊天
|
||||
@@ -139,17 +140,17 @@ async def build_cross_context_s4u(
|
||||
for stream_id, stream in chat_manager.streams.items():
|
||||
if stream_id == chat_stream.stream_id:
|
||||
continue # 排除当前聊天
|
||||
|
||||
|
||||
is_group = stream.group_info is not None
|
||||
chat_type = "group" if is_group else "private"
|
||||
|
||||
|
||||
# 安全地获取 raw_id
|
||||
if is_group and stream.group_info:
|
||||
raw_id = stream.group_info.group_id
|
||||
elif not is_group and stream.user_info:
|
||||
raw_id = stream.user_info.user_id
|
||||
else:
|
||||
continue # 如果缺少关键信息则跳过
|
||||
continue # 如果缺少关键信息则跳过
|
||||
|
||||
# 如果不在黑名单中,则加入处理列表
|
||||
if (chat_type, str(raw_id)) not in blacklisted_ids:
|
||||
@@ -184,12 +185,8 @@ async def build_cross_context_s4u(
|
||||
if user_messages:
|
||||
chat_name = await chat_manager.get_stream_name(stream_id) or chat_raw_id
|
||||
user_name = target_user_info.get("person_name") or target_user_info.get("user_nickname") or user_id
|
||||
formatted_messages, _ = await build_readable_messages_with_id(
|
||||
user_messages, timestamp_mode="relative"
|
||||
)
|
||||
cross_context_messages.append(
|
||||
f'[以下是"{user_name}"在"{chat_name}"的近期发言]\n{formatted_messages}'
|
||||
)
|
||||
formatted_messages, _ = await build_readable_messages_with_id(user_messages, timestamp_mode="relative")
|
||||
cross_context_messages.append(f'[以下是"{user_name}"在"{chat_name}"的近期发言]\n{formatted_messages}')
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户 {user_id} 在聊天 {chat_raw_id} 的消息失败: {e}")
|
||||
|
||||
@@ -197,9 +194,7 @@ async def build_cross_context_s4u(
|
||||
if context_group.s4u_ignore_whitelist:
|
||||
private_stream_id = chat_manager.get_stream_id(chat_stream.platform, user_id, is_group=False)
|
||||
# 检查该私聊是否已在白名单中处理过
|
||||
is_already_processed = any(
|
||||
info[0] == "private" and info[1] == user_id for info in context_group.chat_ids
|
||||
)
|
||||
is_already_processed = any(info[0] == "private" and info[1] == user_id for info in context_group.chat_ids)
|
||||
|
||||
if private_stream_id and not is_already_processed:
|
||||
try:
|
||||
@@ -227,91 +222,24 @@ async def build_cross_context_s4u(
|
||||
return "### 其他群聊中的聊天记录\n" + "\n\n".join(cross_context_messages) + "\n"
|
||||
|
||||
|
||||
async def get_chat_history_by_group_name(group_name: str) -> str:
|
||||
"""
|
||||
根据互通组名字获取聊天记录
|
||||
"""
|
||||
target_group = None
|
||||
for group in global_config.cross_context.groups:
|
||||
if group.name == group_name:
|
||||
target_group = group
|
||||
break
|
||||
|
||||
if not target_group:
|
||||
return f"找不到名为 {group_name} 的互通组。"
|
||||
|
||||
if not target_group.chat_ids:
|
||||
return f"互通组 {group_name} 中没有配置任何聊天。"
|
||||
|
||||
chat_infos = target_group.chat_ids
|
||||
chat_manager = get_chat_manager()
|
||||
|
||||
cross_context_messages = []
|
||||
for chat_info in chat_infos:
|
||||
chat_type, chat_raw_id, limit = chat_info[0], chat_info[1], int(chat_info[2]) if len(chat_info) > 2 else 5
|
||||
is_group = chat_type == "group"
|
||||
|
||||
found_stream = None
|
||||
for stream in chat_manager.streams.values():
|
||||
if is_group:
|
||||
if stream.group_info and stream.group_info.group_id == chat_raw_id:
|
||||
found_stream = stream
|
||||
break
|
||||
else: # private
|
||||
if stream.user_info and stream.user_info.user_id == chat_raw_id and not stream.group_info:
|
||||
found_stream = stream
|
||||
break
|
||||
|
||||
if not found_stream:
|
||||
logger.warning(f"在已加载的聊天流中找不到ID为 {chat_raw_id} 的聊天。")
|
||||
continue
|
||||
|
||||
stream_id = found_stream.stream_id
|
||||
|
||||
try:
|
||||
messages = await get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=limit,
|
||||
)
|
||||
if messages:
|
||||
chat_name = await get_chat_manager().get_stream_name(stream_id) or chat_raw_id
|
||||
formatted_messages, _ = await build_readable_messages_with_id(messages, timestamp_mode="relative")
|
||||
cross_context_messages.append(f'[以下是来自"{chat_name}"的近期消息]\n{formatted_messages}')
|
||||
except Exception as e:
|
||||
logger.error(f"获取聊天 {chat_raw_id} 的消息失败: {e}")
|
||||
continue
|
||||
|
||||
if not cross_context_messages:
|
||||
return f"无法从互通组 {group_name} 中获取任何聊天记录。"
|
||||
|
||||
return "# 跨上下文参考\n" + "\n\n".join(cross_context_messages) + "\n"
|
||||
|
||||
|
||||
async def get_intercom_group_context_by_name(
|
||||
group_name: str, days: int = 3, limit_per_chat: int = 20, total_limit: int = 100
|
||||
) -> str | None:
|
||||
async def get_intercom_group_context(group_name: str, limit_per_chat: int = 20, total_limit: int = 100) -> str | None:
|
||||
"""
|
||||
根据互通组的名称,构建该组的聊天上下文。
|
||||
支持黑白名单模式,并以分块形式返回每个聊天的消息。
|
||||
|
||||
Args:
|
||||
group_name: 互通组的名称。
|
||||
days: 获取过去多少天的消息。
|
||||
limit_per_chat: 每个聊天最多获取的消息条数。
|
||||
total_limit: 返回的总消息条数上限。
|
||||
|
||||
Returns:
|
||||
如果找到匹配的组,则返回一个包含聊天记录的字符串;否则返回 None。
|
||||
如果找到匹配的组并获取到消息,则返回一个包含聊天记录的字符串;否则返回 None。
|
||||
"""
|
||||
cross_context_config = global_config.cross_context
|
||||
if not (cross_context_config and cross_context_config.enable):
|
||||
return None
|
||||
|
||||
target_group = None
|
||||
for group in cross_context_config.groups:
|
||||
if group.name == group_name:
|
||||
target_group = group
|
||||
break
|
||||
target_group = next((g for g in cross_context_config.groups if g.name == group_name), None)
|
||||
|
||||
if not target_group:
|
||||
logger.error(f"在 cross_context 配置中未找到名为 '{group_name}' 的组。")
|
||||
@@ -319,15 +247,34 @@ async def get_intercom_group_context_by_name(
|
||||
|
||||
chat_manager = get_chat_manager()
|
||||
|
||||
all_messages = []
|
||||
end_time = time.time()
|
||||
start_time = end_time - (days * 24 * 60 * 60)
|
||||
# 1. 根据黑白名单模式确定要处理的聊天列表
|
||||
chat_infos_to_fetch = []
|
||||
if target_group.mode == "blacklist":
|
||||
blacklisted_ids = {tuple(info[:2]) for info in target_group.chat_ids}
|
||||
for stream in chat_manager.streams.values():
|
||||
is_group = stream.group_info is not None
|
||||
chat_type = "group" if is_group else "private"
|
||||
|
||||
for chat_type, chat_raw_id in target_group.chat_ids:
|
||||
if is_group and stream.group_info:
|
||||
raw_id = stream.group_info.group_id
|
||||
elif not is_group and stream.user_info:
|
||||
raw_id = stream.user_info.user_id
|
||||
else:
|
||||
continue
|
||||
|
||||
if (chat_type, str(raw_id)) not in blacklisted_ids:
|
||||
chat_infos_to_fetch.append([chat_type, str(raw_id)])
|
||||
else: # whitelist mode
|
||||
chat_infos_to_fetch = target_group.chat_ids
|
||||
|
||||
# 2. 获取所有相关消息
|
||||
all_messages = []
|
||||
for chat_info in chat_infos_to_fetch:
|
||||
chat_type, chat_raw_id = chat_info[0], chat_info[1]
|
||||
is_group = chat_type == "group"
|
||||
|
||||
# 查找 stream
|
||||
found_stream = None
|
||||
# 采用与 get_chat_history_by_group_name 相同的健壮的 stream 查找方式
|
||||
for stream in chat_manager.streams.values():
|
||||
if is_group:
|
||||
if stream.group_info and stream.group_info.group_id == chat_raw_id:
|
||||
@@ -336,32 +283,50 @@ async def get_intercom_group_context_by_name(
|
||||
else: # private
|
||||
if stream.user_info and stream.user_info.user_id == chat_raw_id and not stream.group_info:
|
||||
found_stream = stream
|
||||
break
|
||||
|
||||
break
|
||||
if not found_stream:
|
||||
logger.warning(f"在已加载的聊天流中找不到ID为 {chat_raw_id} 的聊天。")
|
||||
continue
|
||||
|
||||
stream_id = found_stream.stream_id
|
||||
messages = await get_raw_msg_by_timestamp_with_chat(
|
||||
chat_id=stream_id,
|
||||
timestamp_start=start_time,
|
||||
timestamp_end=end_time,
|
||||
limit=limit_per_chat,
|
||||
limit_mode="latest",
|
||||
)
|
||||
all_messages.extend(messages)
|
||||
|
||||
try:
|
||||
messages = await get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=limit_per_chat,
|
||||
)
|
||||
if messages:
|
||||
# 为每条消息附加 stream_id 以便后续分组
|
||||
for msg in messages:
|
||||
msg["_stream_id"] = stream_id
|
||||
all_messages.extend(messages)
|
||||
except Exception as e:
|
||||
logger.error(f"获取聊天 {chat_raw_id} 的消息失败: {e}")
|
||||
|
||||
if not all_messages:
|
||||
return None
|
||||
|
||||
# 按时间戳对所有消息进行排序
|
||||
# 3. 应用总数限制
|
||||
all_messages.sort(key=lambda x: x.get("time", 0))
|
||||
|
||||
# 限制总消息数
|
||||
if len(all_messages) > total_limit:
|
||||
all_messages = all_messages[-total_limit:]
|
||||
|
||||
# build_readable_messages_with_id 返回一个元组 (formatted_string, message_id_list)
|
||||
formatted_string, _ = await build_readable_messages_with_id(all_messages)
|
||||
return formatted_string
|
||||
# 4. 按聊天分组并格式化
|
||||
messages_by_stream = {}
|
||||
for msg in all_messages:
|
||||
stream_id = msg.get("_stream_id")
|
||||
if stream_id not in messages_by_stream:
|
||||
messages_by_stream[stream_id] = []
|
||||
messages_by_stream[stream_id].append(msg)
|
||||
|
||||
cross_context_messages = []
|
||||
for stream_id, messages in messages_by_stream.items():
|
||||
if messages:
|
||||
chat_name = await chat_manager.get_stream_name(stream_id) or "未知聊天"
|
||||
formatted_messages, _ = await build_readable_messages_with_id(messages, timestamp_mode="relative")
|
||||
cross_context_messages.append(f'[以下是来自"{chat_name}"的近期消息]\n{formatted_messages}')
|
||||
|
||||
if not cross_context_messages:
|
||||
return None
|
||||
|
||||
return "# 跨上下文参考\n" + "\n\n".join(cross_context_messages) + "\n"
|
||||
|
||||
@@ -190,7 +190,7 @@ class QZoneService:
|
||||
获取互通组的聊天上下文。
|
||||
"""
|
||||
# 实际的逻辑已迁移到 cross_context_api
|
||||
return await cross_context_api.get_intercom_group_context_by_name("maizone_context_group")
|
||||
return await cross_context_api.get_intercom_group_context("maizone_context_group")
|
||||
|
||||
async def _reply_to_own_feed_comments(self, feed: dict, api_client: dict):
|
||||
"""处理对自己说说的评论并进行回复"""
|
||||
|
||||
Reference in New Issue
Block a user