refactor(bot): 使用统一方法转换消息为数据库对象,简化代码逻辑
This commit is contained in:
@@ -144,6 +144,134 @@ class MessageRecv(Message):
|
||||
def update_chat_stream(self, chat_stream: "ChatStream"):
|
||||
self.chat_stream = chat_stream
|
||||
|
||||
def to_database_message(self) -> "DatabaseMessages":
|
||||
"""将 MessageRecv 转换为 DatabaseMessages 对象
|
||||
|
||||
Returns:
|
||||
DatabaseMessages: 数据库消息对象
|
||||
"""
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
import json
|
||||
import time
|
||||
|
||||
message_info = self.message_info
|
||||
msg_user_info = getattr(message_info, "user_info", None)
|
||||
stream_user_info = getattr(self.chat_stream, "user_info", None) if self.chat_stream else None
|
||||
group_info = getattr(self.chat_stream, "group_info", None) if self.chat_stream else None
|
||||
|
||||
message_id = message_info.message_id or ""
|
||||
message_time = message_info.time if hasattr(message_info, "time") and message_info.time is not None else time.time()
|
||||
is_mentioned = None
|
||||
if isinstance(self.is_mentioned, bool):
|
||||
is_mentioned = self.is_mentioned
|
||||
elif isinstance(self.is_mentioned, int | float):
|
||||
is_mentioned = self.is_mentioned != 0
|
||||
|
||||
# 提取用户信息
|
||||
user_id = ""
|
||||
user_nickname = ""
|
||||
user_cardname = None
|
||||
user_platform = ""
|
||||
if msg_user_info:
|
||||
user_id = str(getattr(msg_user_info, "user_id", "") or "")
|
||||
user_nickname = getattr(msg_user_info, "user_nickname", "") or ""
|
||||
user_cardname = getattr(msg_user_info, "user_cardname", None)
|
||||
user_platform = getattr(msg_user_info, "platform", "") or ""
|
||||
elif stream_user_info:
|
||||
user_id = str(getattr(stream_user_info, "user_id", "") or "")
|
||||
user_nickname = getattr(stream_user_info, "user_nickname", "") or ""
|
||||
user_cardname = getattr(stream_user_info, "user_cardname", None)
|
||||
user_platform = getattr(stream_user_info, "platform", "") or ""
|
||||
|
||||
# 提取聊天流信息
|
||||
chat_user_id = str(getattr(stream_user_info, "user_id", "") or "") if stream_user_info else ""
|
||||
chat_user_nickname = getattr(stream_user_info, "user_nickname", "") or "" if stream_user_info else ""
|
||||
chat_user_cardname = getattr(stream_user_info, "user_cardname", None) if stream_user_info else None
|
||||
chat_user_platform = getattr(stream_user_info, "platform", "") or "" if stream_user_info else ""
|
||||
|
||||
group_id = getattr(group_info, "group_id", None) if group_info else None
|
||||
group_name = getattr(group_info, "group_name", None) if group_info else None
|
||||
group_platform = getattr(group_info, "platform", None) if group_info else None
|
||||
|
||||
# 准备 additional_config
|
||||
additional_config_str = None
|
||||
try:
|
||||
import orjson
|
||||
|
||||
additional_config_data = {}
|
||||
|
||||
# 首先获取adapter传递的additional_config
|
||||
if hasattr(message_info, 'additional_config') and message_info.additional_config:
|
||||
if isinstance(message_info.additional_config, dict):
|
||||
additional_config_data = message_info.additional_config.copy()
|
||||
elif isinstance(message_info.additional_config, str):
|
||||
try:
|
||||
additional_config_data = orjson.loads(message_info.additional_config)
|
||||
except Exception as e:
|
||||
logger.warning(f"无法解析 additional_config JSON: {e}")
|
||||
additional_config_data = {}
|
||||
|
||||
# 添加notice相关标志
|
||||
if self.is_notify:
|
||||
additional_config_data["is_notice"] = True
|
||||
additional_config_data["notice_type"] = self.notice_type or "unknown"
|
||||
additional_config_data["is_public_notice"] = bool(self.is_public_notice)
|
||||
|
||||
# 添加format_info到additional_config中
|
||||
if hasattr(message_info, 'format_info') and message_info.format_info:
|
||||
try:
|
||||
format_info_dict = message_info.format_info.to_dict()
|
||||
additional_config_data["format_info"] = format_info_dict
|
||||
logger.debug(f"[message.py] 嵌入 format_info 到 additional_config: {format_info_dict}")
|
||||
except Exception as e:
|
||||
logger.warning(f"将 format_info 转换为字典失败: {e}")
|
||||
|
||||
# 序列化为JSON字符串
|
||||
if additional_config_data:
|
||||
additional_config_str = orjson.dumps(additional_config_data).decode("utf-8")
|
||||
except Exception as e:
|
||||
logger.error(f"准备 additional_config 失败: {e}")
|
||||
|
||||
# 创建数据库消息对象
|
||||
db_message = DatabaseMessages(
|
||||
message_id=message_id,
|
||||
time=float(message_time),
|
||||
chat_id=self.chat_stream.stream_id if self.chat_stream else "",
|
||||
processed_plain_text=self.processed_plain_text,
|
||||
display_message=self.processed_plain_text,
|
||||
is_mentioned=is_mentioned,
|
||||
is_at=bool(self.is_at) if self.is_at is not None else None,
|
||||
is_emoji=bool(self.is_emoji),
|
||||
is_picid=bool(self.is_picid),
|
||||
is_command=bool(self.is_command),
|
||||
is_notify=bool(self.is_notify),
|
||||
is_public_notice=bool(self.is_public_notice),
|
||||
notice_type=self.notice_type,
|
||||
additional_config=additional_config_str,
|
||||
user_id=user_id,
|
||||
user_nickname=user_nickname,
|
||||
user_cardname=user_cardname,
|
||||
user_platform=user_platform,
|
||||
chat_info_stream_id=self.chat_stream.stream_id if self.chat_stream else "",
|
||||
chat_info_platform=self.chat_stream.platform if self.chat_stream else "",
|
||||
chat_info_create_time=float(self.chat_stream.create_time) if self.chat_stream else 0.0,
|
||||
chat_info_last_active_time=float(self.chat_stream.last_active_time) if self.chat_stream else 0.0,
|
||||
chat_info_user_id=chat_user_id,
|
||||
chat_info_user_nickname=chat_user_nickname,
|
||||
chat_info_user_cardname=chat_user_cardname,
|
||||
chat_info_user_platform=chat_user_platform,
|
||||
chat_info_group_id=group_id,
|
||||
chat_info_group_name=group_name,
|
||||
chat_info_group_platform=group_platform,
|
||||
)
|
||||
|
||||
# 同步兴趣度等衍生属性
|
||||
db_message.interest_value = getattr(self, "interest_value", 0.0)
|
||||
setattr(db_message, "should_reply", getattr(self, "should_reply", False))
|
||||
setattr(db_message, "should_act", getattr(self, "should_act", False))
|
||||
|
||||
return db_message
|
||||
|
||||
async def process(self) -> None:
|
||||
"""处理消息内容,生成纯文本和详细文本
|
||||
|
||||
@@ -479,64 +607,9 @@ class MessageSending(MessageProcessBase):
|
||||
return self.message_info.group_info is None or self.message_info.group_info.group_id is None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageSet:
|
||||
"""消息集合类,可以存储多个发送消息"""
|
||||
|
||||
def __init__(self, chat_stream: "ChatStream", message_id: str):
|
||||
self.chat_stream = chat_stream
|
||||
self.message_id = message_id
|
||||
self.messages: list[MessageSending] = []
|
||||
self.time = round(time.time(), 3) # 保留3位小数
|
||||
|
||||
def add_message(self, message: MessageSending) -> None:
|
||||
"""添加消息到集合"""
|
||||
if not isinstance(message, MessageSending):
|
||||
raise TypeError("MessageSet只能添加MessageSending类型的消息")
|
||||
self.messages.append(message)
|
||||
self.messages.sort(key=lambda x: x.message_info.time) # type: ignore
|
||||
|
||||
def get_message_by_index(self, index: int) -> MessageSending | None:
|
||||
"""通过索引获取消息"""
|
||||
return self.messages[index] if 0 <= index < len(self.messages) else None
|
||||
|
||||
def get_message_by_time(self, target_time: float) -> MessageSending | None:
|
||||
"""获取最接近指定时间的消息"""
|
||||
if not self.messages:
|
||||
return None
|
||||
|
||||
left, right = 0, len(self.messages) - 1
|
||||
while left < right:
|
||||
mid = (left + right) // 2
|
||||
if self.messages[mid].message_info.time < target_time: # type: ignore
|
||||
left = mid + 1
|
||||
else:
|
||||
right = mid
|
||||
|
||||
return self.messages[left]
|
||||
|
||||
def clear_messages(self) -> None:
|
||||
"""清空所有消息"""
|
||||
self.messages.clear()
|
||||
|
||||
def remove_message(self, message: MessageSending) -> bool:
|
||||
"""移除指定消息"""
|
||||
if message in self.messages:
|
||||
self.messages.remove(message)
|
||||
return True
|
||||
return False
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"MessageSet(id={self.message_id}, count={len(self.messages)})"
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.messages)
|
||||
|
||||
|
||||
def message_recv_from_dict(message_dict: dict) -> MessageRecv:
|
||||
return MessageRecv(message_dict)
|
||||
|
||||
|
||||
def message_from_db_dict(db_dict: dict) -> MessageRecv:
|
||||
"""从数据库字典创建MessageRecv实例"""
|
||||
# 转换扁平的数据库字典为嵌套结构
|
||||
|
||||
Reference in New Issue
Block a user