better:优化了图片展示形式
This commit is contained in:
@@ -137,7 +137,7 @@ class HeartFCMessageReceiver:
|
||||
"""初始化心流处理器,创建消息存储实例"""
|
||||
self.storage = MessageStorage()
|
||||
|
||||
async def process_message(self, message_data: Dict[str, Any]) -> None:
|
||||
async def process_message(self, message: MessageRecv) -> None:
|
||||
"""处理接收到的原始消息数据
|
||||
|
||||
主要流程:
|
||||
@@ -153,7 +153,6 @@ class HeartFCMessageReceiver:
|
||||
message = None
|
||||
try:
|
||||
# 1. 消息解析与初始化
|
||||
message = MessageRecv(message_data)
|
||||
groupinfo = message.message_info.group_info
|
||||
userinfo = message.message_info.user_info
|
||||
messageinfo = message.message_info
|
||||
@@ -166,7 +165,6 @@ class HeartFCMessageReceiver:
|
||||
|
||||
subheartflow = await heartflow.get_or_create_subheartflow(chat.stream_id)
|
||||
message.update_chat_stream(chat)
|
||||
await message.process()
|
||||
|
||||
# 3. 过滤检查
|
||||
if _check_ban_words(message.processed_plain_text, chat, userinfo) or _check_ban_regex(
|
||||
|
||||
@@ -48,8 +48,8 @@ def init_prompt():
|
||||
请以json格式输出,例如:
|
||||
|
||||
{{
|
||||
"用户A": "昵称",
|
||||
"用户B": "对你的态度",
|
||||
"用户A": "ta的昵称",
|
||||
"用户B": "ta对你的态度",
|
||||
"用户C": "你和ta最近做的事",
|
||||
"用户D": "你对ta的印象",
|
||||
}}
|
||||
|
||||
@@ -51,8 +51,6 @@ class ChatBot:
|
||||
async def _process_commands_with_new_system(self, message: MessageRecv):
|
||||
"""使用新插件系统处理命令"""
|
||||
try:
|
||||
if not message.processed_plain_text:
|
||||
await message.process()
|
||||
|
||||
text = message.processed_plain_text
|
||||
|
||||
@@ -179,11 +177,11 @@ class ChatBot:
|
||||
# 禁止PFC,进入普通的心流消息处理逻辑
|
||||
else:
|
||||
logger.debug("进入普通心流私聊处理")
|
||||
await self.heartflow_message_receiver.process_message(message_data)
|
||||
await self.heartflow_message_receiver.process_message(message)
|
||||
# 群聊默认进入心流消息处理逻辑
|
||||
else:
|
||||
logger.debug(f"检测到群聊消息,群ID: {group_info.group_id}")
|
||||
await self.heartflow_message_receiver.process_message(message_data)
|
||||
await self.heartflow_message_receiver.process_message(message)
|
||||
|
||||
if template_group_name:
|
||||
async with global_prompt_manager.async_message_scope(template_group_name):
|
||||
|
||||
@@ -101,15 +101,11 @@ class MessageRecv(Message):
|
||||
Args:
|
||||
message_dict: MessageCQ序列化后的字典
|
||||
"""
|
||||
# print(f"message_dict: {message_dict}")
|
||||
self.message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {}))
|
||||
|
||||
self.message_segment = Seg.from_dict(message_dict.get("message_segment", {}))
|
||||
self.raw_message = message_dict.get("raw_message")
|
||||
|
||||
# 处理消息内容
|
||||
self.processed_plain_text = message_dict.get("processed_plain_text", "") # 初始化为空字符串
|
||||
self.detailed_plain_text = message_dict.get("detailed_plain_text", "") # 初始化为空字符串
|
||||
self.processed_plain_text = message_dict.get("processed_plain_text", "")
|
||||
self.detailed_plain_text = message_dict.get("detailed_plain_text", "")
|
||||
self.is_emoji = False
|
||||
|
||||
def update_chat_stream(self, chat_stream: "ChatStream"):
|
||||
@@ -123,33 +119,36 @@ class MessageRecv(Message):
|
||||
self.processed_plain_text = await self._process_message_segments(self.message_segment)
|
||||
self.detailed_plain_text = self._generate_detailed_text()
|
||||
|
||||
async def _process_single_segment(self, seg: Seg) -> str:
|
||||
async def _process_single_segment(self, segment: Seg) -> str:
|
||||
"""处理单个消息段
|
||||
|
||||
Args:
|
||||
seg: 要处理的消息段
|
||||
segment: 消息段
|
||||
|
||||
Returns:
|
||||
str: 处理后的文本
|
||||
"""
|
||||
try:
|
||||
if seg.type == "text":
|
||||
return seg.data
|
||||
elif seg.type == "image":
|
||||
if segment.type == "text":
|
||||
return segment.data
|
||||
elif segment.type == "image":
|
||||
# 如果是base64图片数据
|
||||
if isinstance(seg.data, str):
|
||||
return await get_image_manager().get_image_description(seg.data)
|
||||
if isinstance(segment.data, str):
|
||||
image_manager = get_image_manager()
|
||||
# print(f"segment.data: {segment.data}")
|
||||
_, processed_text = await image_manager.process_image(segment.data)
|
||||
return processed_text
|
||||
return "[发了一张图片,网卡了加载不出来]"
|
||||
elif seg.type == "emoji":
|
||||
elif segment.type == "emoji":
|
||||
self.is_emoji = True
|
||||
if isinstance(seg.data, str):
|
||||
return await get_image_manager().get_emoji_description(seg.data)
|
||||
if isinstance(segment.data, str):
|
||||
return await get_image_manager().get_emoji_description(segment.data)
|
||||
return "[发了一个表情包,网卡了加载不出来]"
|
||||
else:
|
||||
return f"[{seg.type}:{str(seg.data)}]"
|
||||
return f"[{segment.type}:{str(segment.data)}]"
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息段失败: {str(e)}, 类型: {seg.type}, 数据: {seg.data}")
|
||||
return f"[处理失败的{seg.type}消息]"
|
||||
logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}")
|
||||
return f"[处理失败的{segment.type}消息]"
|
||||
|
||||
def _generate_detailed_text(self) -> str:
|
||||
"""生成详细文本,包含时间和用户信息"""
|
||||
|
||||
@@ -8,6 +8,7 @@ from src.person_info.person_info import PersonInfoManager, get_person_info_manag
|
||||
from src.chat.utils.utils import translate_timestamp_to_human_readable
|
||||
from rich.traceback import install
|
||||
from src.common.database.database_model import ActionRecords
|
||||
from src.common.database.database_model import Images
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
@@ -157,7 +158,9 @@ def _build_readable_messages_internal(
|
||||
merge_messages: bool = False,
|
||||
timestamp_mode: str = "relative",
|
||||
truncate: bool = False,
|
||||
) -> Tuple[str, List[Tuple[float, str, str]]]:
|
||||
pic_id_mapping: Dict[str, str] = None,
|
||||
pic_counter: int = 1,
|
||||
) -> Tuple[str, List[Tuple[float, str, str]], Dict[str, str], int]:
|
||||
"""
|
||||
内部辅助函数,构建可读消息字符串和原始消息详情列表。
|
||||
|
||||
@@ -167,15 +170,41 @@ def _build_readable_messages_internal(
|
||||
merge_messages: 是否合并来自同一用户的连续消息。
|
||||
timestamp_mode: 时间戳的显示模式 ('relative', 'absolute', etc.)。传递给 translate_timestamp_to_human_readable。
|
||||
truncate: 是否根据消息的新旧程度截断过长的消息内容。
|
||||
pic_id_mapping: 图片ID映射字典,如果为None则创建新的
|
||||
pic_counter: 图片计数器起始值
|
||||
|
||||
Returns:
|
||||
包含格式化消息的字符串和原始消息详情列表 (时间戳, 发送者名称, 内容) 的元组。
|
||||
包含格式化消息的字符串、原始消息详情列表、图片映射字典和更新后的计数器的元组。
|
||||
"""
|
||||
if not messages:
|
||||
return "", []
|
||||
return "", [], pic_id_mapping or {}, pic_counter
|
||||
|
||||
message_details_raw: List[Tuple[float, str, str]] = []
|
||||
|
||||
# 使用传入的映射字典,如果没有则创建新的
|
||||
if pic_id_mapping is None:
|
||||
pic_id_mapping = {}
|
||||
current_pic_counter = pic_counter
|
||||
|
||||
def process_pic_ids(content: str) -> str:
|
||||
"""处理内容中的图片ID,将其替换为[图片x]格式"""
|
||||
nonlocal current_pic_counter
|
||||
|
||||
# 匹配 [picid:xxxxx] 格式
|
||||
pic_pattern = r'\[picid:([^\]]+)\]'
|
||||
|
||||
def replace_pic_id(match):
|
||||
nonlocal current_pic_counter
|
||||
pic_id = match.group(1)
|
||||
|
||||
if pic_id not in pic_id_mapping:
|
||||
pic_id_mapping[pic_id] = f"图片{current_pic_counter}"
|
||||
current_pic_counter += 1
|
||||
|
||||
return f"[{pic_id_mapping[pic_id]}]"
|
||||
|
||||
return re.sub(pic_pattern, replace_pic_id, content)
|
||||
|
||||
# 1 & 2: 获取发送者信息并提取消息组件
|
||||
for msg in messages:
|
||||
# 检查是否是动作记录
|
||||
@@ -183,7 +212,8 @@ def _build_readable_messages_internal(
|
||||
is_action = True
|
||||
timestamp = msg.get("time")
|
||||
content = msg.get("display_message", "")
|
||||
# 对于动作记录,直接使用内容
|
||||
# 对于动作记录,也处理图片ID
|
||||
content = process_pic_ids(content)
|
||||
message_details_raw.append((timestamp, global_config.bot.nickname, content, is_action))
|
||||
continue
|
||||
|
||||
@@ -215,6 +245,9 @@ def _build_readable_messages_internal(
|
||||
if "ⁿ" in content:
|
||||
content = content.replace("ⁿ", "")
|
||||
|
||||
# 处理图片ID
|
||||
content = process_pic_ids(content)
|
||||
|
||||
# 检查必要信息是否存在
|
||||
if not all([platform, user_id, timestamp is not None]):
|
||||
continue
|
||||
@@ -277,7 +310,7 @@ def _build_readable_messages_internal(
|
||||
message_details_raw.append((timestamp, person_name, content, False))
|
||||
|
||||
if not message_details_raw:
|
||||
return "", []
|
||||
return "", [], pic_id_mapping, current_pic_counter
|
||||
|
||||
message_details_raw.sort(key=lambda x: x[0]) # 按时间戳(第一个元素)升序排序,越早的消息排在前面
|
||||
|
||||
@@ -285,10 +318,6 @@ def _build_readable_messages_internal(
|
||||
message_details_with_flags = []
|
||||
for timestamp, name, content, is_action in message_details_raw:
|
||||
message_details_with_flags.append((timestamp, name, content, is_action))
|
||||
# print(f"content:{content}")
|
||||
# print(f"is_action:{is_action}")
|
||||
|
||||
# print(f"message_details_with_flags:{message_details_with_flags}")
|
||||
|
||||
# 应用截断逻辑 (如果 truncate 为 True)
|
||||
message_details: List[Tuple[float, str, str, bool]] = []
|
||||
@@ -326,8 +355,6 @@ def _build_readable_messages_internal(
|
||||
# 如果不截断,直接使用原始列表
|
||||
message_details = message_details_with_flags
|
||||
|
||||
# print(f"message_details:{message_details}")
|
||||
|
||||
# 3: 合并连续消息 (如果 merge_messages 为 True)
|
||||
merged_messages = []
|
||||
if merge_messages and message_details:
|
||||
@@ -388,6 +415,7 @@ def _build_readable_messages_internal(
|
||||
|
||||
# 4 & 5: 格式化为字符串
|
||||
output_lines = []
|
||||
|
||||
for _i, merged in enumerate(merged_messages):
|
||||
# 使用指定的 timestamp_mode 格式化时间
|
||||
readable_time = translate_timestamp_to_human_readable(merged["start_time"], mode=timestamp_mode)
|
||||
@@ -416,9 +444,44 @@ def _build_readable_messages_internal(
|
||||
# 移除可能的多余换行,然后合并
|
||||
formatted_string = "".join(output_lines).strip()
|
||||
|
||||
# 返回格式化后的字符串和 *应用截断后* 的 message_details 列表
|
||||
# 注意:如果外部调用者需要原始未截断的内容,可能需要调整返回策略
|
||||
return formatted_string, [(t, n, c) for t, n, c, is_action in message_details if not is_action]
|
||||
# 返回格式化后的字符串、消息详情列表、图片映射字典和更新后的计数器
|
||||
return formatted_string, [(t, n, c) for t, n, c, is_action in message_details if not is_action], pic_id_mapping, current_pic_counter
|
||||
|
||||
|
||||
def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
|
||||
"""
|
||||
构建图片映射信息字符串,显示图片的具体描述内容
|
||||
|
||||
Args:
|
||||
pic_id_mapping: 图片ID到显示名称的映射字典
|
||||
|
||||
Returns:
|
||||
格式化的映射信息字符串
|
||||
"""
|
||||
if not pic_id_mapping:
|
||||
return ""
|
||||
|
||||
|
||||
|
||||
mapping_lines = []
|
||||
|
||||
# 按图片编号排序
|
||||
sorted_items = sorted(pic_id_mapping.items(), key=lambda x: int(x[1].replace("图片", "")))
|
||||
|
||||
for pic_id, display_name in sorted_items:
|
||||
# 从数据库中获取图片描述
|
||||
description = "内容正在阅读"
|
||||
try:
|
||||
image = Images.get_or_none(Images.image_id == pic_id)
|
||||
if image and image.description:
|
||||
description = image.description
|
||||
except Exception as e:
|
||||
# 如果查询失败,保持默认描述
|
||||
pass
|
||||
|
||||
mapping_lines.append(f"[{display_name}] 的内容:{description}")
|
||||
|
||||
return "\n".join(mapping_lines)
|
||||
|
||||
|
||||
async def build_readable_messages_with_list(
|
||||
@@ -432,9 +495,15 @@ async def build_readable_messages_with_list(
|
||||
将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。
|
||||
允许通过参数控制格式化行为。
|
||||
"""
|
||||
formatted_string, details_list = _build_readable_messages_internal(
|
||||
formatted_string, details_list, pic_id_mapping, _ = _build_readable_messages_internal(
|
||||
messages, replace_bot_name, merge_messages, timestamp_mode, truncate
|
||||
)
|
||||
|
||||
# 生成图片映射信息并添加到最前面
|
||||
pic_mapping_info = build_pic_mapping_info(pic_id_mapping)
|
||||
if pic_mapping_info:
|
||||
formatted_string = f"{pic_mapping_info}\n\n{formatted_string}"
|
||||
|
||||
return formatted_string, details_list
|
||||
|
||||
|
||||
@@ -503,53 +572,56 @@ def build_readable_messages(
|
||||
|
||||
if read_mark <= 0:
|
||||
# 没有有效的 read_mark,直接格式化所有消息
|
||||
|
||||
# for message in messages:
|
||||
# print(f"message:{message}")
|
||||
|
||||
formatted_string, _ = _build_readable_messages_internal(
|
||||
formatted_string, _, pic_id_mapping, _ = _build_readable_messages_internal(
|
||||
copy_messages, replace_bot_name, merge_messages, timestamp_mode, truncate
|
||||
)
|
||||
|
||||
# print(f"formatted_string:{formatted_string}")
|
||||
|
||||
# 生成图片映射信息并添加到最前面
|
||||
pic_mapping_info = build_pic_mapping_info(pic_id_mapping)
|
||||
if pic_mapping_info:
|
||||
return f"{pic_mapping_info}\n\n{formatted_string}"
|
||||
else:
|
||||
return formatted_string
|
||||
else:
|
||||
# 按 read_mark 分割消息
|
||||
messages_before_mark = [msg for msg in copy_messages if msg.get("time", 0) <= read_mark]
|
||||
messages_after_mark = [msg for msg in copy_messages if msg.get("time", 0) > read_mark]
|
||||
|
||||
# for message in messages_before_mark:
|
||||
# print(f"message:{message}")
|
||||
# 共享的图片映射字典和计数器
|
||||
pic_id_mapping = {}
|
||||
pic_counter = 1
|
||||
|
||||
# for message in messages_after_mark:
|
||||
# print(f"message:{message}")
|
||||
|
||||
# 分别格式化
|
||||
formatted_before, _ = _build_readable_messages_internal(
|
||||
messages_before_mark, replace_bot_name, merge_messages, timestamp_mode, truncate
|
||||
# 分别格式化,但使用共享的图片映射
|
||||
formatted_before, _, pic_id_mapping, pic_counter = _build_readable_messages_internal(
|
||||
messages_before_mark, replace_bot_name, merge_messages, timestamp_mode, truncate,
|
||||
pic_id_mapping, pic_counter
|
||||
)
|
||||
formatted_after, _ = _build_readable_messages_internal(
|
||||
messages_after_mark,
|
||||
replace_bot_name,
|
||||
merge_messages,
|
||||
timestamp_mode,
|
||||
formatted_after, _, pic_id_mapping, _ = _build_readable_messages_internal(
|
||||
messages_after_mark, replace_bot_name, merge_messages, timestamp_mode, False,
|
||||
pic_id_mapping, pic_counter
|
||||
)
|
||||
|
||||
# print(f"formatted_before:{formatted_before}")
|
||||
# print(f"formatted_after:{formatted_after}")
|
||||
|
||||
read_mark_line = "\n--- 以上消息是你已经看过---\n--- 请关注以下未读的新消息---\n"
|
||||
|
||||
# 生成图片映射信息
|
||||
pic_mapping_info = f"图片信息:\n{build_pic_mapping_info(pic_id_mapping)}\n聊天记录信息:\n"
|
||||
|
||||
# 组合结果
|
||||
result_parts = []
|
||||
if pic_mapping_info:
|
||||
result_parts.append(pic_mapping_info)
|
||||
result_parts.append("\n")
|
||||
|
||||
if formatted_before and formatted_after:
|
||||
return f"{formatted_before}{read_mark_line}{formatted_after}"
|
||||
result_parts.extend([formatted_before, read_mark_line, formatted_after])
|
||||
elif formatted_before:
|
||||
return f"{formatted_before}{read_mark_line}"
|
||||
result_parts.extend([formatted_before, read_mark_line])
|
||||
elif formatted_after:
|
||||
return f"{read_mark_line}{formatted_after}"
|
||||
result_parts.extend([read_mark_line, formatted_after])
|
||||
else:
|
||||
return read_mark_line.strip()
|
||||
result_parts.append(read_mark_line.strip())
|
||||
|
||||
return "".join(result_parts)
|
||||
|
||||
|
||||
async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
|
||||
@@ -565,6 +637,29 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
|
||||
current_char = ord("A")
|
||||
output_lines = []
|
||||
|
||||
# 图片ID映射字典
|
||||
pic_id_mapping = {}
|
||||
pic_counter = 1
|
||||
|
||||
def process_pic_ids(content: str) -> str:
|
||||
"""处理内容中的图片ID,将其替换为[图片x]格式"""
|
||||
nonlocal pic_counter
|
||||
|
||||
# 匹配 [picid:xxxxx] 格式
|
||||
pic_pattern = r'\[picid:([^\]]+)\]'
|
||||
|
||||
def replace_pic_id(match):
|
||||
nonlocal pic_counter
|
||||
pic_id = match.group(1)
|
||||
|
||||
if pic_id not in pic_id_mapping:
|
||||
pic_id_mapping[pic_id] = f"图片{pic_counter}"
|
||||
pic_counter += 1
|
||||
|
||||
return f"[{pic_id_mapping[pic_id]}]"
|
||||
|
||||
return re.sub(pic_pattern, replace_pic_id, content)
|
||||
|
||||
def get_anon_name(platform, user_id):
|
||||
# print(f"get_anon_name: platform:{platform}, user_id:{user_id}")
|
||||
# print(f"global_config.bot.qq_account:{global_config.bot.qq_account}")
|
||||
@@ -599,6 +694,9 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
|
||||
if "ⁿ" in content:
|
||||
content = content.replace("ⁿ", "")
|
||||
|
||||
# 处理图片ID
|
||||
content = process_pic_ids(content)
|
||||
|
||||
# if not all([platform, user_id, timestamp is not None]):
|
||||
# continue
|
||||
|
||||
@@ -650,7 +748,15 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
formatted_string = "".join(output_lines).strip()
|
||||
# 在最前面添加图片映射信息
|
||||
final_output_lines = []
|
||||
pic_mapping_info = build_pic_mapping_info(pic_id_mapping)
|
||||
if pic_mapping_info:
|
||||
final_output_lines.append(pic_mapping_info)
|
||||
final_output_lines.append("\n\n")
|
||||
|
||||
final_output_lines.extend(output_lines)
|
||||
formatted_string = "".join(final_output_lines).strip()
|
||||
return formatted_string
|
||||
|
||||
|
||||
|
||||
@@ -2,10 +2,12 @@ import base64
|
||||
import os
|
||||
import time
|
||||
import hashlib
|
||||
from typing import Optional
|
||||
import uuid
|
||||
from typing import Optional, Tuple
|
||||
from PIL import Image
|
||||
import io
|
||||
import numpy as np
|
||||
import asyncio
|
||||
|
||||
|
||||
from src.common.database.database import db
|
||||
@@ -360,6 +362,125 @@ class ImageManager:
|
||||
logger.error(f"GIF转换失败: {str(e)}", exc_info=True) # 记录详细错误信息
|
||||
return None # 其他错误也返回None
|
||||
|
||||
async def process_image(self, image_base64: str) -> Tuple[str, str]:
|
||||
"""处理图片并返回图片ID和描述
|
||||
|
||||
Args:
|
||||
image_base64: 图片的base64编码
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: (图片ID, 描述)
|
||||
"""
|
||||
try:
|
||||
# 生成图片ID
|
||||
# 计算图片哈希
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
|
||||
# 检查图片是否已存在
|
||||
existing_image = Images.get_or_none(Images.emoji_hash == image_hash)
|
||||
|
||||
if existing_image:
|
||||
# print(f"图片已存在: {existing_image.image_id}")
|
||||
# print(f"图片描述: {existing_image.description}")
|
||||
# print(f"图片计数: {existing_image.count}")
|
||||
# 更新计数
|
||||
existing_image.count += 1
|
||||
existing_image.save()
|
||||
return existing_image.image_id, f"[picid:{existing_image.image_id}]"
|
||||
else:
|
||||
# print(f"图片不存在: {image_hash}")
|
||||
image_id = str(uuid.uuid4())
|
||||
|
||||
# 保存新图片
|
||||
current_timestamp = time.time()
|
||||
image_dir = os.path.join(self.IMAGE_DIR, "images")
|
||||
os.makedirs(image_dir, exist_ok=True)
|
||||
filename = f"{image_id}.png"
|
||||
file_path = os.path.join(image_dir, filename)
|
||||
|
||||
# 保存文件
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
|
||||
# 保存到数据库
|
||||
Images.create(
|
||||
image_id=image_id,
|
||||
emoji_hash=image_hash,
|
||||
path=file_path,
|
||||
base64=image_base64,
|
||||
type="image",
|
||||
timestamp=current_timestamp,
|
||||
vlm_processed=False
|
||||
)
|
||||
|
||||
# 启动异步VLM处理
|
||||
asyncio.create_task(self._process_image_with_vlm(image_id, image_base64))
|
||||
|
||||
return image_id, f"[picid:{image_id}]"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理图片失败: {str(e)}")
|
||||
return "", "[图片]"
|
||||
|
||||
async def _process_image_with_vlm(self, image_id: str, image_base64: str) -> None:
|
||||
"""使用VLM处理图片并更新数据库
|
||||
|
||||
Args:
|
||||
image_id: 图片ID
|
||||
image_base64: 图片的base64编码
|
||||
"""
|
||||
try:
|
||||
# 计算图片哈希
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
|
||||
# 先检查缓存的描述
|
||||
cached_description = self._get_description_from_db(image_hash, "image")
|
||||
if cached_description:
|
||||
logger.debug(f"VLM处理时发现缓存描述: {cached_description}")
|
||||
# 更新数据库
|
||||
image = Images.get(Images.image_id == image_id)
|
||||
image.description = cached_description
|
||||
image.vlm_processed = True
|
||||
image.save()
|
||||
return
|
||||
|
||||
# 获取图片格式
|
||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
|
||||
|
||||
# 构建prompt
|
||||
prompt = """请用中文描述这张图片的内容。如果有文字,请把文字描述概括出来,请留意其主题,直观感受,输出为一段平文本,最多50字"""
|
||||
|
||||
# 获取VLM描述
|
||||
description, _ = await self._llm.generate_response_for_image(
|
||||
prompt,
|
||||
image_base64,
|
||||
image_format
|
||||
)
|
||||
|
||||
if description is None:
|
||||
logger.warning("VLM未能生成图片描述")
|
||||
description = "无法生成描述"
|
||||
|
||||
# 再次检查缓存,防止并发写入时重复生成
|
||||
cached_description = self._get_description_from_db(image_hash, "image")
|
||||
if cached_description:
|
||||
logger.warning(f"虽然生成了描述,但是找到缓存图片描述: {cached_description}")
|
||||
description = cached_description
|
||||
|
||||
# 更新数据库
|
||||
image = Images.get(Images.image_id == image_id)
|
||||
image.description = description
|
||||
image.vlm_processed = True
|
||||
image.save()
|
||||
|
||||
# 保存描述到ImageDescriptions表
|
||||
self._save_description_to_db(image_hash, description, "image")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"VLM处理图片失败: {str(e)}")
|
||||
|
||||
|
||||
# 创建全局单例
|
||||
image_manager = None
|
||||
|
||||
@@ -185,14 +185,17 @@ class Images(BaseModel):
|
||||
用于存储图像信息的模型。
|
||||
"""
|
||||
|
||||
image_id = TextField(default="") # 图片唯一ID
|
||||
emoji_hash = TextField(index=True) # 图像的哈希值
|
||||
description = TextField(null=True) # 图像的描述
|
||||
path = TextField(unique=True) # 图像文件的路径
|
||||
base64 = TextField() # 图片的base64编码
|
||||
count = IntegerField(default=1) # 图片被引用的次数
|
||||
timestamp = FloatField() # 时间戳
|
||||
type = TextField() # 图像类型,例如 "emoji"
|
||||
vlm_processed = BooleanField(default=False) # 是否已经过VLM处理
|
||||
|
||||
class Meta:
|
||||
# database = db # 继承自 BaseModel
|
||||
table_name = "images"
|
||||
|
||||
|
||||
|
||||
@@ -70,7 +70,7 @@ class MuteAction(BaseAction):
|
||||
|
||||
# Action参数定义
|
||||
action_parameters = {
|
||||
"target": "禁言对象,必填,输入你要禁言的对象的名字",
|
||||
"target": "禁言对象,必填,输入你要禁言的对象的名字,请仔细思考不要弄错禁言对象",
|
||||
"duration": "禁言时长,必填,输入你要禁言的时长(秒),单位为秒,必须为数字",
|
||||
"reason": "禁言理由,可选",
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user