refactor(chat): 重构图片在聊天记录中的处理与表示方式
为了简化LLM的上下文并提高代码可维护性,对聊天记录中图片的处理方式进行了彻底重构。 旧系统使用 [图片1] 等占位符,并在消息头部附加一个独立的图片描述映射块。这种方式结构复杂,容易造成上下文分离。 新系统将图片描述直接内联到消息文本中,格式为 `[图片:一只猫]`,使聊天记录对LLM更加自然和易于理解。 主要变更: - **消息构建 (`chat_message_builder`):** 在构建可读消息时,异步将数据库中的 `[picid:...]` 标签直接替换为完整的 `[图片:描述]`。 - **废弃映射:** 移除了独立的图片映射信息块 (`build_pic_mapping_info` 函数),所有信息都在消息内联。 - **图片处理 (`utils_image`):** `process_image` 流程现在同步返回完整的描述字符串,并增强了VLM调用的重试逻辑和缓存机制,提高了健壮性。 - **消息存储 (`storage`):** 在消息存入数据库前,将 `[图片:描述]` 转换为 `[picid:...]`,以保持存储规范化。 - **修复:** 增加了多处空值检查,提高了代码的稳定性。这不得之前稳定好用多了😋😋😋
This commit is contained in:
@@ -86,7 +86,7 @@ class MaiEmoji:
|
|||||||
logger.debug(f"[初始化] 正在使用Pillow获取格式: {self.filename}")
|
logger.debug(f"[初始化] 正在使用Pillow获取格式: {self.filename}")
|
||||||
try:
|
try:
|
||||||
with Image.open(io.BytesIO(image_bytes)) as img:
|
with Image.open(io.BytesIO(image_bytes)) as img:
|
||||||
self.format = img.format.lower() # type: ignore
|
self.format = (img.format or "jpeg").lower()
|
||||||
logger.debug(f"[初始化] 格式获取成功: {self.format}")
|
logger.debug(f"[初始化] 格式获取成功: {self.format}")
|
||||||
except Exception as pil_error:
|
except Exception as pil_error:
|
||||||
logger.error(f"[初始化错误] Pillow无法处理图片 ({self.filename}): {pil_error}")
|
logger.error(f"[初始化错误] Pillow无法处理图片 ({self.filename}): {pil_error}")
|
||||||
@@ -327,7 +327,7 @@ async def clear_temp_emoji() -> None:
|
|||||||
):
|
):
|
||||||
if os.path.exists(need_clear):
|
if os.path.exists(need_clear):
|
||||||
files = os.listdir(need_clear)
|
files = os.listdir(need_clear)
|
||||||
# 如果文件数超过100就全部删除
|
# 如果文件数超过1000就全部删除
|
||||||
if len(files) > 1000:
|
if len(files) > 1000:
|
||||||
for filename in files:
|
for filename in files:
|
||||||
file_path = os.path.join(need_clear, filename)
|
file_path = os.path.join(need_clear, filename)
|
||||||
@@ -439,12 +439,12 @@ class EmojiManager:
|
|||||||
stmt = select(Emoji).where(Emoji.emoji_hash == emoji_hash)
|
stmt = select(Emoji).where(Emoji.emoji_hash == emoji_hash)
|
||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
emoji_update = result.scalar_one_or_none()
|
emoji_update = result.scalar_one_or_none()
|
||||||
if emoji_update is None:
|
if emoji_update:
|
||||||
logger.error(f"记录表情使用失败: 未找到 hash 为 {emoji_hash} 的表情包")
|
|
||||||
else:
|
|
||||||
emoji_update.usage_count += 1
|
emoji_update.usage_count += 1
|
||||||
emoji_update.last_used_time = time.time() # Update last used time
|
emoji_update.last_used_time = time.time() # Update last used time
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
else:
|
||||||
|
logger.error(f"记录表情使用失败: 未找到 hash 为 {emoji_hash} 的表情包")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"记录表情使用失败: {e!s}")
|
logger.error(f"记录表情使用失败: {e!s}")
|
||||||
|
|
||||||
@@ -469,7 +469,7 @@ class EmojiManager:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
# 2. 根据全局配置决定候选表情包的数量
|
# 2. 根据全局配置决定候选表情包的数量
|
||||||
max_candidates = global_config.emoji.max_emoji_for_llm_select
|
max_candidates = global_config.emoji.max_context_emojis
|
||||||
|
|
||||||
# 如果配置为0或者大于等于总数,则选择所有表情包
|
# 如果配置为0或者大于等于总数,则选择所有表情包
|
||||||
if max_candidates <= 0 or max_candidates >= len(all_emojis):
|
if max_candidates <= 0 or max_candidates >= len(all_emojis):
|
||||||
@@ -943,11 +943,7 @@ class EmojiManager:
|
|||||||
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
||||||
image_bytes = base64.b64decode(image_base64)
|
image_bytes = base64.b64decode(image_base64)
|
||||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||||
image_format = (
|
image_format = (Image.open(io.BytesIO(image_bytes)).format or "jpeg").lower()
|
||||||
Image.open(io.BytesIO(image_bytes)).format.lower()
|
|
||||||
if Image.open(io.BytesIO(image_bytes)).format
|
|
||||||
else "jpeg"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. 检查数据库中是否已存在该表情包的描述,实现复用
|
# 2. 检查数据库中是否已存在该表情包的描述,实现复用
|
||||||
existing_description = None
|
existing_description = None
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import re
|
import re
|
||||||
import json
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
import orjson
|
import orjson
|
||||||
@@ -44,7 +44,9 @@ class MessageStorage:
|
|||||||
|
|
||||||
if processed_plain_text:
|
if processed_plain_text:
|
||||||
processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text)
|
processed_plain_text = await MessageStorage.replace_image_descriptions(processed_plain_text)
|
||||||
filtered_processed_plain_text = re.sub(pattern, "", processed_plain_text, flags=re.DOTALL)
|
# 增加对None的防御性处理
|
||||||
|
safe_processed_plain_text = processed_plain_text or ""
|
||||||
|
filtered_processed_plain_text = re.sub(pattern, "", safe_processed_plain_text, flags=re.DOTALL)
|
||||||
else:
|
else:
|
||||||
filtered_processed_plain_text = ""
|
filtered_processed_plain_text = ""
|
||||||
|
|
||||||
@@ -55,9 +57,7 @@ class MessageStorage:
|
|||||||
else:
|
else:
|
||||||
# 如果没有设置display_message,使用processed_plain_text作为显示消息
|
# 如果没有设置display_message,使用processed_plain_text作为显示消息
|
||||||
filtered_display_message = (
|
filtered_display_message = (
|
||||||
re.sub(pattern, "", message.processed_plain_text, flags=re.DOTALL)
|
re.sub(pattern, "", (message.processed_plain_text or ""), flags=re.DOTALL)
|
||||||
if message.processed_plain_text
|
|
||||||
else ""
|
|
||||||
)
|
)
|
||||||
interest_value = 0
|
interest_value = 0
|
||||||
is_mentioned = False
|
is_mentioned = False
|
||||||
@@ -103,7 +103,7 @@ class MessageStorage:
|
|||||||
|
|
||||||
new_message = Messages(
|
new_message = Messages(
|
||||||
message_id=msg_id,
|
message_id=msg_id,
|
||||||
time=float(message.message_info.time),
|
time=float(message.message_info.time or time.time()),
|
||||||
chat_id=chat_stream.stream_id,
|
chat_id=chat_stream.stream_id,
|
||||||
reply_to=reply_to,
|
reply_to=reply_to,
|
||||||
is_mentioned=is_mentioned,
|
is_mentioned=is_mentioned,
|
||||||
@@ -196,29 +196,48 @@ class MessageStorage:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def replace_image_descriptions(text: str) -> str:
|
async def replace_image_descriptions(text: str) -> str:
|
||||||
"""将[图片:描述]替换为[picid:image_id]"""
|
"""异步地将文本中的所有[图片:描述]标记替换为[picid:image_id]"""
|
||||||
# 先检查文本中是否有图片标记
|
|
||||||
pattern = r"\[图片:([^\]]+)\]"
|
pattern = r"\[图片:([^\]]+)\]"
|
||||||
matches = list(re.finditer(pattern, text))
|
|
||||||
|
# 如果没有匹配项,提前返回以提高效率
|
||||||
if not matches:
|
if not re.search(pattern, text):
|
||||||
logger.debug("文本中没有图片标记,直接返回原文本")
|
|
||||||
return text
|
return text
|
||||||
|
|
||||||
async def replace_match(match):
|
# re.sub不支持异步替换函数,所以我们需要手动迭代和替换
|
||||||
|
new_text = []
|
||||||
|
last_end = 0
|
||||||
|
for match in re.finditer(pattern, text):
|
||||||
|
# 添加上一个匹配到当前匹配之间的文本
|
||||||
|
new_text.append(text[last_end:match.start()])
|
||||||
|
|
||||||
description = match.group(1).strip()
|
description = match.group(1).strip()
|
||||||
|
replacement = match.group(0) # 默认情况下,替换为原始匹配文本
|
||||||
try:
|
try:
|
||||||
from src.common.database.sqlalchemy_models import get_db_session
|
|
||||||
|
|
||||||
async with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
image_record = (
|
# 查询数据库以找到具有该描述的最新图片记录
|
||||||
await session.execute(
|
result = await session.execute(
|
||||||
select(Images).where(Images.description == description).order_by(desc(Images.timestamp))
|
select(Images.image_id)
|
||||||
)
|
.where(Images.description == description)
|
||||||
).scalar()
|
.order_by(desc(Images.timestamp))
|
||||||
return f"[picid:{image_record.image_id}]" if image_record else match.group(0)
|
.limit(1)
|
||||||
except Exception:
|
)
|
||||||
return match.group(0)
|
image_id = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if image_id:
|
||||||
|
replacement = f"[picid:{image_id}]"
|
||||||
|
logger.debug(f"成功将描述 '{description[:20]}...' 替换为 picid '{image_id}'")
|
||||||
|
else:
|
||||||
|
logger.warning(f"无法为描述 '{description[:20]}...' 找到对应的picid,将保留原始标记")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"替换图片描述时查询数据库失败: {e}", exc_info=True)
|
||||||
|
|
||||||
|
new_text.append(replacement)
|
||||||
|
last_end = match.end()
|
||||||
|
|
||||||
|
# 添加最后一个匹配到字符串末尾的文本
|
||||||
|
new_text.append(text[last_end:])
|
||||||
|
|
||||||
|
return "".join(new_text)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def update_message_interest_value(
|
async def update_message_interest_value(
|
||||||
|
|||||||
@@ -547,6 +547,36 @@ async def _build_readable_messages_internal(
|
|||||||
if pic_id_mapping is None:
|
if pic_id_mapping is None:
|
||||||
pic_id_mapping = {}
|
pic_id_mapping = {}
|
||||||
current_pic_counter = pic_counter
|
current_pic_counter = pic_counter
|
||||||
|
|
||||||
|
# --- 异步图片ID处理器 (修复核心问题) ---
|
||||||
|
async def process_pic_ids(content: str) -> str:
|
||||||
|
"""异步处理内容中的图片ID,将其直接替换为[图片:描述]格式"""
|
||||||
|
pic_pattern = r"\[picid:([^\]]+)\]"
|
||||||
|
matches = list(re.finditer(pic_pattern, content))
|
||||||
|
if not matches:
|
||||||
|
return content
|
||||||
|
|
||||||
|
new_content = ""
|
||||||
|
last_end = 0
|
||||||
|
for match in matches:
|
||||||
|
new_content += content[last_end : match.start()]
|
||||||
|
pic_id = match.group(1)
|
||||||
|
description = "[图片内容未知]"
|
||||||
|
try:
|
||||||
|
async with get_db_session() as session:
|
||||||
|
result = await session.execute(select(Images.description).where(Images.image_id == pic_id))
|
||||||
|
desc_scalar = result.scalar_one_or_none()
|
||||||
|
if desc_scalar and desc_scalar.strip():
|
||||||
|
description = f"[图片:{desc_scalar}]"
|
||||||
|
else:
|
||||||
|
description = "[图片内容未知]"
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"[chat_message_builder] 查询图片 {pic_id} 描述失败: {e}")
|
||||||
|
description = "[图片内容未知]"
|
||||||
|
new_content += description
|
||||||
|
last_end = match.end()
|
||||||
|
new_content += content[last_end:]
|
||||||
|
return new_content
|
||||||
|
|
||||||
# 创建时间戳到消息ID的映射,用于在消息前添加[id]标识符
|
# 创建时间戳到消息ID的映射,用于在消息前添加[id]标识符
|
||||||
timestamp_to_id = {}
|
timestamp_to_id = {}
|
||||||
@@ -557,25 +587,6 @@ async def _build_readable_messages_internal(
|
|||||||
if timestamp is not None:
|
if timestamp is not None:
|
||||||
timestamp_to_id[timestamp] = item.get("id", "")
|
timestamp_to_id[timestamp] = item.get("id", "")
|
||||||
|
|
||||||
def process_pic_ids(content: str) -> str:
|
|
||||||
"""处理内容中的图片ID,将其替换为[图片x]格式"""
|
|
||||||
nonlocal current_pic_counter
|
|
||||||
|
|
||||||
# 匹配 [picid:xxxxx] 格式
|
|
||||||
pic_pattern = r"\[picid:([^\]]+)\]"
|
|
||||||
|
|
||||||
def replace_pic_id(match):
|
|
||||||
nonlocal current_pic_counter
|
|
||||||
pic_id = match.group(1)
|
|
||||||
|
|
||||||
if pic_id not in pic_id_mapping:
|
|
||||||
pic_id_mapping[pic_id] = f"图片{current_pic_counter}"
|
|
||||||
current_pic_counter += 1
|
|
||||||
|
|
||||||
return f"[{pic_id_mapping[pic_id]}]"
|
|
||||||
|
|
||||||
return re.sub(pic_pattern, replace_pic_id, content)
|
|
||||||
|
|
||||||
# 1 & 2: 获取发送者信息并提取消息组件
|
# 1 & 2: 获取发送者信息并提取消息组件
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
# 检查是否是动作记录
|
# 检查是否是动作记录
|
||||||
@@ -583,8 +594,8 @@ async def _build_readable_messages_internal(
|
|||||||
is_action = True
|
is_action = True
|
||||||
timestamp: float = msg.get("time") # type: ignore
|
timestamp: float = msg.get("time") # type: ignore
|
||||||
content = msg.get("display_message", "")
|
content = msg.get("display_message", "")
|
||||||
# 对于动作记录,也处理图片ID
|
if show_pic:
|
||||||
content = process_pic_ids(content)
|
content = await process_pic_ids(content)
|
||||||
message_details_raw.append((timestamp, global_config.bot.nickname, content, is_action))
|
message_details_raw.append((timestamp, global_config.bot.nickname, content, is_action))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -619,7 +630,7 @@ async def _build_readable_messages_internal(
|
|||||||
|
|
||||||
# 处理图片ID
|
# 处理图片ID
|
||||||
if show_pic:
|
if show_pic:
|
||||||
content = process_pic_ids(content)
|
content = await process_pic_ids(content)
|
||||||
|
|
||||||
# 检查必要信息是否存在
|
# 检查必要信息是否存在
|
||||||
if not all([platform, user_id, timestamp is not None]):
|
if not all([platform, user_id, timestamp is not None]):
|
||||||
@@ -808,43 +819,12 @@ async def _build_readable_messages_internal(
|
|||||||
current_pic_counter,
|
current_pic_counter,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def build_pic_mapping_info(pic_id_mapping: dict[str, str]) -> str:
|
async def build_pic_mapping_info(pic_id_mapping: dict[str, str]) -> str:
|
||||||
# sourcery skip: use-contextlib-suppress
|
|
||||||
"""
|
"""
|
||||||
构建图片映射信息字符串,显示图片的具体描述内容
|
此函数已废弃,因为图片描述现在被内联处理。
|
||||||
|
保留此函数以确保向后兼容性,但它将始终返回一个空字符串。
|
||||||
Args:
|
|
||||||
pic_id_mapping: 图片ID到显示名称的映射字典
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
格式化的映射信息字符串
|
|
||||||
"""
|
"""
|
||||||
if not pic_id_mapping:
|
return ""
|
||||||
return ""
|
|
||||||
|
|
||||||
mapping_lines = []
|
|
||||||
|
|
||||||
# 按图片编号排序
|
|
||||||
sorted_items = sorted(pic_id_mapping.items(), key=lambda x: int(x[1].replace("图片", "")))
|
|
||||||
|
|
||||||
for pic_id, display_name in sorted_items:
|
|
||||||
# 从数据库中获取图片描述
|
|
||||||
description = "[图片内容未知]" # 默认描述
|
|
||||||
try:
|
|
||||||
async with get_db_session() as session:
|
|
||||||
result = await session.execute(select(Images).where(Images.image_id == pic_id))
|
|
||||||
image = result.scalar_one_or_none()
|
|
||||||
if image and hasattr(image, "description") and image.description:
|
|
||||||
description = image.description
|
|
||||||
except Exception as e:
|
|
||||||
# 如果查询失败,保持默认描述
|
|
||||||
logger.debug(f"[chat_message_builder] 查询图片描述失败: {e}")
|
|
||||||
pass
|
|
||||||
|
|
||||||
mapping_lines.append(f"[{display_name}] 的内容:{description}")
|
|
||||||
|
|
||||||
return "\n".join(mapping_lines)
|
|
||||||
|
|
||||||
|
|
||||||
def build_readable_actions(actions: list[dict[str, Any]]) -> str:
|
def build_readable_actions(actions: list[dict[str, Any]]) -> str:
|
||||||
@@ -932,13 +912,9 @@ async def build_readable_messages_with_list(
|
|||||||
将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。
|
将消息列表转换为可读的文本格式,并返回原始(时间戳, 昵称, 内容)列表。
|
||||||
允许通过参数控制格式化行为。
|
允许通过参数控制格式化行为。
|
||||||
"""
|
"""
|
||||||
formatted_string, details_list, pic_id_mapping, _ = await _build_readable_messages_internal(
|
formatted_string, details_list, _, _ = await _build_readable_messages_internal(
|
||||||
messages, replace_bot_name, merge_messages, timestamp_mode, truncate
|
messages, replace_bot_name, merge_messages, timestamp_mode, truncate
|
||||||
)
|
)
|
||||||
|
|
||||||
if pic_mapping_info := await build_pic_mapping_info(pic_id_mapping):
|
|
||||||
formatted_string = f"{pic_mapping_info}\n\n{formatted_string}"
|
|
||||||
|
|
||||||
return formatted_string, details_list
|
return formatted_string, details_list
|
||||||
|
|
||||||
|
|
||||||
@@ -970,12 +946,6 @@ async def build_readable_messages_with_id(
|
|||||||
message_id_list=message_id_list,
|
message_id_list=message_id_list,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 如果存在图片映射信息,附加之
|
|
||||||
if pic_mapping_info := await build_pic_mapping_info({}):
|
|
||||||
# 如果当前没有图片映射则不附加
|
|
||||||
if pic_mapping_info:
|
|
||||||
formatted_string = f"{pic_mapping_info}\n\n{formatted_string}"
|
|
||||||
|
|
||||||
return formatted_string, message_id_list
|
return formatted_string, message_id_list
|
||||||
|
|
||||||
|
|
||||||
@@ -1078,7 +1048,7 @@ async def build_readable_messages(
|
|||||||
|
|
||||||
if read_mark <= 0:
|
if read_mark <= 0:
|
||||||
# 没有有效的 read_mark,直接格式化所有消息
|
# 没有有效的 read_mark,直接格式化所有消息
|
||||||
formatted_string, _, pic_id_mapping, _ = await _build_readable_messages_internal(
|
formatted_string, _, _, _ = await _build_readable_messages_internal(
|
||||||
copy_messages,
|
copy_messages,
|
||||||
replace_bot_name,
|
replace_bot_name,
|
||||||
merge_messages,
|
merge_messages,
|
||||||
@@ -1088,12 +1058,7 @@ async def build_readable_messages(
|
|||||||
message_id_list=message_id_list,
|
message_id_list=message_id_list,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 生成图片映射信息并添加到最前面
|
return formatted_string
|
||||||
pic_mapping_info = await build_pic_mapping_info(pic_id_mapping)
|
|
||||||
if pic_mapping_info:
|
|
||||||
return f"{pic_mapping_info}\n\n{formatted_string}"
|
|
||||||
else:
|
|
||||||
return formatted_string
|
|
||||||
else:
|
else:
|
||||||
# 按 read_mark 分割消息
|
# 按 read_mark 分割消息
|
||||||
messages_before_mark = [msg for msg in copy_messages if msg.get("time", 0) <= read_mark]
|
messages_before_mark = [msg for msg in copy_messages if msg.get("time", 0) <= read_mark]
|
||||||
@@ -1128,23 +1093,15 @@ async def build_readable_messages(
|
|||||||
)
|
)
|
||||||
|
|
||||||
read_mark_line = "\n--- 以上消息是你已经看过,请关注以下未读的新消息---\n"
|
read_mark_line = "\n--- 以上消息是你已经看过,请关注以下未读的新消息---\n"
|
||||||
|
|
||||||
# 生成图片映射信息
|
|
||||||
if pic_id_mapping:
|
|
||||||
pic_mapping_info = f"图片信息:\n{await build_pic_mapping_info(pic_id_mapping)}\n聊天记录信息:\n"
|
|
||||||
else:
|
|
||||||
pic_mapping_info = "聊天记录信息:\n"
|
|
||||||
|
|
||||||
# 组合结果
|
# 组合结果
|
||||||
result_parts = []
|
result_parts = []
|
||||||
if pic_mapping_info:
|
|
||||||
result_parts.extend((pic_mapping_info, "\n"))
|
|
||||||
if formatted_before and formatted_after:
|
if formatted_before and formatted_after:
|
||||||
result_parts.extend([formatted_before, read_mark_line, formatted_after])
|
result_parts.extend([formatted_before, read_mark_line, formatted_after])
|
||||||
elif formatted_before:
|
elif formatted_before:
|
||||||
result_parts.extend([formatted_before, read_mark_line])
|
result_parts.extend([formatted_before, read_mark_line])
|
||||||
elif formatted_after:
|
elif formatted_after:
|
||||||
result_parts.extend([read_mark_line, formatted_after])
|
result_parts.extend([read_mark_line.strip(), formatted_after])
|
||||||
else:
|
else:
|
||||||
result_parts.append(read_mark_line.strip())
|
result_parts.append(read_mark_line.strip())
|
||||||
|
|
||||||
@@ -1164,28 +1121,9 @@ async def build_anonymous_messages(messages: list[dict[str, Any]]) -> str:
|
|||||||
current_char = ord("A")
|
current_char = ord("A")
|
||||||
output_lines = []
|
output_lines = []
|
||||||
|
|
||||||
# 图片ID映射字典
|
# This function builds anonymous messages, so we don't need full descriptions.
|
||||||
pic_id_mapping = {}
|
# The existing placeholder logic is sufficient.
|
||||||
pic_counter = 1
|
# However, to maintain consistency, we will adapt it slightly.
|
||||||
|
|
||||||
def process_pic_ids(content: str) -> str:
|
|
||||||
"""处理内容中的图片ID,将其替换为[图片x]格式"""
|
|
||||||
nonlocal pic_counter
|
|
||||||
|
|
||||||
# 匹配 [picid:xxxxx] 格式
|
|
||||||
pic_pattern = r"\[picid:([^\]]+)\]"
|
|
||||||
|
|
||||||
def replace_pic_id(match):
|
|
||||||
nonlocal pic_counter
|
|
||||||
pic_id = match.group(1)
|
|
||||||
|
|
||||||
if pic_id not in pic_id_mapping:
|
|
||||||
pic_id_mapping[pic_id] = f"图片{pic_counter}"
|
|
||||||
pic_counter += 1
|
|
||||||
|
|
||||||
return f"[{pic_id_mapping[pic_id]}]"
|
|
||||||
|
|
||||||
return re.sub(pic_pattern, replace_pic_id, content)
|
|
||||||
|
|
||||||
def get_anon_name(platform, user_id):
|
def get_anon_name(platform, user_id):
|
||||||
# print(f"get_anon_name: platform:{platform}, user_id:{user_id}")
|
# print(f"get_anon_name: platform:{platform}, user_id:{user_id}")
|
||||||
@@ -1222,8 +1160,8 @@ async def build_anonymous_messages(messages: list[dict[str, Any]]) -> str:
|
|||||||
if "ⁿ" in content:
|
if "ⁿ" in content:
|
||||||
content = content.replace("ⁿ", "")
|
content = content.replace("ⁿ", "")
|
||||||
|
|
||||||
# 处理图片ID
|
# For anonymous messages, we just replace with a placeholder.
|
||||||
content = process_pic_ids(content)
|
content = re.sub(r"\[picid:([^\]]+)\]", "[图片]", content)
|
||||||
|
|
||||||
# if not all([platform, user_id, timestamp is not None]):
|
# if not all([platform, user_id, timestamp is not None]):
|
||||||
# continue
|
# continue
|
||||||
@@ -1252,15 +1190,8 @@ async def build_anonymous_messages(messages: list[dict[str, Any]]) -> str:
|
|||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 在最前面添加图片映射信息
|
# Since we are not generating a pic_mapping_info block, just join and return.
|
||||||
final_output_lines = []
|
formatted_string = "".join(output_lines).strip()
|
||||||
pic_mapping_info = await build_pic_mapping_info(pic_id_mapping)
|
|
||||||
if pic_mapping_info:
|
|
||||||
final_output_lines.append(pic_mapping_info)
|
|
||||||
final_output_lines.append("\n\n")
|
|
||||||
|
|
||||||
final_output_lines.extend(output_lines)
|
|
||||||
formatted_string = "".join(final_output_lines).strip()
|
|
||||||
return formatted_string
|
return formatted_string
|
||||||
|
|
||||||
|
|
||||||
@@ -1295,4 +1226,4 @@ async def get_person_id_list(messages: list[dict[str, Any]]) -> list[str]:
|
|||||||
if person_id := get_person_id(platform, user_id):
|
if person_id := get_person_id(platform, user_id):
|
||||||
person_ids_set.add(person_id)
|
person_ids_set.add(person_id)
|
||||||
|
|
||||||
return list(person_ids_set) # 将集合转换为列表返回
|
return list(person_ids_set)
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import base64
|
import base64
|
||||||
|
import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
@@ -215,94 +216,75 @@ class ImageManager:
|
|||||||
return "[表情包(处理失败)]"
|
return "[表情包(处理失败)]"
|
||||||
|
|
||||||
async def get_image_description(self, image_base64: str) -> str:
|
async def get_image_description(self, image_base64: str) -> str:
|
||||||
"""获取普通图片描述,优先使用Images表中的缓存数据"""
|
"""获取普通图片描述,采用同步识别+缓存策略"""
|
||||||
try:
|
try:
|
||||||
# 计算图片哈希
|
# 1. 计算图片哈希
|
||||||
if isinstance(image_base64, str):
|
if isinstance(image_base64, str):
|
||||||
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
||||||
image_bytes = base64.b64decode(image_base64)
|
image_bytes = base64.b64decode(image_base64)
|
||||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||||
|
|
||||||
|
# 2. 优先查询 Images 表缓存
|
||||||
async with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
# 优先检查Images表中是否已有完整的描述
|
|
||||||
result = await session.execute(select(Images).where(Images.emoji_hash == image_hash))
|
result = await session.execute(select(Images).where(Images.emoji_hash == image_hash))
|
||||||
existing_image = result.scalar()
|
existing_image = result.scalar()
|
||||||
if existing_image:
|
if existing_image and existing_image.description:
|
||||||
# 更新计数
|
logger.debug(f"[缓存命中] 使用Images表中的图片描述: {existing_image.description[:50]}...")
|
||||||
if hasattr(existing_image, "count") and existing_image.count is not None:
|
return f"[图片:{existing_image.description}]"
|
||||||
existing_image.count += 1
|
|
||||||
else:
|
# 3. 其次查询 ImageDescriptions 表缓存
|
||||||
existing_image.count = 1
|
if cached_description := await self._get_description_from_db(image_hash, "image"):
|
||||||
|
logger.debug(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description[:50]}...")
|
||||||
|
return f"[图片:{cached_description}]"
|
||||||
|
|
||||||
# 如果已有描述,直接返回
|
# 4. 如果都未命中,则同步调用VLM生成新描述
|
||||||
if existing_image.description:
|
logger.info(f"[新图片识别] 无缓存 (Hash: {image_hash[:8]}...),调用VLM生成描述")
|
||||||
await session.commit()
|
description = None
|
||||||
logger.debug(f"[缓存命中] 使用Images表中的图片描述: {existing_image.description}...")
|
prompt = global_config.custom_prompt.image_prompt
|
||||||
return f"[图片:{existing_image.description}]"
|
logger.info(f"[识图VLM调用] Prompt: {prompt}")
|
||||||
|
for i in range(3): # 重试3次
|
||||||
# 如果没有描述,继续在当前会话中操作
|
try:
|
||||||
if cached_description := await self._get_description_from_db(image_hash, "image"):
|
image_format = (Image.open(io.BytesIO(image_bytes)).format or "jpeg").lower()
|
||||||
logger.debug(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description}...")
|
logger.info(f"[VLM调用] 正在为图片生成描述 (第 {i+1}/3 次)...")
|
||||||
return f"[图片:{cached_description}]"
|
description, response_tuple = await self.vlm.generate_response_for_image(
|
||||||
|
prompt, image_base64, image_format, temperature=0.4, max_tokens=300
|
||||||
# 调用AI获取描述
|
|
||||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
|
|
||||||
prompt = global_config.custom_prompt.image_prompt
|
|
||||||
logger.info(f"[VLM调用] 为图片生成新描述 (Hash: {image_hash[:8]}...)")
|
|
||||||
description, _ = await self.vlm.generate_response_for_image(
|
|
||||||
prompt, image_base64, image_format, temperature=0.4, max_tokens=300
|
|
||||||
)
|
|
||||||
|
|
||||||
if description is None:
|
|
||||||
logger.warning("AI未能生成图片描述")
|
|
||||||
return "[图片(描述生成失败)]"
|
|
||||||
|
|
||||||
# 保存图片和描述
|
|
||||||
current_timestamp = time.time()
|
|
||||||
filename = f"{int(current_timestamp)}_{image_hash[:8]}.{image_format}"
|
|
||||||
image_dir = os.path.join(self.IMAGE_DIR, "image")
|
|
||||||
os.makedirs(image_dir, exist_ok=True)
|
|
||||||
file_path = os.path.join(image_dir, filename)
|
|
||||||
|
|
||||||
with open(file_path, "wb") as f:
|
|
||||||
f.write(image_bytes)
|
|
||||||
|
|
||||||
# 保存到数据库,补充缺失字段
|
|
||||||
if existing_image:
|
|
||||||
existing_image.path = file_path
|
|
||||||
existing_image.description = description
|
|
||||||
existing_image.timestamp = current_timestamp
|
|
||||||
if not hasattr(existing_image, "image_id") or not existing_image.image_id:
|
|
||||||
existing_image.image_id = str(uuid.uuid4())
|
|
||||||
if not hasattr(existing_image, "vlm_processed") or existing_image.vlm_processed is None:
|
|
||||||
existing_image.vlm_processed = True
|
|
||||||
logger.debug(f"[数据库] 更新已有图片记录: {image_hash[:8]}...")
|
|
||||||
else:
|
|
||||||
new_img = Images(
|
|
||||||
image_id=str(uuid.uuid4()),
|
|
||||||
emoji_hash=image_hash,
|
|
||||||
path=file_path,
|
|
||||||
type="image",
|
|
||||||
description=description,
|
|
||||||
timestamp=current_timestamp,
|
|
||||||
vlm_processed=True,
|
|
||||||
count=1,
|
|
||||||
)
|
)
|
||||||
session.add(new_img)
|
# response_tuple is (reasoning, model_name, tool_calls)
|
||||||
logger.debug(f"[数据库] 创建新图片记录: {image_hash[:8]}...")
|
model_name_used = response_tuple[1]
|
||||||
|
logger.info(f"[VLM调用成功] 使用模型: {model_name_used}")
|
||||||
|
if description and description.strip():
|
||||||
|
break # 成功获取描述则跳出循环
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"VLM调用失败 (第 {i+1}/3 次): {e}", exc_info=True)
|
||||||
|
|
||||||
|
if i < 2: # 如果不是最后一次,则等待1秒
|
||||||
|
logger.warning(f"识图失败,将在1秒后重试...")
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
if not description or not description.strip():
|
||||||
|
logger.warning("VLM未能生成有效描述")
|
||||||
|
return "[图片(描述生成失败)]"
|
||||||
|
|
||||||
|
logger.info(f"[VLM完成] 图片描述生成: {description[:50]}...")
|
||||||
|
|
||||||
|
# 5. 将新描述存入两个缓存表
|
||||||
|
await self._save_description_to_db(image_hash, description, "image")
|
||||||
|
async with get_db_session() as session:
|
||||||
|
result = await session.execute(select(Images).where(Images.emoji_hash == image_hash))
|
||||||
|
existing_image_for_update = result.scalar()
|
||||||
|
if existing_image_for_update:
|
||||||
|
existing_image_for_update.description = description
|
||||||
|
existing_image_for_update.vlm_processed = True
|
||||||
|
logger.debug(f"[数据库] 为现有图片记录补充描述: {image_hash[:8]}...")
|
||||||
|
# 注意:这里不创建新的Images记录,因为process_image会负责创建
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
|
logger.info(f"新生成的图片描述已存入缓存 (Hash: {image_hash[:8]}...)")
|
||||||
|
|
||||||
# 保存描述到ImageDescriptions表作为备用缓存
|
|
||||||
await self._save_description_to_db(image_hash, description, "image")
|
|
||||||
|
|
||||||
logger.info(f"[VLM完成] 图片描述生成: {description}...")
|
|
||||||
return f"[图片:{description}]"
|
|
||||||
|
|
||||||
logger.info(f"[VLM完成] 图片描述生成: {description}...")
|
|
||||||
return f"[图片:{description}]"
|
return f"[图片:{description}]"
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取图片描述失败: {e!s}")
|
logger.error(f"获取图片描述时发生严重错误: {e!s}", exc_info=True)
|
||||||
return "[图片(处理失败)]"
|
return "[图片(处理失败)]"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -427,96 +409,75 @@ class ImageManager:
|
|||||||
return None # 其他错误也返回None
|
return None # 其他错误也返回None
|
||||||
|
|
||||||
async def process_image(self, image_base64: str) -> tuple[str, str]:
|
async def process_image(self, image_base64: str) -> tuple[str, str]:
|
||||||
# sourcery skip: hoist-if-from-if
|
"""处理图片并返回图片ID和描述,采用同步识别流程"""
|
||||||
"""处理图片并返回图片ID和描述
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image_base64: 图片的base64编码
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[str, str]: (图片ID, 描述)
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
# 生成图片ID
|
|
||||||
# 计算图片哈希
|
|
||||||
# 确保base64字符串只包含ASCII字符
|
|
||||||
if isinstance(image_base64, str):
|
if isinstance(image_base64, str):
|
||||||
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
||||||
image_bytes = base64.b64decode(image_base64)
|
image_bytes = base64.b64decode(image_base64)
|
||||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||||
|
|
||||||
|
image_id = ""
|
||||||
|
description = ""
|
||||||
|
|
||||||
async with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
result = await session.execute(select(Images).where(Images.emoji_hash == image_hash))
|
result = await session.execute(select(Images).where(Images.emoji_hash == image_hash))
|
||||||
existing_image = result.scalar()
|
existing_image = result.scalar()
|
||||||
if existing_image:
|
|
||||||
# 检查是否缺少必要字段,如果缺少则创建新记录
|
|
||||||
if (
|
|
||||||
not hasattr(existing_image, "image_id")
|
|
||||||
or not existing_image.image_id
|
|
||||||
or not hasattr(existing_image, "count")
|
|
||||||
or existing_image.count is None
|
|
||||||
or not hasattr(existing_image, "vlm_processed")
|
|
||||||
or existing_image.vlm_processed is None
|
|
||||||
):
|
|
||||||
logger.debug(f"图片记录缺少必要字段,补全旧记录: {image_hash}")
|
|
||||||
if not existing_image.image_id:
|
|
||||||
existing_image.image_id = str(uuid.uuid4())
|
|
||||||
if existing_image.count is None:
|
|
||||||
existing_image.count = 0
|
|
||||||
if existing_image.vlm_processed is None:
|
|
||||||
existing_image.vlm_processed = False
|
|
||||||
|
|
||||||
|
if existing_image and existing_image.image_id:
|
||||||
|
image_id = existing_image.image_id
|
||||||
existing_image.count += 1
|
existing_image.count += 1
|
||||||
await session.commit()
|
logger.debug(f"图片记录已存在 (ID: {image_id}),使用次数 +1")
|
||||||
|
|
||||||
# 如果已有描述,直接返回
|
|
||||||
if existing_image.description and existing_image.description.strip():
|
if existing_image.description and existing_image.description.strip():
|
||||||
return existing_image.image_id, f"[picid:{existing_image.image_id}]"
|
description = f"[图片:{existing_image.description}]"
|
||||||
|
logger.debug("缓存命中,直接返回数据库中已有的完整描述")
|
||||||
|
return image_id, description
|
||||||
else:
|
else:
|
||||||
# 同步处理图片描述
|
logger.warning(f"图片记录 (ID: {image_id}) 描述为空,将同步生成")
|
||||||
description = await self.get_image_description(image_base64)
|
description = await self.get_image_description(image_base64)
|
||||||
# 更新数据库中的描述
|
|
||||||
existing_image.description = description.replace("[图片:", "").replace("]", "")
|
existing_image.description = description.replace("[图片:", "").replace("]", "")
|
||||||
existing_image.vlm_processed = True
|
existing_image.vlm_processed = True
|
||||||
await session.commit()
|
else:
|
||||||
return existing_image.image_id, f"[picid:{existing_image.image_id}]"
|
logger.debug(f"新图片 (Hash: {image_hash[:8]}...),将同步生成描述并创建新记录")
|
||||||
|
image_id = str(uuid.uuid4())
|
||||||
|
description = await self.get_image_description(image_base64)
|
||||||
|
|
||||||
# print(f"图片不存在: {image_hash}")
|
# 如果描述生成失败,则不存入数据库,直接返回失败信息
|
||||||
image_id = str(uuid.uuid4())
|
if "(处理失败)" in description or "(描述生成失败)" in description:
|
||||||
|
logger.warning("图片描述生成失败,不创建数据库记录,直接返回失败信息。")
|
||||||
|
return "", description
|
||||||
|
|
||||||
# 同步获取图片描述
|
clean_description = description.replace("[图片:", "").replace("]", "")
|
||||||
description = await self.get_image_description(image_base64)
|
image_format = (Image.open(io.BytesIO(image_bytes)).format or "png").lower()
|
||||||
clean_description = description.replace("[图片:", "").replace("]", "")
|
filename = f"{image_id}.{image_format}"
|
||||||
|
image_dir = os.path.join(self.IMAGE_DIR, "images")
|
||||||
|
os.makedirs(image_dir, exist_ok=True)
|
||||||
|
file_path = os.path.join(image_dir, filename)
|
||||||
|
|
||||||
# 保存新图片
|
with open(file_path, "wb") as f:
|
||||||
current_timestamp = time.time()
|
f.write(image_bytes)
|
||||||
image_dir = os.path.join(self.IMAGE_DIR, "images")
|
|
||||||
os.makedirs(image_dir, exist_ok=True)
|
|
||||||
filename = f"{image_id}.png"
|
|
||||||
file_path = os.path.join(image_dir, filename)
|
|
||||||
|
|
||||||
# 保存文件
|
new_img = Images(
|
||||||
with open(file_path, "wb") as f:
|
image_id=image_id,
|
||||||
f.write(image_bytes)
|
emoji_hash=image_hash,
|
||||||
|
path=file_path,
|
||||||
|
type="image",
|
||||||
|
description=clean_description,
|
||||||
|
timestamp=time.time(),
|
||||||
|
vlm_processed=True,
|
||||||
|
count=1,
|
||||||
|
)
|
||||||
|
session.add(new_img)
|
||||||
|
logger.info(f"新图片记录已创建 (ID: {image_id})")
|
||||||
|
|
||||||
# 保存到数据库
|
|
||||||
new_img = Images(
|
|
||||||
image_id=image_id,
|
|
||||||
emoji_hash=image_hash,
|
|
||||||
path=file_path,
|
|
||||||
type="image",
|
|
||||||
description=clean_description,
|
|
||||||
timestamp=current_timestamp,
|
|
||||||
vlm_processed=True,
|
|
||||||
count=1,
|
|
||||||
)
|
|
||||||
session.add(new_img)
|
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
return image_id, f"[picid:{image_id}]"
|
# 无论是新图片还是旧图片,只要成功获取描述,就直接返回描述
|
||||||
|
return image_id, description
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"处理图片失败: {e!s}")
|
logger.error(f"处理图片时发生严重错误: {e!s}", exc_info=True)
|
||||||
return "", "[图片]"
|
return "", "[图片(处理失败)]"
|
||||||
|
|
||||||
|
|
||||||
# 创建全局单例
|
# 创建全局单例
|
||||||
|
|||||||
Reference in New Issue
Block a user