Merge branch 'dev' of https://github.com/MoFox-Studio/MoFox-Core into dev
This commit is contained in:
@@ -702,28 +702,6 @@ class WebSearchConfig(ValidatedConfigBase):
|
||||
search_strategy: Literal["fallback", "single", "parallel"] = Field(default="single", description="搜索策略")
|
||||
|
||||
|
||||
class ContextGroup(ValidatedConfigBase):
|
||||
"""
|
||||
上下文共享组配置
|
||||
|
||||
定义了一个聊天上下文的共享范围和规则。
|
||||
"""
|
||||
|
||||
name: str = Field(..., description="共享组的名称,用于唯一标识一个共享组")
|
||||
mode: Literal["whitelist", "blacklist"] = Field(
|
||||
default="whitelist",
|
||||
description="共享模式。'whitelist'表示仅共享chat_ids中列出的聊天;'blacklist'表示共享除chat_ids中列出的所有聊天。",
|
||||
)
|
||||
default_limit: int = Field(
|
||||
default=5,
|
||||
description="在'blacklist'模式下,对于未明确指定数量的聊天,默认获取的消息条数。",
|
||||
)
|
||||
chat_ids: list[list[str]] = Field(
|
||||
...,
|
||||
description='定义组内成员的列表。格式为 [["type", "id", "limit"(可选)]]。type为"group"或"private",id为群号或用户ID,limit为可选的消息条数。',
|
||||
)
|
||||
|
||||
|
||||
class MaizoneContextGroup(ValidatedConfigBase):
|
||||
"""QQ空间专用互通组配置"""
|
||||
|
||||
@@ -739,8 +717,6 @@ class CrossContextConfig(ValidatedConfigBase):
|
||||
|
||||
enable: bool = Field(default=False, description="是否启用跨群聊上下文共享功能")
|
||||
|
||||
# --- Normal模式: 共享组配置 ---
|
||||
groups: list[ContextGroup] = Field(default_factory=list, description="上下文共享组列表")
|
||||
# --- S4U模式: 用户中心上下文检索 ---
|
||||
s4u_mode: Literal["whitelist", "blacklist"] = Field(
|
||||
default="whitelist",
|
||||
|
||||
@@ -392,7 +392,7 @@ MoFox_Bot(第三方修改版)
|
||||
全部组件已成功启动!
|
||||
=========================================================
|
||||
🌐 项目地址: https://github.com/MoFox-Studio/MoFox-Core
|
||||
🏠 官方项目: https://github.com/MaiM-with-u/MaiBot
|
||||
🏠 官方项目: https://github.com/Mai-with-u/MaiBot
|
||||
=========================================================
|
||||
这是基于原版MMC的社区改版,包含增强功能和优化(同时也有更多的'特性')
|
||||
=========================================================
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
import time
|
||||
from typing import Any, TYPE_CHECKING
|
||||
from src.common.message_repository import find_messages
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
@@ -13,7 +14,6 @@ from src.chat.utils.chat_message_builder import (
|
||||
from src.common.logger import get_logger
|
||||
from src.common.message_repository import get_user_messages_from_streams
|
||||
from src.config.config import global_config
|
||||
from src.config.official_configs import ContextGroup
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
@@ -21,111 +21,6 @@ if TYPE_CHECKING:
|
||||
logger = get_logger("cross_context_api")
|
||||
|
||||
|
||||
async def get_context_group(chat_id: str) -> ContextGroup | None:
|
||||
"""
|
||||
获取当前聊天所在的共享组
|
||||
"""
|
||||
current_stream = await get_chat_manager().get_stream(chat_id)
|
||||
if not current_stream:
|
||||
return None
|
||||
|
||||
is_group = current_stream.group_info is not None
|
||||
if not is_group and not current_stream.user_info:
|
||||
return None
|
||||
if is_group:
|
||||
assert current_stream.group_info is not None
|
||||
current_chat_raw_id = current_stream.group_info.group_id
|
||||
elif current_stream.user_info:
|
||||
current_chat_raw_id = current_stream.user_info.user_id
|
||||
else:
|
||||
return None
|
||||
current_type = "group" if is_group else "private"
|
||||
|
||||
for group in global_config.cross_context.groups:
|
||||
for chat_info in group.chat_ids:
|
||||
if len(chat_info) >= 2:
|
||||
chat_type, chat_raw_id = chat_info[0], chat_info[1]
|
||||
if chat_type == current_type and str(chat_raw_id) == str(current_chat_raw_id):
|
||||
# 排除maizone专用组
|
||||
if group.name == "maizone_context_group":
|
||||
continue
|
||||
return group
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def build_cross_context_normal(chat_stream: "ChatStream", context_group: ContextGroup) -> str:
|
||||
"""
|
||||
构建跨群聊/私聊上下文 (Normal模式)。
|
||||
|
||||
根据共享组的配置(白名单或黑名单模式),获取相关聊天的近期消息,并格式化为字符串。
|
||||
|
||||
Args:
|
||||
chat_stream: 当前的聊天流对象。
|
||||
context_group: 当前聊天所在的上下文共享组配置。
|
||||
|
||||
Returns:
|
||||
一个包含格式化后的跨上下文消息的字符串,如果无消息则为空字符串。
|
||||
"""
|
||||
cross_context_messages = []
|
||||
chat_manager = get_chat_manager()
|
||||
|
||||
chat_infos_to_fetch = []
|
||||
if context_group.mode == "blacklist":
|
||||
# 黑名单模式:获取所有聊天,并排除在 chat_ids 中定义过的聊天
|
||||
blacklisted_ids = {tuple(info[:2]) for info in context_group.chat_ids}
|
||||
for stream_id, stream in chat_manager.streams.items():
|
||||
is_group = stream.group_info is not None
|
||||
chat_type = "group" if is_group else "private"
|
||||
|
||||
# 安全地获取 raw_id
|
||||
if is_group and stream.group_info:
|
||||
raw_id = stream.group_info.group_id
|
||||
elif not is_group and stream.user_info:
|
||||
raw_id = stream.user_info.user_id
|
||||
else:
|
||||
continue # 如果缺少关键信息则跳过
|
||||
|
||||
# 如果当前聊天不在黑名单中,则添加到待获取列表
|
||||
if (chat_type, str(raw_id)) not in blacklisted_ids:
|
||||
chat_infos_to_fetch.append([chat_type, str(raw_id), str(context_group.default_limit)])
|
||||
else:
|
||||
# 白名单模式:直接使用配置中定义的 chat_ids
|
||||
chat_infos_to_fetch = context_group.chat_ids
|
||||
|
||||
# 遍历待获取列表,抓取并格式化消息
|
||||
for chat_info in chat_infos_to_fetch:
|
||||
chat_type, chat_raw_id, limit_str = (
|
||||
chat_info[0],
|
||||
chat_info[1],
|
||||
chat_info[2] if len(chat_info) > 2 else str(context_group.default_limit),
|
||||
)
|
||||
limit = int(limit_str)
|
||||
is_group = chat_type == "group"
|
||||
stream_id = chat_manager.get_stream_id(chat_stream.platform, chat_raw_id, is_group=is_group)
|
||||
if not stream_id or stream_id == chat_stream.stream_id:
|
||||
continue
|
||||
|
||||
try:
|
||||
messages = await get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=limit,
|
||||
)
|
||||
if messages:
|
||||
chat_name = await chat_manager.get_stream_name(stream_id) or chat_raw_id
|
||||
formatted_messages, _ = await build_readable_messages_with_id(messages, timestamp_mode="relative")
|
||||
cross_context_messages.append(f'[以下是来自"{chat_name}"的近期消息]\n{formatted_messages}')
|
||||
except Exception as e:
|
||||
logger.error(f"获取聊天 {chat_raw_id} 的消息失败: {e}")
|
||||
continue
|
||||
|
||||
if not cross_context_messages:
|
||||
return ""
|
||||
|
||||
return "# 跨上下文参考\n" + "\n\n".join(cross_context_messages) + "\n"
|
||||
|
||||
|
||||
async def build_cross_context_s4u(
|
||||
chat_stream: "ChatStream",
|
||||
target_user_info: dict[str, Any] | None,
|
||||
@@ -134,32 +29,55 @@ async def build_cross_context_s4u(
|
||||
构建跨群聊/私聊上下文 (S4U模式)。
|
||||
优先展示目标用户的私聊记录(双向),其次按时间顺序展示其他群聊记录。
|
||||
"""
|
||||
# 记录S4U上下文构建开始
|
||||
logger.debug("[S4U] Starting S4U context build.")
|
||||
|
||||
# 检查全局配置是否存在且包含必要部分
|
||||
if not global_config or not global_config.cross_context or not global_config.bot:
|
||||
logger.error("全局配置尚未初始化或缺少关键配置,无法构建S4U上下文。")
|
||||
return ""
|
||||
|
||||
# 获取跨上下文配置
|
||||
cross_context_config = global_config.cross_context
|
||||
|
||||
# 检查目标用户信息和用户ID是否存在
|
||||
if not target_user_info or not (user_id := target_user_info.get("user_id")):
|
||||
logger.warning(f"[S4U] Failed: target_user_info ({target_user_info}) or user_id is missing.")
|
||||
return ""
|
||||
|
||||
# 记录目标用户ID
|
||||
logger.debug(f"[S4U] Target user ID: {user_id}")
|
||||
|
||||
# 获取聊天管理器实例
|
||||
chat_manager = get_chat_manager()
|
||||
private_context_block = ""
|
||||
group_context_blocks = []
|
||||
|
||||
# --- 1. 优先处理私聊上下文 ---
|
||||
# 获取与目标用户的私聊流ID
|
||||
private_stream_id = chat_manager.get_stream_id(chat_stream.platform, user_id, is_group=False)
|
||||
|
||||
# 如果存在私聊流且不是当前聊天流
|
||||
if private_stream_id and private_stream_id != chat_stream.stream_id:
|
||||
logger.debug(f"[S4U] Found private chat with target user: {private_stream_id}")
|
||||
try:
|
||||
# 定义需要获取消息的用户ID列表(目标用户和机器人自己)
|
||||
user_ids_to_fetch = [str(user_id), str(global_config.bot.qq_account)]
|
||||
|
||||
# 从指定私聊流中获取双方的消息
|
||||
messages_by_stream = await get_user_messages_from_streams(
|
||||
user_ids=user_ids_to_fetch,
|
||||
stream_ids=[private_stream_id],
|
||||
timestamp_after=time.time() - (3 * 24 * 60 * 60), # 3天
|
||||
timestamp_after=time.time() - (3 * 24 * 60 * 60), # 最近3天的消息
|
||||
limit_per_stream=cross_context_config.s4u_limit,
|
||||
)
|
||||
|
||||
# 如果获取到了私聊消息
|
||||
if private_messages := messages_by_stream.get(private_stream_id):
|
||||
chat_name = await chat_manager.get_stream_name(private_stream_id) or "私聊"
|
||||
title = f'[以下是您与"{chat_name}"的近期私聊记录]\n'
|
||||
|
||||
# 格式化消息为可读字符串
|
||||
formatted, _ = await build_readable_messages_with_id(private_messages, timestamp_mode="relative")
|
||||
private_context_block = f"{title}{formatted}"
|
||||
logger.debug(f"[S4U] Generated private context block of length {len(private_context_block)}.")
|
||||
@@ -168,18 +86,23 @@ async def build_cross_context_s4u(
|
||||
|
||||
# --- 2. 处理其他群聊上下文 ---
|
||||
streams_to_scan = []
|
||||
# 根据全局S4U配置确定要扫描的聊天范围
|
||||
|
||||
# 根据S4U配置模式(白名单/黑名单)确定要扫描的聊天范围
|
||||
if cross_context_config.s4u_mode == "whitelist":
|
||||
# 白名单模式:只扫描在白名单中的聊天
|
||||
for chat_str in cross_context_config.s4u_whitelist_chats:
|
||||
try:
|
||||
platform, chat_type, chat_raw_id = chat_str.split(":")
|
||||
is_group = chat_type == "group"
|
||||
stream_id = chat_manager.get_stream_id(platform, chat_raw_id, is_group=is_group)
|
||||
|
||||
# 排除当前聊和私聊
|
||||
if stream_id and stream_id != chat_stream.stream_id and stream_id != private_stream_id:
|
||||
streams_to_scan.append(stream_id)
|
||||
except ValueError:
|
||||
logger.warning(f"无效的S4U白名单格式: {chat_str}")
|
||||
else: # blacklist mode
|
||||
else: # 黑名单模式
|
||||
# 黑名单模式:扫描所有聊天,除了黑名单中的和私聊
|
||||
blacklisted_streams = {private_stream_id}
|
||||
for chat_str in cross_context_config.s4u_blacklist_chats:
|
||||
try:
|
||||
@@ -190,6 +113,8 @@ async def build_cross_context_s4u(
|
||||
blacklisted_streams.add(stream_id)
|
||||
except ValueError:
|
||||
logger.warning(f"无效的S4U黑名单格式: {chat_str}")
|
||||
|
||||
# 将不在黑名单中的流添加到扫描列表
|
||||
streams_to_scan.extend(
|
||||
stream_id for stream_id in chat_manager.streams
|
||||
if stream_id != chat_stream.stream_id and stream_id not in blacklisted_streams
|
||||
@@ -197,12 +122,113 @@ async def build_cross_context_s4u(
|
||||
|
||||
logger.debug(f"[S4U] Found {len(streams_to_scan)} group streams to scan.")
|
||||
|
||||
if streams_to_scan:
|
||||
# 获取目标用户在这些群聊中的消息
|
||||
messages_by_stream = await get_user_messages_from_streams(
|
||||
user_ids=[str(user_id)],
|
||||
stream_ids=streams_to_scan,
|
||||
timestamp_after=time.time() - (3 * 24 * 60 * 60), # 最近3天
|
||||
limit_per_stream=cross_context_config.s4u_limit,
|
||||
)
|
||||
|
||||
all_group_messages = []
|
||||
# 将所有群聊消息聚合,并附带最新时间戳
|
||||
for stream_id, user_messages in messages_by_stream.items():
|
||||
if user_messages:
|
||||
latest_timestamp = max(msg.get("time", 0) for msg in user_messages)
|
||||
all_group_messages.append(
|
||||
{"stream_id": stream_id, "messages": user_messages, "latest_timestamp": latest_timestamp}
|
||||
)
|
||||
|
||||
# 按最新消息时间倒序排序
|
||||
all_group_messages.sort(key=lambda x: x["latest_timestamp"], reverse=True)
|
||||
|
||||
# 计算群聊上下文的额度
|
||||
remaining_limit = cross_context_config.s4u_stream_limit - (1 if private_context_block else 0)
|
||||
limited_group_messages = all_group_messages[:remaining_limit]
|
||||
|
||||
# 格式化每个群聊的消息
|
||||
for item in limited_group_messages:
|
||||
try:
|
||||
chat_name = await chat_manager.get_stream_name(item["stream_id"]) or "未知群聊"
|
||||
user_name = target_user_info.get("person_name") or target_user_info.get("user_nickname") or user_id
|
||||
title = f'[以下是"{user_name}"在"{chat_name}"的近期发言]\n'
|
||||
formatted, _ = await build_readable_messages_with_id(item["messages"], timestamp_mode="relative")
|
||||
group_context_blocks.append(f"{title}{formatted}")
|
||||
except Exception as e:
|
||||
logger.error(f"S4U模式下格式化群聊消息失败 (stream: {item['stream_id']}): {e}")
|
||||
|
||||
# --- 3. 组合最终上下文 ---
|
||||
# 如果没有任何上下文内容,则返回空
|
||||
if not private_context_block and not group_context_blocks:
|
||||
logger.debug("[S4U] No context blocks were generated. Returning empty string.")
|
||||
return ""
|
||||
|
||||
final_context_parts = []
|
||||
# 添加私聊部分
|
||||
if private_context_block:
|
||||
final_context_parts.append(private_context_block)
|
||||
# 添加群聊部分
|
||||
if group_context_blocks:
|
||||
group_context_str = "\n\n".join(group_context_blocks)
|
||||
final_context_parts.append(f"### 其他群聊中的聊天记录\n{group_context_str}")
|
||||
|
||||
# 组合最终的上下文字符串
|
||||
final_context = "\n\n".join(final_context_parts) + "\n"
|
||||
logger.debug(f"[S4U] Successfully generated S4U context. Total length: {len(final_context)}.")
|
||||
return final_context
|
||||
|
||||
async def build_cross_context_for_user(
|
||||
user_id: str,
|
||||
platform: str,
|
||||
limit_per_stream: int,
|
||||
stream_limit: int,
|
||||
) -> str:
|
||||
"""
|
||||
构建指定用户的跨群聊/私聊上下文(简化版API)。
|
||||
"""
|
||||
logger.debug(f"[S4U_SIMPLE] Starting simplified S4U context build for user {user_id} on {platform}.")
|
||||
|
||||
if not global_config or not global_config.cross_context or not global_config.bot:
|
||||
logger.error("全局配置尚未初始化或缺少关键配置,无法构建S4U上下文。")
|
||||
return ""
|
||||
|
||||
chat_manager = get_chat_manager()
|
||||
|
||||
private_context_block = ""
|
||||
group_context_blocks = []
|
||||
|
||||
# --- 1. 优先处理私聊上下文 ---
|
||||
private_stream_id = chat_manager.get_stream_id(platform, user_id, is_group=False)
|
||||
if private_stream_id:
|
||||
try:
|
||||
user_ids_to_fetch = [str(user_id), str(global_config.bot.qq_account)]
|
||||
messages_by_stream = await get_user_messages_from_streams(
|
||||
user_ids=user_ids_to_fetch,
|
||||
stream_ids=[private_stream_id],
|
||||
timestamp_after=time.time() - (3 * 24 * 60 * 60),
|
||||
limit_per_stream=limit_per_stream,
|
||||
)
|
||||
if private_messages := messages_by_stream.get(private_stream_id):
|
||||
chat_name = await chat_manager.get_stream_name(private_stream_id) or "私聊"
|
||||
title = f'[以下是您与"{chat_name}"的近期私聊记录]\n'
|
||||
formatted, _ = await build_readable_messages_with_id(private_messages, timestamp_mode="relative")
|
||||
private_context_block = f"{title}{formatted}"
|
||||
except Exception as e:
|
||||
logger.error(f"[S4U_SIMPLE] 处理私聊记录失败: {e}")
|
||||
|
||||
# --- 2. 处理其他群聊上下文 ---
|
||||
streams_to_scan = [
|
||||
stream_id for stream_id in chat_manager.streams
|
||||
if stream_id != private_stream_id
|
||||
]
|
||||
|
||||
if streams_to_scan:
|
||||
messages_by_stream = await get_user_messages_from_streams(
|
||||
user_ids=[str(user_id)],
|
||||
stream_ids=streams_to_scan,
|
||||
timestamp_after=time.time() - (3 * 24 * 60 * 60),
|
||||
limit_per_stream=cross_context_config.s4u_limit,
|
||||
limit_per_stream=limit_per_stream,
|
||||
)
|
||||
|
||||
all_group_messages = []
|
||||
@@ -215,23 +241,21 @@ async def build_cross_context_s4u(
|
||||
|
||||
all_group_messages.sort(key=lambda x: x["latest_timestamp"], reverse=True)
|
||||
|
||||
# 计算群聊的额度
|
||||
remaining_limit = cross_context_config.s4u_stream_limit - (1 if private_context_block else 0)
|
||||
remaining_limit = stream_limit - (1 if private_context_block else 0)
|
||||
limited_group_messages = all_group_messages[:remaining_limit]
|
||||
|
||||
for item in limited_group_messages:
|
||||
try:
|
||||
chat_name = await chat_manager.get_stream_name(item["stream_id"]) or "未知群聊"
|
||||
user_name = target_user_info.get("person_name") or target_user_info.get("user_nickname") or user_id
|
||||
user_name = user_id # 简化处理
|
||||
title = f'[以下是"{user_name}"在"{chat_name}"的近期发言]\n'
|
||||
formatted, _ = await build_readable_messages_with_id(item["messages"], timestamp_mode="relative")
|
||||
group_context_blocks.append(f"{title}{formatted}")
|
||||
except Exception as e:
|
||||
logger.error(f"S4U模式下格式化群聊消息失败 (stream: {item['stream_id']}): {e}")
|
||||
logger.error(f"[S4U_SIMPLE] 格式化群聊消息失败 (stream: {item['stream_id']}): {e}")
|
||||
|
||||
# --- 3. 组合最终上下文 ---
|
||||
if not private_context_block and not group_context_blocks:
|
||||
logger.debug("[S4U] No context blocks were generated. Returning empty string.")
|
||||
return ""
|
||||
|
||||
final_context_parts = []
|
||||
@@ -242,116 +266,5 @@ async def build_cross_context_s4u(
|
||||
final_context_parts.append(f"### 其他群聊中的聊天记录\n{group_context_str}")
|
||||
|
||||
final_context = "\n\n".join(final_context_parts) + "\n"
|
||||
logger.debug(f"[S4U] Successfully generated S4U context. Total length: {len(final_context)}.")
|
||||
logger.debug(f"[S4U_SIMPLE] Successfully generated context for user {user_id}. Total length: {len(final_context)}.")
|
||||
return final_context
|
||||
|
||||
|
||||
async def get_intercom_group_context(group_name: str, limit_per_chat: int = 20, total_limit: int = 100) -> str | None:
|
||||
"""
|
||||
根据互通组的名称,构建该组的聊天上下文。
|
||||
支持黑白名单模式,并以分块形式返回每个聊天的消息。
|
||||
|
||||
Args:
|
||||
group_name: 互通组的名称。
|
||||
limit_per_chat: 每个聊天最多获取的消息条数。
|
||||
total_limit: 返回的总消息条数上限。
|
||||
|
||||
Returns:
|
||||
如果找到匹配的组并获取到消息,则返回一个包含聊天记录的字符串;否则返回 None。
|
||||
"""
|
||||
cross_context_config = global_config.cross_context
|
||||
if not (cross_context_config and cross_context_config.enable):
|
||||
return None
|
||||
|
||||
target_group = next((g for g in cross_context_config.groups if g.name == group_name), None)
|
||||
|
||||
if not target_group:
|
||||
logger.error(f"在 cross_context 配置中未找到名为 '{group_name}' 的组。")
|
||||
return None
|
||||
|
||||
chat_manager = get_chat_manager()
|
||||
|
||||
# 1. 根据黑白名单模式确定要处理的聊天列表
|
||||
chat_infos_to_fetch = []
|
||||
if target_group.mode == "blacklist":
|
||||
blacklisted_ids = {tuple(info[:2]) for info in target_group.chat_ids}
|
||||
for stream in chat_manager.streams.values():
|
||||
is_group = stream.group_info is not None
|
||||
chat_type = "group" if is_group else "private"
|
||||
|
||||
if is_group and stream.group_info:
|
||||
raw_id = stream.group_info.group_id
|
||||
elif not is_group and stream.user_info:
|
||||
raw_id = stream.user_info.user_id
|
||||
else:
|
||||
continue
|
||||
|
||||
if (chat_type, str(raw_id)) not in blacklisted_ids:
|
||||
chat_infos_to_fetch.append([chat_type, str(raw_id)])
|
||||
else: # whitelist mode
|
||||
chat_infos_to_fetch = target_group.chat_ids
|
||||
|
||||
# 2. 获取所有相关消息
|
||||
all_messages = []
|
||||
for chat_info in chat_infos_to_fetch:
|
||||
chat_type, chat_raw_id = chat_info[0], chat_info[1]
|
||||
is_group = chat_type == "group"
|
||||
|
||||
# 查找 stream
|
||||
found_stream = None
|
||||
for stream in chat_manager.streams.values():
|
||||
if is_group:
|
||||
if stream.group_info and stream.group_info.group_id == chat_raw_id:
|
||||
found_stream = stream
|
||||
break
|
||||
else: # private
|
||||
if stream.user_info and stream.user_info.user_id == chat_raw_id and not stream.group_info:
|
||||
found_stream = stream
|
||||
break
|
||||
if not found_stream:
|
||||
logger.warning(f"在已加载的聊天流中找不到ID为 {chat_raw_id} 的聊天。")
|
||||
continue
|
||||
stream_id = found_stream.stream_id
|
||||
|
||||
try:
|
||||
messages = await get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=limit_per_chat,
|
||||
)
|
||||
if messages:
|
||||
# 为每条消息附加 stream_id 以便后续分组
|
||||
for msg in messages:
|
||||
msg["_stream_id"] = stream_id
|
||||
all_messages.extend(messages)
|
||||
except Exception as e:
|
||||
logger.error(f"获取聊天 {chat_raw_id} 的消息失败: {e}")
|
||||
|
||||
if not all_messages:
|
||||
return None
|
||||
|
||||
# 3. 应用总数限制
|
||||
all_messages.sort(key=lambda x: x.get("time", 0))
|
||||
if len(all_messages) > total_limit:
|
||||
all_messages = all_messages[-total_limit:]
|
||||
|
||||
# 4. 按聊天分组并格式化
|
||||
messages_by_stream = {}
|
||||
for msg in all_messages:
|
||||
stream_id = msg.get("_stream_id")
|
||||
if stream_id not in messages_by_stream:
|
||||
messages_by_stream[stream_id] = []
|
||||
messages_by_stream[stream_id].append(msg)
|
||||
|
||||
cross_context_messages = []
|
||||
for stream_id, messages in messages_by_stream.items():
|
||||
if messages:
|
||||
chat_name = await chat_manager.get_stream_name(stream_id) or "未知聊天"
|
||||
formatted_messages, _ = await build_readable_messages_with_id(messages, timestamp_mode="relative")
|
||||
cross_context_messages.append(f'[以下是来自"{chat_name}"的近期消息]\n{formatted_messages}')
|
||||
|
||||
if not cross_context_messages:
|
||||
return None
|
||||
|
||||
return "# 跨上下文参考\n" + "\n\n".join(cross_context_messages) + "\n"
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ class MaiZoneRefactoredPlugin(BasePlugin):
|
||||
"enable_image": ConfigField(type=bool, default=False, description="是否启用说说配图"),
|
||||
"enable_ai_image": ConfigField(type=bool, default=False, description="是否启用AI生成配图"),
|
||||
"enable_reply": ConfigField(type=bool, default=True, description="完成后是否回复"),
|
||||
"ai_image_number": ConfigField(type=int, default=1, description="AI生成图片数量"),
|
||||
"ai_image_number": ConfigField(type=int, default=1, description="AI生成图片数量(1-4张)"),
|
||||
"image_number": ConfigField(type=int, default=1, description="本地配图数量(1-9张)"),
|
||||
"image_directory": ConfigField(
|
||||
type=str, default=(Path(__file__).parent / "images").as_posix(), description="图片存储目录"
|
||||
@@ -83,6 +83,9 @@ class MaiZoneRefactoredPlugin(BasePlugin):
|
||||
"http_fallback_port": ConfigField(type=int, default=9999, description="备用Cookie获取服务的端口"),
|
||||
"napcat_token": ConfigField(type=str, default="", description="Napcat服务的认证Token(可选)"),
|
||||
},
|
||||
"cross_context": {
|
||||
"user_id": ConfigField(type=str, default="", description="用于获取互通上下文的目标用户QQ号"),
|
||||
},
|
||||
}
|
||||
|
||||
permission_nodes: list[PermissionNodeField] = [
|
||||
|
||||
@@ -274,7 +274,7 @@ class ContentService:
|
||||
await asyncio.sleep(2)
|
||||
return None
|
||||
|
||||
async def generate_story_from_activity(self, activity: str) -> str:
|
||||
async def generate_story_from_activity(self, activity: str, context: str | None = None) -> str:
|
||||
"""
|
||||
根据当前的日程活动生成一条QQ空间说说。
|
||||
|
||||
@@ -350,6 +350,9 @@ class ContentService:
|
||||
- 鼓励你多描述日常生活相关的生产活动和消遣,展现真实,而不是浮在空中。
|
||||
"""
|
||||
|
||||
# 如果有上下文,则加入到prompt中
|
||||
if context:
|
||||
prompt += f"\n作为参考,这里有一些最近的聊天记录:\n---\n{context}\n---"
|
||||
# 添加历史记录避免重复
|
||||
prompt += "\n\n---历史说说记录---\n"
|
||||
history_block = await get_send_history(qq_account)
|
||||
|
||||
@@ -4,13 +4,17 @@
|
||||
"""
|
||||
|
||||
import base64
|
||||
import random
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
|
||||
import aiofiles
|
||||
import aiohttp
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import llm_api, config_api
|
||||
|
||||
logger = get_logger("MaiZone.ImageService")
|
||||
|
||||
@@ -40,7 +44,14 @@ class ImageService:
|
||||
api_key = str(self.get_config("models.siliconflow_apikey", ""))
|
||||
image_dir = str(self.get_config("send.image_directory", "./data/plugins/maizone_refactored/images"))
|
||||
image_num_raw = self.get_config("send.ai_image_number", 1)
|
||||
image_num = int(image_num_raw if image_num_raw is not None else 1)
|
||||
|
||||
# 安全地处理图片数量配置,并限制在API允许的范围内
|
||||
try:
|
||||
image_num = int(image_num_raw) if image_num_raw not in [None, ""] else 1
|
||||
image_num = max(1, min(image_num, 4)) # SiliconFlow API限制:1 <= batch_size <= 4
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"无效的图片数量配置: {image_num_raw},使用默认值1")
|
||||
image_num = 1
|
||||
|
||||
if not enable_ai_image:
|
||||
return True # 未启用AI配图,视为成功
|
||||
@@ -52,49 +63,191 @@ class ImageService:
|
||||
# 确保图片目录存在
|
||||
Path(image_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 生成图片提示词
|
||||
image_prompt = await self._generate_image_prompt(story)
|
||||
if not image_prompt:
|
||||
logger.error("生成图片提示词失败")
|
||||
return False
|
||||
|
||||
logger.info(f"正在为说说生成 {image_num} 张AI配图...")
|
||||
return await self._call_siliconflow_api(api_key, story, image_dir, image_num)
|
||||
return await self._call_siliconflow_api(api_key, image_prompt, image_dir, image_num)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理AI配图时发生异常: {e}")
|
||||
return False
|
||||
|
||||
async def _call_siliconflow_api(self, api_key: str, story: str, image_dir: str, batch_size: int) -> bool:
|
||||
async def _generate_image_prompt(self, story_content: str) -> str:
|
||||
"""
|
||||
使用LLM生成图片提示词,基于说说内容。
|
||||
|
||||
:param story_content: 说说内容
|
||||
:return: 生成的图片提示词,失败时返回空字符串
|
||||
"""
|
||||
try:
|
||||
# 获取配置
|
||||
identity = config_api.get_global_config("personality.identity", "年龄为19岁,是女孩子,身高为160cm,黑色短发")
|
||||
enable_ref = bool(self.get_config("models.image_ref", True))
|
||||
|
||||
# 构建提示词
|
||||
prompt = f"""
|
||||
请根据以下QQ空间说说内容配图,并构建生成配图的风格和prompt。
|
||||
说说主人信息:'{identity}'。
|
||||
说说内容:'{story_content}'。
|
||||
请注意:仅回复用于生成图片的prompt,不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。
|
||||
"""
|
||||
if enable_ref:
|
||||
prompt += "说说主人的人设参考图片将随同提示词一起发送给生图AI,可使用'in the style of'或'根据图中人物'等描述引导生成风格"
|
||||
|
||||
# 获取模型配置
|
||||
models = llm_api.get_available_models()
|
||||
prompt_model = self.get_config("models.text_model", "replyer")
|
||||
model_config = models.get(prompt_model)
|
||||
|
||||
if not model_config:
|
||||
logger.error(f"找不到模型配置: {prompt_model}")
|
||||
return ""
|
||||
|
||||
# 调用LLM生成提示词
|
||||
logger.info("正在生成图片提示词...")
|
||||
success, image_prompt, reasoning, model_name = await llm_api.generate_with_model(
|
||||
prompt=prompt,
|
||||
model_config=model_config,
|
||||
request_type="story.generate",
|
||||
temperature=0.3,
|
||||
max_tokens=1000
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(f'成功生成图片提示词: {image_prompt}')
|
||||
return image_prompt
|
||||
else:
|
||||
logger.error('生成图片提示词失败')
|
||||
return ""
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成图片提示词时发生异常: {e}")
|
||||
return ""
|
||||
|
||||
async def _call_siliconflow_api(self, api_key: str, image_prompt: str, image_dir: str, batch_size: int) -> bool:
|
||||
"""
|
||||
调用硅基流动(SiliconFlow)的API来生成图片。
|
||||
|
||||
:param api_key: SiliconFlow API密钥。
|
||||
:param story: 用于生成图片的文本内容(说说)。
|
||||
:param image_prompt: 用于生成图片的提示词。
|
||||
:param image_dir: 图片保存目录。
|
||||
:param batch_size: 生成图片的数量。
|
||||
:param batch_size: 生成图片的数量(1-4)。
|
||||
:return: API调用是否成功。
|
||||
"""
|
||||
url = "https://api.siliconflow.cn/v1/images/generations"
|
||||
headers = {
|
||||
"accept": "application/json",
|
||||
"authorization": f"Bearer {api_key}",
|
||||
"content-type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
payload = {"prompt": story, "n": batch_size, "response_format": "b64_json", "style": "cinematic-default"}
|
||||
data = {
|
||||
"model": "Kwai-Kolors/Kolors",
|
||||
"prompt": image_prompt,
|
||||
"negative_prompt": "lowres, bad anatomy, bad hands, text, error, cropped, worst quality, low quality, "
|
||||
"normal quality, jpeg artifacts, signature, watermark, username, blurry",
|
||||
"seed": random.randint(1, 9999999999),
|
||||
"batch_size": batch_size,
|
||||
}
|
||||
|
||||
# 检查是否启用参考图片
|
||||
enable_ref = bool(self.get_config("models.image_ref", True))
|
||||
if enable_ref:
|
||||
# 修复:使用Path对象正确获取父目录
|
||||
parent_dir = Path(image_dir).parent
|
||||
ref_images = list(parent_dir.glob("done_ref.*"))
|
||||
if ref_images:
|
||||
try:
|
||||
image = Image.open(ref_images[0])
|
||||
encoded_image = self._encode_image_to_base64(image)
|
||||
if encoded_image: # 只有在编码成功时才添加
|
||||
data["image"] = encoded_image
|
||||
logger.info("已添加参考图片到生成参数")
|
||||
except Exception as e:
|
||||
logger.warning(f"加载参考图片失败: {e}")
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, json=payload, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
for i, img_data in enumerate(data.get("data", [])):
|
||||
b64_json = img_data.get("b64_json")
|
||||
if b64_json:
|
||||
image_bytes = base64.b64decode(b64_json)
|
||||
file_path = Path(image_dir) / f"image_{i + 1}.png"
|
||||
async with aiofiles.open(file_path, "wb") as f:
|
||||
await f.write(image_bytes)
|
||||
logger.info(f"成功保存AI图片到: {file_path}")
|
||||
return True
|
||||
else:
|
||||
# 发送生成请求
|
||||
async with session.post(url, json=data, headers=headers) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"AI生图API请求失败,状态码: {response.status}, 错误信息: {error_text}")
|
||||
logger.error(f'生成图片出错,错误码[{response.status}]')
|
||||
logger.error(f'错误响应: {error_text}')
|
||||
return False
|
||||
|
||||
json_data = await response.json()
|
||||
image_urls = [img["url"] for img in json_data["images"]]
|
||||
|
||||
success_count = 0
|
||||
# 下载并保存图片
|
||||
for i, img_url in enumerate(image_urls):
|
||||
try:
|
||||
# 下载图片
|
||||
async with session.get(img_url) as img_response:
|
||||
img_response.raise_for_status()
|
||||
img_data = await img_response.read()
|
||||
|
||||
# 处理图片
|
||||
try:
|
||||
image = Image.open(BytesIO(img_data))
|
||||
|
||||
# 保存图片为PNG格式(确保兼容性)
|
||||
filename = f"image_{i}.png"
|
||||
save_path = Path(image_dir) / filename
|
||||
|
||||
# 转换为RGB模式如果必要(避免RGBA等模式的问题)
|
||||
if image.mode in ('RGBA', 'LA', 'P'):
|
||||
background = Image.new('RGB', image.size, (255, 255, 255))
|
||||
background.paste(image, mask=image.split()[-1] if image.mode == 'RGBA' else None)
|
||||
image = background
|
||||
|
||||
image.save(save_path, format='PNG')
|
||||
logger.info(f"图片已保存至: {save_path}")
|
||||
success_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理图片失败: {str(e)}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"下载第{i+1}张图片失败: {str(e)}")
|
||||
continue
|
||||
|
||||
# 只要至少有一张图片成功就返回True
|
||||
return success_count > 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"调用AI生图API时发生异常: {e}")
|
||||
return False
|
||||
|
||||
def _encode_image_to_base64(self, img: Image.Image) -> str:
|
||||
"""
|
||||
将PIL.Image对象编码为base64 data URL
|
||||
|
||||
:param img: PIL图片对象
|
||||
:return: base64 data URL字符串,失败时返回空字符串
|
||||
"""
|
||||
try:
|
||||
# 强制转换为PNG格式,因为SiliconFlow API要求data:image/png
|
||||
buffer = BytesIO()
|
||||
|
||||
# 转换为RGB模式如果必要
|
||||
if img.mode in ('RGBA', 'LA', 'P'):
|
||||
background = Image.new('RGB', img.size, (255, 255, 255))
|
||||
background.paste(img, mask=img.split()[-1] if img.mode == 'RGBA' else None)
|
||||
img = background
|
||||
|
||||
# 保存为PNG格式
|
||||
img.save(buffer, format="PNG")
|
||||
byte_data = buffer.getvalue()
|
||||
|
||||
# Base64编码,使用固定的data:image/png
|
||||
encoded_string = base64.b64encode(byte_data).decode("utf-8")
|
||||
return f"data:image/png;base64,{encoded_string}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"编码图片为base64失败: {e}")
|
||||
return ""
|
||||
@@ -19,7 +19,8 @@ import json5
|
||||
import orjson
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import config_api, cross_context_api, person_api
|
||||
from src.plugin_system.apis import config_api, person_api
|
||||
from src.plugin_system.apis import cross_context_api
|
||||
|
||||
from .content_service import ContentService
|
||||
from .cookie_service import CookieService
|
||||
@@ -60,13 +61,32 @@ class QZoneService:
|
||||
self.processing_comments = set()
|
||||
|
||||
# --- Public Methods (High-Level Business Logic) ---
|
||||
async def _get_cross_context(self) -> str:
|
||||
"""获取并构建跨群聊上下文"""
|
||||
context = ""
|
||||
user_id = self.get_config("cross_context.user_id")
|
||||
|
||||
if user_id:
|
||||
logger.info(f"检测到互通组用户ID: {user_id},准备获取上下文...")
|
||||
try:
|
||||
context = await cross_context_api.build_cross_context_for_user(
|
||||
user_id=user_id,
|
||||
platform="QQ", # 硬编码为QQ
|
||||
limit_per_stream=10,
|
||||
stream_limit=3,
|
||||
)
|
||||
if context:
|
||||
logger.info("成功获取到互通组上下文。")
|
||||
else:
|
||||
logger.info("未获取到有效的互通组上下文。")
|
||||
except Exception as e:
|
||||
logger.error(f"获取互通组上下文时发生异常: {e}")
|
||||
return context
|
||||
|
||||
async def send_feed(self, topic: str, stream_id: str | None) -> dict[str, Any]:
|
||||
"""发送一条说说"""
|
||||
# --- 获取互通组上下文 ---
|
||||
context = await self._get_intercom_context(stream_id) if stream_id else None
|
||||
|
||||
story = await self.content_service.generate_story(topic, context=context)
|
||||
cross_context = await self._get_cross_context()
|
||||
story = await self.content_service.generate_story(topic, context=cross_context)
|
||||
if not story:
|
||||
return {"success": False, "message": "生成说说内容失败"}
|
||||
|
||||
@@ -91,7 +111,8 @@ class QZoneService:
|
||||
|
||||
async def send_feed_from_activity(self, activity: str) -> dict[str, Any]:
|
||||
"""根据日程活动发送一条说说"""
|
||||
story = await self.content_service.generate_story_from_activity(activity)
|
||||
cross_context = await self._get_cross_context()
|
||||
story = await self.content_service.generate_story_from_activity(activity, context=cross_context)
|
||||
if not story:
|
||||
return {"success": False, "message": "根据活动生成说说内容失败"}
|
||||
|
||||
@@ -302,12 +323,6 @@ class QZoneService:
|
||||
|
||||
# --- Internal Helper Methods ---
|
||||
|
||||
async def _get_intercom_context(self, stream_id: str) -> str | None:
|
||||
"""
|
||||
获取互通组的聊天上下文。
|
||||
"""
|
||||
# 实际的逻辑已迁移到 cross_context_api
|
||||
return await cross_context_api.get_intercom_group_context("maizone_context_group")
|
||||
|
||||
async def _reply_to_own_feed_comments(self, feed: dict, api_client: dict):
|
||||
"""处理对自己说说的评论并进行回复"""
|
||||
|
||||
Reference in New Issue
Block a user