Files
Mofox-Core/src/chat/message_receive/storage.py
2025-09-22 12:33:59 +08:00

222 lines
9.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import re
import traceback
import orjson
from typing import Union
from src.common.database.sqlalchemy_models import Messages, Images
from src.common.logger import get_logger
from .chat_stream import ChatStream
from .message import MessageSending, MessageRecv
from src.common.database.sqlalchemy_database_api import get_db_session
from sqlalchemy import select, update, desc
logger = get_logger("message_storage")
class MessageStorage:
@staticmethod
def _serialize_keywords(keywords) -> str:
"""将关键词列表序列化为JSON字符串"""
if isinstance(keywords, list):
return orjson.dumps(keywords).decode("utf-8")
return "[]"
@staticmethod
def _deserialize_keywords(keywords_str: str) -> list:
"""将JSON字符串反序列化为关键词列表"""
if not keywords_str:
return []
try:
return orjson.loads(keywords_str)
except (orjson.JSONDecodeError, TypeError):
return []
@staticmethod
async def store_message(message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None:
"""存储消息到数据库"""
try:
# 过滤敏感信息的正则模式
pattern = r"<MainRule>.*?</MainRule>|<schedule>.*?</schedule>|<UserMessage>.*?</UserMessage>"
processed_plain_text = message.processed_plain_text
if processed_plain_text:
processed_plain_text = MessageStorage.replace_image_descriptions(processed_plain_text)
filtered_processed_plain_text = re.sub(pattern, "", processed_plain_text, flags=re.DOTALL)
else:
filtered_processed_plain_text = ""
if isinstance(message, MessageSending):
display_message = message.display_message
if display_message:
filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
else:
# 如果没有设置display_message使用processed_plain_text作为显示消息
filtered_display_message = re.sub(pattern, "", message.processed_plain_text, flags=re.DOTALL) if message.processed_plain_text else ""
interest_value = 0
is_mentioned = False
reply_to = message.reply_to
priority_mode = ""
priority_info = {}
is_emoji = False
is_picid = False
is_notify = False
is_command = False
key_words = ""
key_words_lite = ""
else:
filtered_display_message = ""
interest_value = message.interest_value
is_mentioned = message.is_mentioned
reply_to = ""
priority_mode = message.priority_mode
priority_info = message.priority_info
is_emoji = message.is_emoji
is_picid = message.is_picid
is_notify = message.is_notify
is_command = message.is_command
# 序列化关键词列表为JSON字符串
key_words = MessageStorage._serialize_keywords(message.key_words)
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
chat_info_dict = chat_stream.to_dict()
user_info_dict = message.message_info.user_info.to_dict() # type: ignore
# message_id 现在是 TextField直接使用字符串值
msg_id = message.message_info.message_id
# 安全地获取 group_info, 如果为 None 则视为空字典
group_info_from_chat = chat_info_dict.get("group_info") or {}
# 安全地获取 user_info, 如果为 None 则视为空字典 (以防万一)
user_info_from_chat = chat_info_dict.get("user_info") or {}
# 将priority_info字典序列化为JSON字符串以便存储到数据库的Text字段
priority_info_json = orjson.dumps(priority_info).decode("utf-8") if priority_info else None
# 获取数据库会话
new_message = Messages(
message_id=msg_id,
time=float(message.message_info.time),
chat_id=chat_stream.stream_id,
reply_to=reply_to,
is_mentioned=is_mentioned,
chat_info_stream_id=chat_info_dict.get("stream_id"),
chat_info_platform=chat_info_dict.get("platform"),
chat_info_user_platform=user_info_from_chat.get("platform"),
chat_info_user_id=user_info_from_chat.get("user_id"),
chat_info_user_nickname=user_info_from_chat.get("user_nickname"),
chat_info_user_cardname=user_info_from_chat.get("user_cardname"),
chat_info_group_platform=group_info_from_chat.get("platform"),
chat_info_group_id=group_info_from_chat.get("group_id"),
chat_info_group_name=group_info_from_chat.get("group_name"),
chat_info_create_time=float(chat_info_dict.get("create_time", 0.0)),
chat_info_last_active_time=float(chat_info_dict.get("last_active_time", 0.0)),
user_platform=user_info_dict.get("platform"),
user_id=user_info_dict.get("user_id"),
user_nickname=user_info_dict.get("user_nickname"),
user_cardname=user_info_dict.get("user_cardname"),
processed_plain_text=filtered_processed_plain_text,
display_message=filtered_display_message,
memorized_times=message.memorized_times,
interest_value=interest_value,
priority_mode=priority_mode,
priority_info=priority_info_json,
is_emoji=is_emoji,
is_picid=is_picid,
is_notify=is_notify,
is_command=is_command,
key_words=key_words,
key_words_lite=key_words_lite,
)
with get_db_session() as session:
session.add(new_message)
session.commit()
except Exception:
logger.exception("存储消息失败")
logger.error(f"消息:{message}")
traceback.print_exc()
@staticmethod
async def update_message(message):
"""更新消息ID"""
try:
mmc_message_id = message.message_info.message_id
qq_message_id = None
logger.debug(f"尝试更新消息ID: {mmc_message_id}, 消息段类型: {message.message_segment.type}")
# 根据消息段类型提取message_id
if message.message_segment.type == "notify":
qq_message_id = message.message_segment.data.get("id")
elif message.message_segment.type == "text":
qq_message_id = message.message_segment.data.get("id")
elif message.message_segment.type == "reply":
qq_message_id = message.message_segment.data.get("id")
if qq_message_id:
logger.debug(f"从reply消息段获取到消息ID: {qq_message_id}")
elif message.message_segment.type == "adapter_response":
logger.debug("适配器响应消息不需要更新ID")
return
elif message.message_segment.type == "adapter_command":
logger.debug("适配器命令消息不需要更新ID")
return
else:
logger.debug(f"未知的消息段类型: {message.message_segment.type}跳过ID更新")
return
if not qq_message_id:
logger.debug(f"消息段类型 {message.message_segment.type} 中未找到有效的message_id跳过更新")
logger.debug(f"消息段数据: {message.message_segment.data}")
return
# 使用上下文管理器确保session正确管理
from src.common.database.sqlalchemy_models import get_db_session
with get_db_session() as session:
matched_message = session.execute(
select(Messages).where(Messages.message_id == mmc_message_id).order_by(desc(Messages.time))
).scalar()
if matched_message:
session.execute(
update(Messages).where(Messages.id == matched_message.id).values(message_id=qq_message_id)
)
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
else:
logger.warning(f"未找到匹配的消息记录: {mmc_message_id}")
except Exception as e:
logger.error(f"更新消息ID失败: {e}")
logger.error(
f"消息信息: message_id={getattr(message.message_info, 'message_id', 'N/A')}, "
f"segment_type={getattr(message.message_segment, 'type', 'N/A')}"
)
@staticmethod
def replace_image_descriptions(text: str) -> str:
"""将[图片:描述]替换为[picid:image_id]"""
# 先检查文本中是否有图片标记
pattern = r"\[图片:([^\]]+)\]"
matches = re.findall(pattern, text)
if not matches:
logger.debug("文本中没有图片标记,直接返回原文本")
return text
def replace_match(match):
description = match.group(1).strip()
try:
from src.common.database.sqlalchemy_models import get_db_session
with get_db_session() as session:
image_record = session.execute(
select(Images).where(Images.description == description).order_by(desc(Images.timestamp))
).scalar()
return f"[picid:{image_record.image_id}]" if image_record else match.group(0)
except Exception:
return match.group(0)
return re.sub(r"\[图片:([^\]]+)\]", replace_match, text)