This commit is contained in:
tt-P607
2025-11-29 02:06:33 +08:00
7 changed files with 352 additions and 289 deletions

View File

@@ -702,28 +702,6 @@ class WebSearchConfig(ValidatedConfigBase):
search_strategy: Literal["fallback", "single", "parallel"] = Field(default="single", description="搜索策略")
class ContextGroup(ValidatedConfigBase):
"""
上下文共享组配置
定义了一个聊天上下文的共享范围和规则。
"""
name: str = Field(..., description="共享组的名称,用于唯一标识一个共享组")
mode: Literal["whitelist", "blacklist"] = Field(
default="whitelist",
description="共享模式。'whitelist'表示仅共享chat_ids中列出的聊天'blacklist'表示共享除chat_ids中列出的所有聊天。",
)
default_limit: int = Field(
default=5,
description="'blacklist'模式下,对于未明确指定数量的聊天,默认获取的消息条数。",
)
chat_ids: list[list[str]] = Field(
...,
description='定义组内成员的列表。格式为 [["type", "id", "limit"(可选)]]。type为"group""private"id为群号或用户IDlimit为可选的消息条数。',
)
class MaizoneContextGroup(ValidatedConfigBase):
"""QQ空间专用互通组配置"""
@@ -739,8 +717,6 @@ class CrossContextConfig(ValidatedConfigBase):
enable: bool = Field(default=False, description="是否启用跨群聊上下文共享功能")
# --- Normal模式: 共享组配置 ---
groups: list[ContextGroup] = Field(default_factory=list, description="上下文共享组列表")
# --- S4U模式: 用户中心上下文检索 ---
s4u_mode: Literal["whitelist", "blacklist"] = Field(
default="whitelist",

View File

@@ -392,7 +392,7 @@ MoFox_Bot(第三方修改版)
全部组件已成功启动!
=========================================================
🌐 项目地址: https://github.com/MoFox-Studio/MoFox-Core
🏠 官方项目: https://github.com/MaiM-with-u/MaiBot
🏠 官方项目: https://github.com/Mai-with-u/MaiBot
=========================================================
这是基于原版MMC的社区改版包含增强功能和优化(同时也有更多的'特性')
=========================================================

View File

@@ -4,6 +4,7 @@
import time
from typing import Any, TYPE_CHECKING
from src.common.message_repository import find_messages
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.chat_message_builder import (
@@ -13,7 +14,6 @@ from src.chat.utils.chat_message_builder import (
from src.common.logger import get_logger
from src.common.message_repository import get_user_messages_from_streams
from src.config.config import global_config
from src.config.official_configs import ContextGroup
if TYPE_CHECKING:
from src.chat.message_receive.chat_stream import ChatStream
@@ -21,111 +21,6 @@ if TYPE_CHECKING:
logger = get_logger("cross_context_api")
async def get_context_group(chat_id: str) -> ContextGroup | None:
"""
获取当前聊天所在的共享组
"""
current_stream = await get_chat_manager().get_stream(chat_id)
if not current_stream:
return None
is_group = current_stream.group_info is not None
if not is_group and not current_stream.user_info:
return None
if is_group:
assert current_stream.group_info is not None
current_chat_raw_id = current_stream.group_info.group_id
elif current_stream.user_info:
current_chat_raw_id = current_stream.user_info.user_id
else:
return None
current_type = "group" if is_group else "private"
for group in global_config.cross_context.groups:
for chat_info in group.chat_ids:
if len(chat_info) >= 2:
chat_type, chat_raw_id = chat_info[0], chat_info[1]
if chat_type == current_type and str(chat_raw_id) == str(current_chat_raw_id):
# 排除maizone专用组
if group.name == "maizone_context_group":
continue
return group
return None
async def build_cross_context_normal(chat_stream: "ChatStream", context_group: ContextGroup) -> str:
"""
构建跨群聊/私聊上下文 (Normal模式)。
根据共享组的配置(白名单或黑名单模式),获取相关聊天的近期消息,并格式化为字符串。
Args:
chat_stream: 当前的聊天流对象。
context_group: 当前聊天所在的上下文共享组配置。
Returns:
一个包含格式化后的跨上下文消息的字符串,如果无消息则为空字符串。
"""
cross_context_messages = []
chat_manager = get_chat_manager()
chat_infos_to_fetch = []
if context_group.mode == "blacklist":
# 黑名单模式:获取所有聊天,并排除在 chat_ids 中定义过的聊天
blacklisted_ids = {tuple(info[:2]) for info in context_group.chat_ids}
for stream_id, stream in chat_manager.streams.items():
is_group = stream.group_info is not None
chat_type = "group" if is_group else "private"
# 安全地获取 raw_id
if is_group and stream.group_info:
raw_id = stream.group_info.group_id
elif not is_group and stream.user_info:
raw_id = stream.user_info.user_id
else:
continue # 如果缺少关键信息则跳过
# 如果当前聊天不在黑名单中,则添加到待获取列表
if (chat_type, str(raw_id)) not in blacklisted_ids:
chat_infos_to_fetch.append([chat_type, str(raw_id), str(context_group.default_limit)])
else:
# 白名单模式:直接使用配置中定义的 chat_ids
chat_infos_to_fetch = context_group.chat_ids
# 遍历待获取列表,抓取并格式化消息
for chat_info in chat_infos_to_fetch:
chat_type, chat_raw_id, limit_str = (
chat_info[0],
chat_info[1],
chat_info[2] if len(chat_info) > 2 else str(context_group.default_limit),
)
limit = int(limit_str)
is_group = chat_type == "group"
stream_id = chat_manager.get_stream_id(chat_stream.platform, chat_raw_id, is_group=is_group)
if not stream_id or stream_id == chat_stream.stream_id:
continue
try:
messages = await get_raw_msg_before_timestamp_with_chat(
chat_id=stream_id,
timestamp=time.time(),
limit=limit,
)
if messages:
chat_name = await chat_manager.get_stream_name(stream_id) or chat_raw_id
formatted_messages, _ = await build_readable_messages_with_id(messages, timestamp_mode="relative")
cross_context_messages.append(f'[以下是来自"{chat_name}"的近期消息]\n{formatted_messages}')
except Exception as e:
logger.error(f"获取聊天 {chat_raw_id} 的消息失败: {e}")
continue
if not cross_context_messages:
return ""
return "# 跨上下文参考\n" + "\n\n".join(cross_context_messages) + "\n"
async def build_cross_context_s4u(
chat_stream: "ChatStream",
target_user_info: dict[str, Any] | None,
@@ -134,32 +29,55 @@ async def build_cross_context_s4u(
构建跨群聊/私聊上下文 (S4U模式)。
优先展示目标用户的私聊记录(双向),其次按时间顺序展示其他群聊记录。
"""
# 记录S4U上下文构建开始
logger.debug("[S4U] Starting S4U context build.")
# 检查全局配置是否存在且包含必要部分
if not global_config or not global_config.cross_context or not global_config.bot:
logger.error("全局配置尚未初始化或缺少关键配置无法构建S4U上下文。")
return ""
# 获取跨上下文配置
cross_context_config = global_config.cross_context
# 检查目标用户信息和用户ID是否存在
if not target_user_info or not (user_id := target_user_info.get("user_id")):
logger.warning(f"[S4U] Failed: target_user_info ({target_user_info}) or user_id is missing.")
return ""
# 记录目标用户ID
logger.debug(f"[S4U] Target user ID: {user_id}")
# 获取聊天管理器实例
chat_manager = get_chat_manager()
private_context_block = ""
group_context_blocks = []
# --- 1. 优先处理私聊上下文 ---
# 获取与目标用户的私聊流ID
private_stream_id = chat_manager.get_stream_id(chat_stream.platform, user_id, is_group=False)
# 如果存在私聊流且不是当前聊天流
if private_stream_id and private_stream_id != chat_stream.stream_id:
logger.debug(f"[S4U] Found private chat with target user: {private_stream_id}")
try:
# 定义需要获取消息的用户ID列表目标用户和机器人自己
user_ids_to_fetch = [str(user_id), str(global_config.bot.qq_account)]
# 从指定私聊流中获取双方的消息
messages_by_stream = await get_user_messages_from_streams(
user_ids=user_ids_to_fetch,
stream_ids=[private_stream_id],
timestamp_after=time.time() - (3 * 24 * 60 * 60), # 3天
timestamp_after=time.time() - (3 * 24 * 60 * 60), # 最近3天的消息
limit_per_stream=cross_context_config.s4u_limit,
)
# 如果获取到了私聊消息
if private_messages := messages_by_stream.get(private_stream_id):
chat_name = await chat_manager.get_stream_name(private_stream_id) or "私聊"
title = f'[以下是您与"{chat_name}"的近期私聊记录]\n'
# 格式化消息为可读字符串
formatted, _ = await build_readable_messages_with_id(private_messages, timestamp_mode="relative")
private_context_block = f"{title}{formatted}"
logger.debug(f"[S4U] Generated private context block of length {len(private_context_block)}.")
@@ -168,18 +86,23 @@ async def build_cross_context_s4u(
# --- 2. 处理其他群聊上下文 ---
streams_to_scan = []
# 根据全局S4U配置确定要扫描的聊天范围
# 根据S4U配置模式白名单/黑名单)确定要扫描的聊天范围
if cross_context_config.s4u_mode == "whitelist":
# 白名单模式:只扫描在白名单中的聊天
for chat_str in cross_context_config.s4u_whitelist_chats:
try:
platform, chat_type, chat_raw_id = chat_str.split(":")
is_group = chat_type == "group"
stream_id = chat_manager.get_stream_id(platform, chat_raw_id, is_group=is_group)
# 排除当前聊和私聊
if stream_id and stream_id != chat_stream.stream_id and stream_id != private_stream_id:
streams_to_scan.append(stream_id)
except ValueError:
logger.warning(f"无效的S4U白名单格式: {chat_str}")
else: # blacklist mode
else: # 黑名单模式
# 黑名单模式:扫描所有聊天,除了黑名单中的和私聊
blacklisted_streams = {private_stream_id}
for chat_str in cross_context_config.s4u_blacklist_chats:
try:
@@ -190,6 +113,8 @@ async def build_cross_context_s4u(
blacklisted_streams.add(stream_id)
except ValueError:
logger.warning(f"无效的S4U黑名单格式: {chat_str}")
# 将不在黑名单中的流添加到扫描列表
streams_to_scan.extend(
stream_id for stream_id in chat_manager.streams
if stream_id != chat_stream.stream_id and stream_id not in blacklisted_streams
@@ -197,12 +122,113 @@ async def build_cross_context_s4u(
logger.debug(f"[S4U] Found {len(streams_to_scan)} group streams to scan.")
if streams_to_scan:
# 获取目标用户在这些群聊中的消息
messages_by_stream = await get_user_messages_from_streams(
user_ids=[str(user_id)],
stream_ids=streams_to_scan,
timestamp_after=time.time() - (3 * 24 * 60 * 60), # 最近3天
limit_per_stream=cross_context_config.s4u_limit,
)
all_group_messages = []
# 将所有群聊消息聚合,并附带最新时间戳
for stream_id, user_messages in messages_by_stream.items():
if user_messages:
latest_timestamp = max(msg.get("time", 0) for msg in user_messages)
all_group_messages.append(
{"stream_id": stream_id, "messages": user_messages, "latest_timestamp": latest_timestamp}
)
# 按最新消息时间倒序排序
all_group_messages.sort(key=lambda x: x["latest_timestamp"], reverse=True)
# 计算群聊上下文的额度
remaining_limit = cross_context_config.s4u_stream_limit - (1 if private_context_block else 0)
limited_group_messages = all_group_messages[:remaining_limit]
# 格式化每个群聊的消息
for item in limited_group_messages:
try:
chat_name = await chat_manager.get_stream_name(item["stream_id"]) or "未知群聊"
user_name = target_user_info.get("person_name") or target_user_info.get("user_nickname") or user_id
title = f'[以下是"{user_name}""{chat_name}"的近期发言]\n'
formatted, _ = await build_readable_messages_with_id(item["messages"], timestamp_mode="relative")
group_context_blocks.append(f"{title}{formatted}")
except Exception as e:
logger.error(f"S4U模式下格式化群聊消息失败 (stream: {item['stream_id']}): {e}")
# --- 3. 组合最终上下文 ---
# 如果没有任何上下文内容,则返回空
if not private_context_block and not group_context_blocks:
logger.debug("[S4U] No context blocks were generated. Returning empty string.")
return ""
final_context_parts = []
# 添加私聊部分
if private_context_block:
final_context_parts.append(private_context_block)
# 添加群聊部分
if group_context_blocks:
group_context_str = "\n\n".join(group_context_blocks)
final_context_parts.append(f"### 其他群聊中的聊天记录\n{group_context_str}")
# 组合最终的上下文字符串
final_context = "\n\n".join(final_context_parts) + "\n"
logger.debug(f"[S4U] Successfully generated S4U context. Total length: {len(final_context)}.")
return final_context
async def build_cross_context_for_user(
user_id: str,
platform: str,
limit_per_stream: int,
stream_limit: int,
) -> str:
"""
构建指定用户的跨群聊/私聊上下文简化版API
"""
logger.debug(f"[S4U_SIMPLE] Starting simplified S4U context build for user {user_id} on {platform}.")
if not global_config or not global_config.cross_context or not global_config.bot:
logger.error("全局配置尚未初始化或缺少关键配置无法构建S4U上下文。")
return ""
chat_manager = get_chat_manager()
private_context_block = ""
group_context_blocks = []
# --- 1. 优先处理私聊上下文 ---
private_stream_id = chat_manager.get_stream_id(platform, user_id, is_group=False)
if private_stream_id:
try:
user_ids_to_fetch = [str(user_id), str(global_config.bot.qq_account)]
messages_by_stream = await get_user_messages_from_streams(
user_ids=user_ids_to_fetch,
stream_ids=[private_stream_id],
timestamp_after=time.time() - (3 * 24 * 60 * 60),
limit_per_stream=limit_per_stream,
)
if private_messages := messages_by_stream.get(private_stream_id):
chat_name = await chat_manager.get_stream_name(private_stream_id) or "私聊"
title = f'[以下是您与"{chat_name}"的近期私聊记录]\n'
formatted, _ = await build_readable_messages_with_id(private_messages, timestamp_mode="relative")
private_context_block = f"{title}{formatted}"
except Exception as e:
logger.error(f"[S4U_SIMPLE] 处理私聊记录失败: {e}")
# --- 2. 处理其他群聊上下文 ---
streams_to_scan = [
stream_id for stream_id in chat_manager.streams
if stream_id != private_stream_id
]
if streams_to_scan:
messages_by_stream = await get_user_messages_from_streams(
user_ids=[str(user_id)],
stream_ids=streams_to_scan,
timestamp_after=time.time() - (3 * 24 * 60 * 60),
limit_per_stream=cross_context_config.s4u_limit,
limit_per_stream=limit_per_stream,
)
all_group_messages = []
@@ -215,23 +241,21 @@ async def build_cross_context_s4u(
all_group_messages.sort(key=lambda x: x["latest_timestamp"], reverse=True)
# 计算群聊的额度
remaining_limit = cross_context_config.s4u_stream_limit - (1 if private_context_block else 0)
remaining_limit = stream_limit - (1 if private_context_block else 0)
limited_group_messages = all_group_messages[:remaining_limit]
for item in limited_group_messages:
try:
chat_name = await chat_manager.get_stream_name(item["stream_id"]) or "未知群聊"
user_name = target_user_info.get("person_name") or target_user_info.get("user_nickname") or user_id
user_name = user_id # 简化处理
title = f'[以下是"{user_name}""{chat_name}"的近期发言]\n'
formatted, _ = await build_readable_messages_with_id(item["messages"], timestamp_mode="relative")
group_context_blocks.append(f"{title}{formatted}")
except Exception as e:
logger.error(f"S4U模式下格式化群聊消息失败 (stream: {item['stream_id']}): {e}")
logger.error(f"[S4U_SIMPLE] 格式化群聊消息失败 (stream: {item['stream_id']}): {e}")
# --- 3. 组合最终上下文 ---
if not private_context_block and not group_context_blocks:
logger.debug("[S4U] No context blocks were generated. Returning empty string.")
return ""
final_context_parts = []
@@ -242,116 +266,5 @@ async def build_cross_context_s4u(
final_context_parts.append(f"### 其他群聊中的聊天记录\n{group_context_str}")
final_context = "\n\n".join(final_context_parts) + "\n"
logger.debug(f"[S4U] Successfully generated S4U context. Total length: {len(final_context)}.")
logger.debug(f"[S4U_SIMPLE] Successfully generated context for user {user_id}. Total length: {len(final_context)}.")
return final_context
async def get_intercom_group_context(group_name: str, limit_per_chat: int = 20, total_limit: int = 100) -> str | None:
"""
根据互通组的名称,构建该组的聊天上下文。
支持黑白名单模式,并以分块形式返回每个聊天的消息。
Args:
group_name: 互通组的名称。
limit_per_chat: 每个聊天最多获取的消息条数。
total_limit: 返回的总消息条数上限。
Returns:
如果找到匹配的组并获取到消息,则返回一个包含聊天记录的字符串;否则返回 None。
"""
cross_context_config = global_config.cross_context
if not (cross_context_config and cross_context_config.enable):
return None
target_group = next((g for g in cross_context_config.groups if g.name == group_name), None)
if not target_group:
logger.error(f"在 cross_context 配置中未找到名为 '{group_name}' 的组。")
return None
chat_manager = get_chat_manager()
# 1. 根据黑白名单模式确定要处理的聊天列表
chat_infos_to_fetch = []
if target_group.mode == "blacklist":
blacklisted_ids = {tuple(info[:2]) for info in target_group.chat_ids}
for stream in chat_manager.streams.values():
is_group = stream.group_info is not None
chat_type = "group" if is_group else "private"
if is_group and stream.group_info:
raw_id = stream.group_info.group_id
elif not is_group and stream.user_info:
raw_id = stream.user_info.user_id
else:
continue
if (chat_type, str(raw_id)) not in blacklisted_ids:
chat_infos_to_fetch.append([chat_type, str(raw_id)])
else: # whitelist mode
chat_infos_to_fetch = target_group.chat_ids
# 2. 获取所有相关消息
all_messages = []
for chat_info in chat_infos_to_fetch:
chat_type, chat_raw_id = chat_info[0], chat_info[1]
is_group = chat_type == "group"
# 查找 stream
found_stream = None
for stream in chat_manager.streams.values():
if is_group:
if stream.group_info and stream.group_info.group_id == chat_raw_id:
found_stream = stream
break
else: # private
if stream.user_info and stream.user_info.user_id == chat_raw_id and not stream.group_info:
found_stream = stream
break
if not found_stream:
logger.warning(f"在已加载的聊天流中找不到ID为 {chat_raw_id} 的聊天。")
continue
stream_id = found_stream.stream_id
try:
messages = await get_raw_msg_before_timestamp_with_chat(
chat_id=stream_id,
timestamp=time.time(),
limit=limit_per_chat,
)
if messages:
# 为每条消息附加 stream_id 以便后续分组
for msg in messages:
msg["_stream_id"] = stream_id
all_messages.extend(messages)
except Exception as e:
logger.error(f"获取聊天 {chat_raw_id} 的消息失败: {e}")
if not all_messages:
return None
# 3. 应用总数限制
all_messages.sort(key=lambda x: x.get("time", 0))
if len(all_messages) > total_limit:
all_messages = all_messages[-total_limit:]
# 4. 按聊天分组并格式化
messages_by_stream = {}
for msg in all_messages:
stream_id = msg.get("_stream_id")
if stream_id not in messages_by_stream:
messages_by_stream[stream_id] = []
messages_by_stream[stream_id].append(msg)
cross_context_messages = []
for stream_id, messages in messages_by_stream.items():
if messages:
chat_name = await chat_manager.get_stream_name(stream_id) or "未知聊天"
formatted_messages, _ = await build_readable_messages_with_id(messages, timestamp_mode="relative")
cross_context_messages.append(f'[以下是来自"{chat_name}"的近期消息]\n{formatted_messages}')
if not cross_context_messages:
return None
return "# 跨上下文参考\n" + "\n\n".join(cross_context_messages) + "\n"

View File

@@ -51,7 +51,7 @@ class MaiZoneRefactoredPlugin(BasePlugin):
"enable_image": ConfigField(type=bool, default=False, description="是否启用说说配图"),
"enable_ai_image": ConfigField(type=bool, default=False, description="是否启用AI生成配图"),
"enable_reply": ConfigField(type=bool, default=True, description="完成后是否回复"),
"ai_image_number": ConfigField(type=int, default=1, description="AI生成图片数量"),
"ai_image_number": ConfigField(type=int, default=1, description="AI生成图片数量1-4张"),
"image_number": ConfigField(type=int, default=1, description="本地配图数量1-9张"),
"image_directory": ConfigField(
type=str, default=(Path(__file__).parent / "images").as_posix(), description="图片存储目录"
@@ -83,6 +83,9 @@ class MaiZoneRefactoredPlugin(BasePlugin):
"http_fallback_port": ConfigField(type=int, default=9999, description="备用Cookie获取服务的端口"),
"napcat_token": ConfigField(type=str, default="", description="Napcat服务的认证Token可选"),
},
"cross_context": {
"user_id": ConfigField(type=str, default="", description="用于获取互通上下文的目标用户QQ号"),
},
}
permission_nodes: list[PermissionNodeField] = [

View File

@@ -274,7 +274,7 @@ class ContentService:
await asyncio.sleep(2)
return None
async def generate_story_from_activity(self, activity: str) -> str:
async def generate_story_from_activity(self, activity: str, context: str | None = None) -> str:
"""
根据当前的日程活动生成一条QQ空间说说。
@@ -350,6 +350,9 @@ class ContentService:
- 鼓励你多描述日常生活相关的生产活动和消遣,展现真实,而不是浮在空中。
"""
# 如果有上下文则加入到prompt中
if context:
prompt += f"\n作为参考,这里有一些最近的聊天记录:\n---\n{context}\n---"
# 添加历史记录避免重复
prompt += "\n\n---历史说说记录---\n"
history_block = await get_send_history(qq_account)

View File

@@ -4,13 +4,17 @@
"""
import base64
import random
from collections.abc import Callable
from pathlib import Path
from io import BytesIO
from PIL import Image
import aiofiles
import aiohttp
from src.common.logger import get_logger
from src.plugin_system.apis import llm_api, config_api
logger = get_logger("MaiZone.ImageService")
@@ -40,7 +44,14 @@ class ImageService:
api_key = str(self.get_config("models.siliconflow_apikey", ""))
image_dir = str(self.get_config("send.image_directory", "./data/plugins/maizone_refactored/images"))
image_num_raw = self.get_config("send.ai_image_number", 1)
image_num = int(image_num_raw if image_num_raw is not None else 1)
# 安全地处理图片数量配置并限制在API允许的范围内
try:
image_num = int(image_num_raw) if image_num_raw not in [None, ""] else 1
image_num = max(1, min(image_num, 4)) # SiliconFlow API限制1 <= batch_size <= 4
except (ValueError, TypeError):
logger.warning(f"无效的图片数量配置: {image_num_raw}使用默认值1")
image_num = 1
if not enable_ai_image:
return True # 未启用AI配图视为成功
@@ -52,49 +63,191 @@ class ImageService:
# 确保图片目录存在
Path(image_dir).mkdir(parents=True, exist_ok=True)
# 生成图片提示词
image_prompt = await self._generate_image_prompt(story)
if not image_prompt:
logger.error("生成图片提示词失败")
return False
logger.info(f"正在为说说生成 {image_num} 张AI配图...")
return await self._call_siliconflow_api(api_key, story, image_dir, image_num)
return await self._call_siliconflow_api(api_key, image_prompt, image_dir, image_num)
except Exception as e:
logger.error(f"处理AI配图时发生异常: {e}")
return False
async def _call_siliconflow_api(self, api_key: str, story: str, image_dir: str, batch_size: int) -> bool:
async def _generate_image_prompt(self, story_content: str) -> str:
"""
使用LLM生成图片提示词基于说说内容。
:param story_content: 说说内容
:return: 生成的图片提示词,失败时返回空字符串
"""
try:
# 获取配置
identity = config_api.get_global_config("personality.identity", "年龄为19岁,是女孩子,身高为160cm,黑色短发")
enable_ref = bool(self.get_config("models.image_ref", True))
# 构建提示词
prompt = f"""
请根据以下QQ空间说说内容配图并构建生成配图的风格和prompt。
说说主人信息:'{identity}'
说说内容:'{story_content}'
请注意仅回复用于生成图片的prompt不要输出多余内容(包括前后缀,冒号和引号,括号()表情包at或 @等 )。
"""
if enable_ref:
prompt += "说说主人的人设参考图片将随同提示词一起发送给生图AI可使用'in the style of''根据图中人物'等描述引导生成风格"
# 获取模型配置
models = llm_api.get_available_models()
prompt_model = self.get_config("models.text_model", "replyer")
model_config = models.get(prompt_model)
if not model_config:
logger.error(f"找不到模型配置: {prompt_model}")
return ""
# 调用LLM生成提示词
logger.info("正在生成图片提示词...")
success, image_prompt, reasoning, model_name = await llm_api.generate_with_model(
prompt=prompt,
model_config=model_config,
request_type="story.generate",
temperature=0.3,
max_tokens=1000
)
if success:
logger.info(f'成功生成图片提示词: {image_prompt}')
return image_prompt
else:
logger.error('生成图片提示词失败')
return ""
except Exception as e:
logger.error(f"生成图片提示词时发生异常: {e}")
return ""
async def _call_siliconflow_api(self, api_key: str, image_prompt: str, image_dir: str, batch_size: int) -> bool:
"""
调用硅基流动SiliconFlow的API来生成图片。
:param api_key: SiliconFlow API密钥。
:param story: 用于生成图片的文本内容(说说)
:param image_prompt: 用于生成图片的提示词
:param image_dir: 图片保存目录。
:param batch_size: 生成图片的数量。
:param batch_size: 生成图片的数量1-4
:return: API调用是否成功。
"""
url = "https://api.siliconflow.cn/v1/images/generations"
headers = {
"accept": "application/json",
"authorization": f"Bearer {api_key}",
"content-type": "application/json",
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
payload = {"prompt": story, "n": batch_size, "response_format": "b64_json", "style": "cinematic-default"}
data = {
"model": "Kwai-Kolors/Kolors",
"prompt": image_prompt,
"negative_prompt": "lowres, bad anatomy, bad hands, text, error, cropped, worst quality, low quality, "
"normal quality, jpeg artifacts, signature, watermark, username, blurry",
"seed": random.randint(1, 9999999999),
"batch_size": batch_size,
}
# 检查是否启用参考图片
enable_ref = bool(self.get_config("models.image_ref", True))
if enable_ref:
# 修复使用Path对象正确获取父目录
parent_dir = Path(image_dir).parent
ref_images = list(parent_dir.glob("done_ref.*"))
if ref_images:
try:
image = Image.open(ref_images[0])
encoded_image = self._encode_image_to_base64(image)
if encoded_image: # 只有在编码成功时才添加
data["image"] = encoded_image
logger.info("已添加参考图片到生成参数")
except Exception as e:
logger.warning(f"加载参考图片失败: {e}")
try:
async with aiohttp.ClientSession() as session:
async with session.post(url, json=payload, headers=headers) as response:
if response.status == 200:
data = await response.json()
for i, img_data in enumerate(data.get("data", [])):
b64_json = img_data.get("b64_json")
if b64_json:
image_bytes = base64.b64decode(b64_json)
file_path = Path(image_dir) / f"image_{i + 1}.png"
async with aiofiles.open(file_path, "wb") as f:
await f.write(image_bytes)
logger.info(f"成功保存AI图片到: {file_path}")
return True
else:
# 发送生成请求
async with session.post(url, json=data, headers=headers) as response:
if response.status != 200:
error_text = await response.text()
logger.error(f"AI生图API请求失败状态码: {response.status}, 错误信息: {error_text}")
logger.error(f'生成图片出错,错误码[{response.status}]')
logger.error(f'错误响应: {error_text}')
return False
json_data = await response.json()
image_urls = [img["url"] for img in json_data["images"]]
success_count = 0
# 下载并保存图片
for i, img_url in enumerate(image_urls):
try:
# 下载图片
async with session.get(img_url) as img_response:
img_response.raise_for_status()
img_data = await img_response.read()
# 处理图片
try:
image = Image.open(BytesIO(img_data))
# 保存图片为PNG格式确保兼容性
filename = f"image_{i}.png"
save_path = Path(image_dir) / filename
# 转换为RGB模式如果必要避免RGBA等模式的问题
if image.mode in ('RGBA', 'LA', 'P'):
background = Image.new('RGB', image.size, (255, 255, 255))
background.paste(image, mask=image.split()[-1] if image.mode == 'RGBA' else None)
image = background
image.save(save_path, format='PNG')
logger.info(f"图片已保存至: {save_path}")
success_count += 1
except Exception as e:
logger.error(f"处理图片失败: {str(e)}")
continue
except Exception as e:
logger.error(f"下载第{i+1}张图片失败: {str(e)}")
continue
# 只要至少有一张图片成功就返回True
return success_count > 0
except Exception as e:
logger.error(f"调用AI生图API时发生异常: {e}")
return False
def _encode_image_to_base64(self, img: Image.Image) -> str:
"""
将PIL.Image对象编码为base64 data URL
:param img: PIL图片对象
:return: base64 data URL字符串失败时返回空字符串
"""
try:
# 强制转换为PNG格式因为SiliconFlow API要求data:image/png
buffer = BytesIO()
# 转换为RGB模式如果必要
if img.mode in ('RGBA', 'LA', 'P'):
background = Image.new('RGB', img.size, (255, 255, 255))
background.paste(img, mask=img.split()[-1] if img.mode == 'RGBA' else None)
img = background
# 保存为PNG格式
img.save(buffer, format="PNG")
byte_data = buffer.getvalue()
# Base64编码使用固定的data:image/png
encoded_string = base64.b64encode(byte_data).decode("utf-8")
return f"data:image/png;base64,{encoded_string}"
except Exception as e:
logger.error(f"编码图片为base64失败: {e}")
return ""

View File

@@ -19,7 +19,8 @@ import json5
import orjson
from src.common.logger import get_logger
from src.plugin_system.apis import config_api, cross_context_api, person_api
from src.plugin_system.apis import config_api, person_api
from src.plugin_system.apis import cross_context_api
from .content_service import ContentService
from .cookie_service import CookieService
@@ -60,13 +61,32 @@ class QZoneService:
self.processing_comments = set()
# --- Public Methods (High-Level Business Logic) ---
async def _get_cross_context(self) -> str:
"""获取并构建跨群聊上下文"""
context = ""
user_id = self.get_config("cross_context.user_id")
if user_id:
logger.info(f"检测到互通组用户ID: {user_id},准备获取上下文...")
try:
context = await cross_context_api.build_cross_context_for_user(
user_id=user_id,
platform="QQ", # 硬编码为QQ
limit_per_stream=10,
stream_limit=3,
)
if context:
logger.info("成功获取到互通组上下文。")
else:
logger.info("未获取到有效的互通组上下文。")
except Exception as e:
logger.error(f"获取互通组上下文时发生异常: {e}")
return context
async def send_feed(self, topic: str, stream_id: str | None) -> dict[str, Any]:
"""发送一条说说"""
# --- 获取互通组上下文 ---
context = await self._get_intercom_context(stream_id) if stream_id else None
story = await self.content_service.generate_story(topic, context=context)
cross_context = await self._get_cross_context()
story = await self.content_service.generate_story(topic, context=cross_context)
if not story:
return {"success": False, "message": "生成说说内容失败"}
@@ -91,7 +111,8 @@ class QZoneService:
async def send_feed_from_activity(self, activity: str) -> dict[str, Any]:
"""根据日程活动发送一条说说"""
story = await self.content_service.generate_story_from_activity(activity)
cross_context = await self._get_cross_context()
story = await self.content_service.generate_story_from_activity(activity, context=cross_context)
if not story:
return {"success": False, "message": "根据活动生成说说内容失败"}
@@ -302,12 +323,6 @@ class QZoneService:
# --- Internal Helper Methods ---
async def _get_intercom_context(self, stream_id: str) -> str | None:
"""
获取互通组的聊天上下文。
"""
# 实际的逻辑已迁移到 cross_context_api
return await cross_context_api.get_intercom_group_context("maizone_context_group")
async def _reply_to_own_feed_comments(self, feed: dict, api_client: dict):
"""处理对自己说说的评论并进行回复"""