refactor(context): 简化跨上下文功能,移除通用共享组模式
移除了基于白名单/黑名单的通用上下文共享组(ContextGroup)功能。此模式实现复杂且与S4U模式功能重叠,移除后可大幅简化配置项和内部逻辑。 主要变更: - 从配置中删除了 `ContextGroup` 模型和 `cross_context.groups` 列表。 - 删除了 `build_cross_context_normal` 和 `get_context_group` 函数。 - 保留并增强了S4U(Search for User)模式,为其增加了更详细的日志和健壮性检查。 - `get_intercom_group_context` 函数被调整为专门服务于 `maizone_context_group`。 BREAKING CHANGE: 移除了 `cross_context.groups` 配置项及相关的通用上下文共享组功能。请迁移至S4U模式以实现跨上下文需求。
This commit is contained in:
@@ -702,28 +702,6 @@ class WebSearchConfig(ValidatedConfigBase):
|
|||||||
search_strategy: Literal["fallback", "single", "parallel"] = Field(default="single", description="搜索策略")
|
search_strategy: Literal["fallback", "single", "parallel"] = Field(default="single", description="搜索策略")
|
||||||
|
|
||||||
|
|
||||||
class ContextGroup(ValidatedConfigBase):
|
|
||||||
"""
|
|
||||||
上下文共享组配置
|
|
||||||
|
|
||||||
定义了一个聊天上下文的共享范围和规则。
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: str = Field(..., description="共享组的名称,用于唯一标识一个共享组")
|
|
||||||
mode: Literal["whitelist", "blacklist"] = Field(
|
|
||||||
default="whitelist",
|
|
||||||
description="共享模式。'whitelist'表示仅共享chat_ids中列出的聊天;'blacklist'表示共享除chat_ids中列出的所有聊天。",
|
|
||||||
)
|
|
||||||
default_limit: int = Field(
|
|
||||||
default=5,
|
|
||||||
description="在'blacklist'模式下,对于未明确指定数量的聊天,默认获取的消息条数。",
|
|
||||||
)
|
|
||||||
chat_ids: list[list[str]] = Field(
|
|
||||||
...,
|
|
||||||
description='定义组内成员的列表。格式为 [["type", "id", "limit"(可选)]]。type为"group"或"private",id为群号或用户ID,limit为可选的消息条数。',
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MaizoneContextGroup(ValidatedConfigBase):
|
class MaizoneContextGroup(ValidatedConfigBase):
|
||||||
"""QQ空间专用互通组配置"""
|
"""QQ空间专用互通组配置"""
|
||||||
|
|
||||||
@@ -739,8 +717,6 @@ class CrossContextConfig(ValidatedConfigBase):
|
|||||||
|
|
||||||
enable: bool = Field(default=False, description="是否启用跨群聊上下文共享功能")
|
enable: bool = Field(default=False, description="是否启用跨群聊上下文共享功能")
|
||||||
|
|
||||||
# --- Normal模式: 共享组配置 ---
|
|
||||||
groups: list[ContextGroup] = Field(default_factory=list, description="上下文共享组列表")
|
|
||||||
# --- S4U模式: 用户中心上下文检索 ---
|
# --- S4U模式: 用户中心上下文检索 ---
|
||||||
s4u_mode: Literal["whitelist", "blacklist"] = Field(
|
s4u_mode: Literal["whitelist", "blacklist"] = Field(
|
||||||
default="whitelist",
|
default="whitelist",
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ from src.chat.utils.chat_message_builder import (
|
|||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.message_repository import get_user_messages_from_streams
|
from src.common.message_repository import get_user_messages_from_streams
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.config.official_configs import ContextGroup
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from src.chat.message_receive.chat_stream import ChatStream
|
from src.chat.message_receive.chat_stream import ChatStream
|
||||||
@@ -21,111 +20,6 @@ if TYPE_CHECKING:
|
|||||||
logger = get_logger("cross_context_api")
|
logger = get_logger("cross_context_api")
|
||||||
|
|
||||||
|
|
||||||
async def get_context_group(chat_id: str) -> ContextGroup | None:
|
|
||||||
"""
|
|
||||||
获取当前聊天所在的共享组
|
|
||||||
"""
|
|
||||||
current_stream = await get_chat_manager().get_stream(chat_id)
|
|
||||||
if not current_stream:
|
|
||||||
return None
|
|
||||||
|
|
||||||
is_group = current_stream.group_info is not None
|
|
||||||
if not is_group and not current_stream.user_info:
|
|
||||||
return None
|
|
||||||
if is_group:
|
|
||||||
assert current_stream.group_info is not None
|
|
||||||
current_chat_raw_id = current_stream.group_info.group_id
|
|
||||||
elif current_stream.user_info:
|
|
||||||
current_chat_raw_id = current_stream.user_info.user_id
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
current_type = "group" if is_group else "private"
|
|
||||||
|
|
||||||
for group in global_config.cross_context.groups:
|
|
||||||
for chat_info in group.chat_ids:
|
|
||||||
if len(chat_info) >= 2:
|
|
||||||
chat_type, chat_raw_id = chat_info[0], chat_info[1]
|
|
||||||
if chat_type == current_type and str(chat_raw_id) == str(current_chat_raw_id):
|
|
||||||
# 排除maizone专用组
|
|
||||||
if group.name == "maizone_context_group":
|
|
||||||
continue
|
|
||||||
return group
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def build_cross_context_normal(chat_stream: "ChatStream", context_group: ContextGroup) -> str:
|
|
||||||
"""
|
|
||||||
构建跨群聊/私聊上下文 (Normal模式)。
|
|
||||||
|
|
||||||
根据共享组的配置(白名单或黑名单模式),获取相关聊天的近期消息,并格式化为字符串。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_stream: 当前的聊天流对象。
|
|
||||||
context_group: 当前聊天所在的上下文共享组配置。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
一个包含格式化后的跨上下文消息的字符串,如果无消息则为空字符串。
|
|
||||||
"""
|
|
||||||
cross_context_messages = []
|
|
||||||
chat_manager = get_chat_manager()
|
|
||||||
|
|
||||||
chat_infos_to_fetch = []
|
|
||||||
if context_group.mode == "blacklist":
|
|
||||||
# 黑名单模式:获取所有聊天,并排除在 chat_ids 中定义过的聊天
|
|
||||||
blacklisted_ids = {tuple(info[:2]) for info in context_group.chat_ids}
|
|
||||||
for stream_id, stream in chat_manager.streams.items():
|
|
||||||
is_group = stream.group_info is not None
|
|
||||||
chat_type = "group" if is_group else "private"
|
|
||||||
|
|
||||||
# 安全地获取 raw_id
|
|
||||||
if is_group and stream.group_info:
|
|
||||||
raw_id = stream.group_info.group_id
|
|
||||||
elif not is_group and stream.user_info:
|
|
||||||
raw_id = stream.user_info.user_id
|
|
||||||
else:
|
|
||||||
continue # 如果缺少关键信息则跳过
|
|
||||||
|
|
||||||
# 如果当前聊天不在黑名单中,则添加到待获取列表
|
|
||||||
if (chat_type, str(raw_id)) not in blacklisted_ids:
|
|
||||||
chat_infos_to_fetch.append([chat_type, str(raw_id), str(context_group.default_limit)])
|
|
||||||
else:
|
|
||||||
# 白名单模式:直接使用配置中定义的 chat_ids
|
|
||||||
chat_infos_to_fetch = context_group.chat_ids
|
|
||||||
|
|
||||||
# 遍历待获取列表,抓取并格式化消息
|
|
||||||
for chat_info in chat_infos_to_fetch:
|
|
||||||
chat_type, chat_raw_id, limit_str = (
|
|
||||||
chat_info[0],
|
|
||||||
chat_info[1],
|
|
||||||
chat_info[2] if len(chat_info) > 2 else str(context_group.default_limit),
|
|
||||||
)
|
|
||||||
limit = int(limit_str)
|
|
||||||
is_group = chat_type == "group"
|
|
||||||
stream_id = chat_manager.get_stream_id(chat_stream.platform, chat_raw_id, is_group=is_group)
|
|
||||||
if not stream_id or stream_id == chat_stream.stream_id:
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
messages = await get_raw_msg_before_timestamp_with_chat(
|
|
||||||
chat_id=stream_id,
|
|
||||||
timestamp=time.time(),
|
|
||||||
limit=limit,
|
|
||||||
)
|
|
||||||
if messages:
|
|
||||||
chat_name = await chat_manager.get_stream_name(stream_id) or chat_raw_id
|
|
||||||
formatted_messages, _ = await build_readable_messages_with_id(messages, timestamp_mode="relative")
|
|
||||||
cross_context_messages.append(f'[以下是来自"{chat_name}"的近期消息]\n{formatted_messages}')
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取聊天 {chat_raw_id} 的消息失败: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not cross_context_messages:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
return "# 跨上下文参考\n" + "\n\n".join(cross_context_messages) + "\n"
|
|
||||||
|
|
||||||
|
|
||||||
async def build_cross_context_s4u(
|
async def build_cross_context_s4u(
|
||||||
chat_stream: "ChatStream",
|
chat_stream: "ChatStream",
|
||||||
target_user_info: dict[str, Any] | None,
|
target_user_info: dict[str, Any] | None,
|
||||||
@@ -134,32 +28,55 @@ async def build_cross_context_s4u(
|
|||||||
构建跨群聊/私聊上下文 (S4U模式)。
|
构建跨群聊/私聊上下文 (S4U模式)。
|
||||||
优先展示目标用户的私聊记录(双向),其次按时间顺序展示其他群聊记录。
|
优先展示目标用户的私聊记录(双向),其次按时间顺序展示其他群聊记录。
|
||||||
"""
|
"""
|
||||||
|
# 记录S4U上下文构建开始
|
||||||
logger.debug("[S4U] Starting S4U context build.")
|
logger.debug("[S4U] Starting S4U context build.")
|
||||||
|
|
||||||
|
# 检查全局配置是否存在且包含必要部分
|
||||||
|
if not global_config or not global_config.cross_context or not global_config.bot:
|
||||||
|
logger.error("全局配置尚未初始化或缺少关键配置,无法构建S4U上下文。")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# 获取跨上下文配置
|
||||||
cross_context_config = global_config.cross_context
|
cross_context_config = global_config.cross_context
|
||||||
|
|
||||||
|
# 检查目标用户信息和用户ID是否存在
|
||||||
if not target_user_info or not (user_id := target_user_info.get("user_id")):
|
if not target_user_info or not (user_id := target_user_info.get("user_id")):
|
||||||
logger.warning(f"[S4U] Failed: target_user_info ({target_user_info}) or user_id is missing.")
|
logger.warning(f"[S4U] Failed: target_user_info ({target_user_info}) or user_id is missing.")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
# 记录目标用户ID
|
||||||
logger.debug(f"[S4U] Target user ID: {user_id}")
|
logger.debug(f"[S4U] Target user ID: {user_id}")
|
||||||
|
|
||||||
|
# 获取聊天管理器实例
|
||||||
chat_manager = get_chat_manager()
|
chat_manager = get_chat_manager()
|
||||||
private_context_block = ""
|
private_context_block = ""
|
||||||
group_context_blocks = []
|
group_context_blocks = []
|
||||||
|
|
||||||
# --- 1. 优先处理私聊上下文 ---
|
# --- 1. 优先处理私聊上下文 ---
|
||||||
|
# 获取与目标用户的私聊流ID
|
||||||
private_stream_id = chat_manager.get_stream_id(chat_stream.platform, user_id, is_group=False)
|
private_stream_id = chat_manager.get_stream_id(chat_stream.platform, user_id, is_group=False)
|
||||||
|
|
||||||
|
# 如果存在私聊流且不是当前聊天流
|
||||||
if private_stream_id and private_stream_id != chat_stream.stream_id:
|
if private_stream_id and private_stream_id != chat_stream.stream_id:
|
||||||
logger.debug(f"[S4U] Found private chat with target user: {private_stream_id}")
|
logger.debug(f"[S4U] Found private chat with target user: {private_stream_id}")
|
||||||
try:
|
try:
|
||||||
|
# 定义需要获取消息的用户ID列表(目标用户和机器人自己)
|
||||||
user_ids_to_fetch = [str(user_id), str(global_config.bot.qq_account)]
|
user_ids_to_fetch = [str(user_id), str(global_config.bot.qq_account)]
|
||||||
|
|
||||||
|
# 从指定私聊流中获取双方的消息
|
||||||
messages_by_stream = await get_user_messages_from_streams(
|
messages_by_stream = await get_user_messages_from_streams(
|
||||||
user_ids=user_ids_to_fetch,
|
user_ids=user_ids_to_fetch,
|
||||||
stream_ids=[private_stream_id],
|
stream_ids=[private_stream_id],
|
||||||
timestamp_after=time.time() - (3 * 24 * 60 * 60), # 3天
|
timestamp_after=time.time() - (3 * 24 * 60 * 60), # 最近3天的消息
|
||||||
limit_per_stream=cross_context_config.s4u_limit,
|
limit_per_stream=cross_context_config.s4u_limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 如果获取到了私聊消息
|
||||||
if private_messages := messages_by_stream.get(private_stream_id):
|
if private_messages := messages_by_stream.get(private_stream_id):
|
||||||
chat_name = await chat_manager.get_stream_name(private_stream_id) or "私聊"
|
chat_name = await chat_manager.get_stream_name(private_stream_id) or "私聊"
|
||||||
title = f'[以下是您与"{chat_name}"的近期私聊记录]\n'
|
title = f'[以下是您与"{chat_name}"的近期私聊记录]\n'
|
||||||
|
|
||||||
|
# 格式化消息为可读字符串
|
||||||
formatted, _ = await build_readable_messages_with_id(private_messages, timestamp_mode="relative")
|
formatted, _ = await build_readable_messages_with_id(private_messages, timestamp_mode="relative")
|
||||||
private_context_block = f"{title}{formatted}"
|
private_context_block = f"{title}{formatted}"
|
||||||
logger.debug(f"[S4U] Generated private context block of length {len(private_context_block)}.")
|
logger.debug(f"[S4U] Generated private context block of length {len(private_context_block)}.")
|
||||||
@@ -168,18 +85,23 @@ async def build_cross_context_s4u(
|
|||||||
|
|
||||||
# --- 2. 处理其他群聊上下文 ---
|
# --- 2. 处理其他群聊上下文 ---
|
||||||
streams_to_scan = []
|
streams_to_scan = []
|
||||||
# 根据全局S4U配置确定要扫描的聊天范围
|
|
||||||
|
# 根据S4U配置模式(白名单/黑名单)确定要扫描的聊天范围
|
||||||
if cross_context_config.s4u_mode == "whitelist":
|
if cross_context_config.s4u_mode == "whitelist":
|
||||||
|
# 白名单模式:只扫描在白名单中的聊天
|
||||||
for chat_str in cross_context_config.s4u_whitelist_chats:
|
for chat_str in cross_context_config.s4u_whitelist_chats:
|
||||||
try:
|
try:
|
||||||
platform, chat_type, chat_raw_id = chat_str.split(":")
|
platform, chat_type, chat_raw_id = chat_str.split(":")
|
||||||
is_group = chat_type == "group"
|
is_group = chat_type == "group"
|
||||||
stream_id = chat_manager.get_stream_id(platform, chat_raw_id, is_group=is_group)
|
stream_id = chat_manager.get_stream_id(platform, chat_raw_id, is_group=is_group)
|
||||||
|
|
||||||
|
# 排除当前聊和私聊
|
||||||
if stream_id and stream_id != chat_stream.stream_id and stream_id != private_stream_id:
|
if stream_id and stream_id != chat_stream.stream_id and stream_id != private_stream_id:
|
||||||
streams_to_scan.append(stream_id)
|
streams_to_scan.append(stream_id)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
logger.warning(f"无效的S4U白名单格式: {chat_str}")
|
logger.warning(f"无效的S4U白名单格式: {chat_str}")
|
||||||
else: # blacklist mode
|
else: # 黑名单模式
|
||||||
|
# 黑名单模式:扫描所有聊天,除了黑名单中的和私聊
|
||||||
blacklisted_streams = {private_stream_id}
|
blacklisted_streams = {private_stream_id}
|
||||||
for chat_str in cross_context_config.s4u_blacklist_chats:
|
for chat_str in cross_context_config.s4u_blacklist_chats:
|
||||||
try:
|
try:
|
||||||
@@ -190,6 +112,8 @@ async def build_cross_context_s4u(
|
|||||||
blacklisted_streams.add(stream_id)
|
blacklisted_streams.add(stream_id)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
logger.warning(f"无效的S4U黑名单格式: {chat_str}")
|
logger.warning(f"无效的S4U黑名单格式: {chat_str}")
|
||||||
|
|
||||||
|
# 将不在黑名单中的流添加到扫描列表
|
||||||
streams_to_scan.extend(
|
streams_to_scan.extend(
|
||||||
stream_id for stream_id in chat_manager.streams
|
stream_id for stream_id in chat_manager.streams
|
||||||
if stream_id != chat_stream.stream_id and stream_id not in blacklisted_streams
|
if stream_id != chat_stream.stream_id and stream_id not in blacklisted_streams
|
||||||
@@ -198,14 +122,16 @@ async def build_cross_context_s4u(
|
|||||||
logger.debug(f"[S4U] Found {len(streams_to_scan)} group streams to scan.")
|
logger.debug(f"[S4U] Found {len(streams_to_scan)} group streams to scan.")
|
||||||
|
|
||||||
if streams_to_scan:
|
if streams_to_scan:
|
||||||
|
# 获取目标用户在这些群聊中的消息
|
||||||
messages_by_stream = await get_user_messages_from_streams(
|
messages_by_stream = await get_user_messages_from_streams(
|
||||||
user_ids=[str(user_id)],
|
user_ids=[str(user_id)],
|
||||||
stream_ids=streams_to_scan,
|
stream_ids=streams_to_scan,
|
||||||
timestamp_after=time.time() - (3 * 24 * 60 * 60),
|
timestamp_after=time.time() - (3 * 24 * 60 * 60), # 最近3天
|
||||||
limit_per_stream=cross_context_config.s4u_limit,
|
limit_per_stream=cross_context_config.s4u_limit,
|
||||||
)
|
)
|
||||||
|
|
||||||
all_group_messages = []
|
all_group_messages = []
|
||||||
|
# 将所有群聊消息聚合,并附带最新时间戳
|
||||||
for stream_id, user_messages in messages_by_stream.items():
|
for stream_id, user_messages in messages_by_stream.items():
|
||||||
if user_messages:
|
if user_messages:
|
||||||
latest_timestamp = max(msg.get("time", 0) for msg in user_messages)
|
latest_timestamp = max(msg.get("time", 0) for msg in user_messages)
|
||||||
@@ -213,12 +139,14 @@ async def build_cross_context_s4u(
|
|||||||
{"stream_id": stream_id, "messages": user_messages, "latest_timestamp": latest_timestamp}
|
{"stream_id": stream_id, "messages": user_messages, "latest_timestamp": latest_timestamp}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 按最新消息时间倒序排序
|
||||||
all_group_messages.sort(key=lambda x: x["latest_timestamp"], reverse=True)
|
all_group_messages.sort(key=lambda x: x["latest_timestamp"], reverse=True)
|
||||||
|
|
||||||
# 计算群聊的额度
|
# 计算群聊上下文的额度
|
||||||
remaining_limit = cross_context_config.s4u_stream_limit - (1 if private_context_block else 0)
|
remaining_limit = cross_context_config.s4u_stream_limit - (1 if private_context_block else 0)
|
||||||
limited_group_messages = all_group_messages[:remaining_limit]
|
limited_group_messages = all_group_messages[:remaining_limit]
|
||||||
|
|
||||||
|
# 格式化每个群聊的消息
|
||||||
for item in limited_group_messages:
|
for item in limited_group_messages:
|
||||||
try:
|
try:
|
||||||
chat_name = await chat_manager.get_stream_name(item["stream_id"]) or "未知群聊"
|
chat_name = await chat_manager.get_stream_name(item["stream_id"]) or "未知群聊"
|
||||||
@@ -230,128 +158,24 @@ async def build_cross_context_s4u(
|
|||||||
logger.error(f"S4U模式下格式化群聊消息失败 (stream: {item['stream_id']}): {e}")
|
logger.error(f"S4U模式下格式化群聊消息失败 (stream: {item['stream_id']}): {e}")
|
||||||
|
|
||||||
# --- 3. 组合最终上下文 ---
|
# --- 3. 组合最终上下文 ---
|
||||||
|
# 如果没有任何上下文内容,则返回空
|
||||||
if not private_context_block and not group_context_blocks:
|
if not private_context_block and not group_context_blocks:
|
||||||
logger.debug("[S4U] No context blocks were generated. Returning empty string.")
|
logger.debug("[S4U] No context blocks were generated. Returning empty string.")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
final_context_parts = []
|
final_context_parts = []
|
||||||
|
# 添加私聊部分
|
||||||
if private_context_block:
|
if private_context_block:
|
||||||
final_context_parts.append(private_context_block)
|
final_context_parts.append(private_context_block)
|
||||||
|
# 添加群聊部分
|
||||||
if group_context_blocks:
|
if group_context_blocks:
|
||||||
group_context_str = "\n\n".join(group_context_blocks)
|
group_context_str = "\n\n".join(group_context_blocks)
|
||||||
final_context_parts.append(f"### 其他群聊中的聊天记录\n{group_context_str}")
|
final_context_parts.append(f"### 其他群聊中的聊天记录\n{group_context_str}")
|
||||||
|
|
||||||
|
# 组合最终的上下文字符串
|
||||||
final_context = "\n\n".join(final_context_parts) + "\n"
|
final_context = "\n\n".join(final_context_parts) + "\n"
|
||||||
logger.debug(f"[S4U] Successfully generated S4U context. Total length: {len(final_context)}.")
|
logger.debug(f"[S4U] Successfully generated S4U context. Total length: {len(final_context)}.")
|
||||||
return final_context
|
return final_context
|
||||||
|
|
||||||
|
|
||||||
async def get_intercom_group_context(group_name: str, limit_per_chat: int = 20, total_limit: int = 100) -> str | None:
|
|
||||||
"""
|
|
||||||
根据互通组的名称,构建该组的聊天上下文。
|
|
||||||
支持黑白名单模式,并以分块形式返回每个聊天的消息。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
group_name: 互通组的名称。
|
|
||||||
limit_per_chat: 每个聊天最多获取的消息条数。
|
|
||||||
total_limit: 返回的总消息条数上限。
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
如果找到匹配的组并获取到消息,则返回一个包含聊天记录的字符串;否则返回 None。
|
|
||||||
"""
|
|
||||||
cross_context_config = global_config.cross_context
|
|
||||||
if not (cross_context_config and cross_context_config.enable):
|
|
||||||
return None
|
|
||||||
|
|
||||||
target_group = next((g for g in cross_context_config.groups if g.name == group_name), None)
|
|
||||||
|
|
||||||
if not target_group:
|
|
||||||
logger.error(f"在 cross_context 配置中未找到名为 '{group_name}' 的组。")
|
|
||||||
return None
|
|
||||||
|
|
||||||
chat_manager = get_chat_manager()
|
|
||||||
|
|
||||||
# 1. 根据黑白名单模式确定要处理的聊天列表
|
|
||||||
chat_infos_to_fetch = []
|
|
||||||
if target_group.mode == "blacklist":
|
|
||||||
blacklisted_ids = {tuple(info[:2]) for info in target_group.chat_ids}
|
|
||||||
for stream in chat_manager.streams.values():
|
|
||||||
is_group = stream.group_info is not None
|
|
||||||
chat_type = "group" if is_group else "private"
|
|
||||||
|
|
||||||
if is_group and stream.group_info:
|
|
||||||
raw_id = stream.group_info.group_id
|
|
||||||
elif not is_group and stream.user_info:
|
|
||||||
raw_id = stream.user_info.user_id
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if (chat_type, str(raw_id)) not in blacklisted_ids:
|
|
||||||
chat_infos_to_fetch.append([chat_type, str(raw_id)])
|
|
||||||
else: # whitelist mode
|
|
||||||
chat_infos_to_fetch = target_group.chat_ids
|
|
||||||
|
|
||||||
# 2. 获取所有相关消息
|
|
||||||
all_messages = []
|
|
||||||
for chat_info in chat_infos_to_fetch:
|
|
||||||
chat_type, chat_raw_id = chat_info[0], chat_info[1]
|
|
||||||
is_group = chat_type == "group"
|
|
||||||
|
|
||||||
# 查找 stream
|
|
||||||
found_stream = None
|
|
||||||
for stream in chat_manager.streams.values():
|
|
||||||
if is_group:
|
|
||||||
if stream.group_info and stream.group_info.group_id == chat_raw_id:
|
|
||||||
found_stream = stream
|
|
||||||
break
|
|
||||||
else: # private
|
|
||||||
if stream.user_info and stream.user_info.user_id == chat_raw_id and not stream.group_info:
|
|
||||||
found_stream = stream
|
|
||||||
break
|
|
||||||
if not found_stream:
|
|
||||||
logger.warning(f"在已加载的聊天流中找不到ID为 {chat_raw_id} 的聊天。")
|
|
||||||
continue
|
|
||||||
stream_id = found_stream.stream_id
|
|
||||||
|
|
||||||
try:
|
|
||||||
messages = await get_raw_msg_before_timestamp_with_chat(
|
|
||||||
chat_id=stream_id,
|
|
||||||
timestamp=time.time(),
|
|
||||||
limit=limit_per_chat,
|
|
||||||
)
|
|
||||||
if messages:
|
|
||||||
# 为每条消息附加 stream_id 以便后续分组
|
|
||||||
for msg in messages:
|
|
||||||
msg["_stream_id"] = stream_id
|
|
||||||
all_messages.extend(messages)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取聊天 {chat_raw_id} 的消息失败: {e}")
|
|
||||||
|
|
||||||
if not all_messages:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 3. 应用总数限制
|
|
||||||
all_messages.sort(key=lambda x: x.get("time", 0))
|
|
||||||
if len(all_messages) > total_limit:
|
|
||||||
all_messages = all_messages[-total_limit:]
|
|
||||||
|
|
||||||
# 4. 按聊天分组并格式化
|
|
||||||
messages_by_stream = {}
|
|
||||||
for msg in all_messages:
|
|
||||||
stream_id = msg.get("_stream_id")
|
|
||||||
if stream_id not in messages_by_stream:
|
|
||||||
messages_by_stream[stream_id] = []
|
|
||||||
messages_by_stream[stream_id].append(msg)
|
|
||||||
|
|
||||||
cross_context_messages = []
|
|
||||||
for stream_id, messages in messages_by_stream.items():
|
|
||||||
if messages:
|
|
||||||
chat_name = await chat_manager.get_stream_name(stream_id) or "未知聊天"
|
|
||||||
formatted_messages, _ = await build_readable_messages_with_id(messages, timestamp_mode="relative")
|
|
||||||
cross_context_messages.append(f'[以下是来自"{chat_name}"的近期消息]\n{formatted_messages}')
|
|
||||||
|
|
||||||
if not cross_context_messages:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return "# 跨上下文参考\n" + "\n\n".join(cross_context_messages) + "\n"
|
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ import json5
|
|||||||
import orjson
|
import orjson
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.plugin_system.apis import config_api, cross_context_api, person_api
|
from src.plugin_system.apis import config_api, person_api
|
||||||
|
|
||||||
from .content_service import ContentService
|
from .content_service import ContentService
|
||||||
from .cookie_service import CookieService
|
from .cookie_service import CookieService
|
||||||
@@ -63,10 +63,7 @@ class QZoneService:
|
|||||||
|
|
||||||
async def send_feed(self, topic: str, stream_id: str | None) -> dict[str, Any]:
|
async def send_feed(self, topic: str, stream_id: str | None) -> dict[str, Any]:
|
||||||
"""发送一条说说"""
|
"""发送一条说说"""
|
||||||
# --- 获取互通组上下文 ---
|
story = await self.content_service.generate_story(topic, context=None)
|
||||||
context = await self._get_intercom_context(stream_id) if stream_id else None
|
|
||||||
|
|
||||||
story = await self.content_service.generate_story(topic, context=context)
|
|
||||||
if not story:
|
if not story:
|
||||||
return {"success": False, "message": "生成说说内容失败"}
|
return {"success": False, "message": "生成说说内容失败"}
|
||||||
|
|
||||||
@@ -302,12 +299,6 @@ class QZoneService:
|
|||||||
|
|
||||||
# --- Internal Helper Methods ---
|
# --- Internal Helper Methods ---
|
||||||
|
|
||||||
async def _get_intercom_context(self, stream_id: str) -> str | None:
|
|
||||||
"""
|
|
||||||
获取互通组的聊天上下文。
|
|
||||||
"""
|
|
||||||
# 实际的逻辑已迁移到 cross_context_api
|
|
||||||
return await cross_context_api.get_intercom_group_context("maizone_context_group")
|
|
||||||
|
|
||||||
async def _reply_to_own_feed_comments(self, feed: dict, api_client: dict):
|
async def _reply_to_own_feed_comments(self, feed: dict, api_client: dict):
|
||||||
"""处理对自己说说的评论并进行回复"""
|
"""处理对自己说说的评论并进行回复"""
|
||||||
|
|||||||
Reference in New Issue
Block a user