rufffffff
This commit is contained in:
@@ -22,14 +22,14 @@ logger = get_logger("message_storage")
|
||||
class MessageStorageBatcher:
|
||||
"""
|
||||
消息存储批处理器
|
||||
|
||||
|
||||
优化: 将消息缓存一段时间后批量写入数据库,减少数据库连接池压力
|
||||
"""
|
||||
|
||||
def __init__(self, batch_size: int = 50, flush_interval: float = 5.0):
|
||||
"""
|
||||
初始化批处理器
|
||||
|
||||
|
||||
Args:
|
||||
batch_size: 批量大小,达到此数量立即写入
|
||||
flush_interval: 自动刷新间隔(秒)
|
||||
@@ -51,7 +51,7 @@ class MessageStorageBatcher:
|
||||
async def stop(self):
|
||||
"""停止批处理器"""
|
||||
self._running = False
|
||||
|
||||
|
||||
if self._flush_task:
|
||||
self._flush_task.cancel()
|
||||
try:
|
||||
@@ -67,7 +67,7 @@ class MessageStorageBatcher:
|
||||
async def add_message(self, message_data: dict):
|
||||
"""
|
||||
添加消息到批处理队列
|
||||
|
||||
|
||||
Args:
|
||||
message_data: 包含消息对象和chat_stream的字典
|
||||
{
|
||||
@@ -97,23 +97,23 @@ class MessageStorageBatcher:
|
||||
|
||||
start_time = time.time()
|
||||
success_count = 0
|
||||
|
||||
|
||||
try:
|
||||
# 🔧 优化:准备字典数据而不是ORM对象,使用批量INSERT
|
||||
messages_dicts = []
|
||||
|
||||
|
||||
for msg_data in messages_to_store:
|
||||
try:
|
||||
message_dict = await self._prepare_message_dict(
|
||||
msg_data['message'],
|
||||
msg_data['chat_stream']
|
||||
msg_data["message"],
|
||||
msg_data["chat_stream"]
|
||||
)
|
||||
if message_dict:
|
||||
messages_dicts.append(message_dict)
|
||||
except Exception as e:
|
||||
logger.error(f"准备消息数据失败: {e}")
|
||||
continue
|
||||
|
||||
|
||||
# 批量写入数据库 - 使用高效的批量INSERT
|
||||
if messages_dicts:
|
||||
from sqlalchemy import insert
|
||||
@@ -122,7 +122,7 @@ class MessageStorageBatcher:
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
success_count = len(messages_dicts)
|
||||
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(
|
||||
f"批量存储了 {success_count}/{len(messages_to_store)} 条消息 "
|
||||
@@ -134,18 +134,18 @@ class MessageStorageBatcher:
|
||||
|
||||
async def _prepare_message_dict(self, message, chat_stream):
|
||||
"""准备消息字典数据(用于批量INSERT)
|
||||
|
||||
|
||||
这个方法准备字典而不是ORM对象,性能更高
|
||||
"""
|
||||
message_obj = await self._prepare_message_object(message, chat_stream)
|
||||
if message_obj is None:
|
||||
return None
|
||||
|
||||
|
||||
# 将ORM对象转换为字典(只包含列字段)
|
||||
message_dict = {}
|
||||
for column in Messages.__table__.columns:
|
||||
message_dict[column.name] = getattr(message_obj, column.name)
|
||||
|
||||
|
||||
return message_dict
|
||||
|
||||
async def _prepare_message_object(self, message, chat_stream):
|
||||
@@ -251,12 +251,12 @@ class MessageStorageBatcher:
|
||||
is_picid = message.is_picid
|
||||
is_notify = message.is_notify
|
||||
is_command = message.is_command
|
||||
is_public_notice = getattr(message, 'is_public_notice', False)
|
||||
notice_type = getattr(message, 'notice_type', None)
|
||||
actions = getattr(message, 'actions', None)
|
||||
should_reply = getattr(message, 'should_reply', None)
|
||||
should_act = getattr(message, 'should_act', None)
|
||||
additional_config = getattr(message, 'additional_config', None)
|
||||
is_public_notice = getattr(message, "is_public_notice", False)
|
||||
notice_type = getattr(message, "notice_type", None)
|
||||
actions = getattr(message, "actions", None)
|
||||
should_reply = getattr(message, "should_reply", None)
|
||||
should_act = getattr(message, "should_act", None)
|
||||
additional_config = getattr(message, "additional_config", None)
|
||||
key_words = MessageStorage._serialize_keywords(message.key_words)
|
||||
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
|
||||
|
||||
@@ -349,7 +349,7 @@ class MessageStorageBatcher:
|
||||
|
||||
|
||||
# 全局批处理器实例
|
||||
_message_storage_batcher: Optional[MessageStorageBatcher] = None
|
||||
_message_storage_batcher: MessageStorageBatcher | None = None
|
||||
_message_update_batcher: Optional["MessageUpdateBatcher"] = None
|
||||
|
||||
|
||||
@@ -367,7 +367,7 @@ def get_message_storage_batcher() -> MessageStorageBatcher:
|
||||
class MessageUpdateBatcher:
|
||||
"""
|
||||
消息更新批处理器
|
||||
|
||||
|
||||
优化: 将多个消息ID更新操作批量处理,减少数据库连接次数
|
||||
"""
|
||||
|
||||
@@ -478,7 +478,7 @@ class MessageStorage:
|
||||
async def store_message(message: DatabaseMessages | MessageSending, chat_stream: ChatStream, use_batch: bool = True) -> None:
|
||||
"""
|
||||
存储消息到数据库
|
||||
|
||||
|
||||
Args:
|
||||
message: 消息对象
|
||||
chat_stream: 聊天流对象
|
||||
@@ -488,11 +488,11 @@ class MessageStorage:
|
||||
if use_batch:
|
||||
batcher = get_message_storage_batcher()
|
||||
await batcher.add_message({
|
||||
'message': message,
|
||||
'chat_stream': chat_stream
|
||||
"message": message,
|
||||
"chat_stream": chat_stream
|
||||
})
|
||||
return
|
||||
|
||||
|
||||
# 直接写入模式(保留用于特殊场景)
|
||||
try:
|
||||
# 过滤敏感信息的正则模式
|
||||
@@ -676,9 +676,9 @@ class MessageStorage:
|
||||
async def update_message(message_data: dict, use_batch: bool = True):
|
||||
"""
|
||||
更新消息ID(从消息字典)
|
||||
|
||||
|
||||
优化: 添加批处理选项,将多个更新操作合并,减少数据库连接
|
||||
|
||||
|
||||
Args:
|
||||
message_data: 消息数据字典
|
||||
use_batch: 是否使用批处理(默认True)
|
||||
|
||||
Reference in New Issue
Block a user