diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index 21937f36c..5ace4a1d4 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -613,22 +613,46 @@ class MessageStorage: try: async with get_db_session() as session: - # 构建批量更新映射,提高数据库批量操作效率 - mappings: list[dict[str, Any]] = [] - for message_id, interest_value in interest_map.items(): - mapping = {"message_id": message_id, "interest_value": interest_value} - if reply_map and message_id in reply_map: - mapping["should_reply"] = reply_map[message_id] - mappings.append(mapping) + # 注意:SQLAlchemy 2.0 对 ORM update + executemany 会走 + # “Bulk UPDATE by Primary Key” 路径,要求每行参数包含主键(Messages.id)。 + # 这里我们按 message_id 更新,因此使用 Core Table + bindparam。 + from sqlalchemy import bindparam, update - # 使用 bulk 操作替代逐条 UPDATE,大幅减少数据库往返 - if mappings: - await session.execute( - update(Messages), - mappings, + messages_table = Messages.__table__ + + interest_mappings: list[dict[str, Any]] = [ + {"b_message_id": message_id, "b_interest_value": interest_value} + for message_id, interest_value in interest_map.items() + ] + + if interest_mappings: + stmt_interest = ( + update(messages_table) + .where(messages_table.c.message_id == bindparam("b_message_id")) + .values(interest_value=bindparam("b_interest_value")) ) - await session.commit() - logger.debug(f"批量更新兴趣度 {len(interest_map)} 条记录") + await session.execute(stmt_interest, interest_mappings) + + if reply_map: + reply_mappings: list[dict[str, Any]] = [ + {"b_message_id": message_id, "b_should_reply": should_reply} + for message_id, should_reply in reply_map.items() + if message_id in interest_map + ] + if reply_mappings and len(reply_mappings) != len(reply_map): + logger.debug( + f"批量更新 should_reply 过滤了 {len(reply_map) - len(reply_mappings)} 条不在兴趣度更新集合中的记录" + ) + if reply_mappings: + stmt_reply = ( + update(messages_table) + .where(messages_table.c.message_id == bindparam("b_message_id")) + .values(should_reply=bindparam("b_should_reply")) + ) + await session.execute(stmt_reply, reply_mappings) + + await session.commit() + logger.debug(f"批量更新兴趣度 {len(interest_map)} 条记录") except Exception as e: logger.error(f"批量更新消息兴趣度失败: {e}") raise