Merge branch 'dev' of https://github.com/MoFox-Studio/MoFox_Bot into dev
This commit is contained in:
@@ -99,37 +99,55 @@ class MessageStorageBatcher:
|
|||||||
success_count = 0
|
success_count = 0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 准备所有消息对象
|
# 🔧 优化:准备字典数据而不是ORM对象,使用批量INSERT
|
||||||
messages_objects = []
|
messages_dicts = []
|
||||||
|
|
||||||
for msg_data in messages_to_store:
|
for msg_data in messages_to_store:
|
||||||
try:
|
try:
|
||||||
message_obj = await self._prepare_message_object(
|
message_dict = await self._prepare_message_dict(
|
||||||
msg_data['message'],
|
msg_data['message'],
|
||||||
msg_data['chat_stream']
|
msg_data['chat_stream']
|
||||||
)
|
)
|
||||||
if message_obj:
|
if message_dict:
|
||||||
messages_objects.append(message_obj)
|
messages_dicts.append(message_dict)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"准备消息对象失败: {e}")
|
logger.error(f"准备消息数据失败: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 批量写入数据库
|
# 批量写入数据库 - 使用高效的批量INSERT
|
||||||
if messages_objects:
|
if messages_dicts:
|
||||||
|
from sqlalchemy import insert
|
||||||
async with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
session.add_all(messages_objects)
|
stmt = insert(Messages).values(messages_dicts)
|
||||||
|
await session.execute(stmt)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
success_count = len(messages_objects)
|
success_count = len(messages_dicts)
|
||||||
|
|
||||||
elapsed = time.time() - start_time
|
elapsed = time.time() - start_time
|
||||||
logger.info(
|
logger.info(
|
||||||
f"批量存储了 {success_count}/{len(messages_to_store)} 条消息 "
|
f"批量存储了 {success_count}/{len(messages_to_store)} 条消息 "
|
||||||
f"(耗时: {elapsed:.3f}秒)"
|
f"(耗时: {elapsed:.3f}秒, 平均 {elapsed/max(success_count,1)*1000:.2f}ms/条)"
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"批量存储消息失败: {e}", exc_info=True)
|
logger.error(f"批量存储消息失败: {e}", exc_info=True)
|
||||||
|
|
||||||
|
async def _prepare_message_dict(self, message, chat_stream):
|
||||||
|
"""准备消息字典数据(用于批量INSERT)
|
||||||
|
|
||||||
|
这个方法准备字典而不是ORM对象,性能更高
|
||||||
|
"""
|
||||||
|
message_obj = await self._prepare_message_object(message, chat_stream)
|
||||||
|
if message_obj is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 将ORM对象转换为字典(只包含列字段)
|
||||||
|
message_dict = {}
|
||||||
|
for column in Messages.__table__.columns:
|
||||||
|
message_dict[column.name] = getattr(message_obj, column.name)
|
||||||
|
|
||||||
|
return message_dict
|
||||||
|
|
||||||
async def _prepare_message_object(self, message, chat_stream):
|
async def _prepare_message_object(self, message, chat_stream):
|
||||||
"""准备消息对象(从原 store_message 逻辑提取)"""
|
"""准备消息对象(从原 store_message 逻辑提取)"""
|
||||||
try:
|
try:
|
||||||
@@ -161,6 +179,12 @@ class MessageStorageBatcher:
|
|||||||
is_picid = message.is_picid or False
|
is_picid = message.is_picid or False
|
||||||
is_notify = message.is_notify or False
|
is_notify = message.is_notify or False
|
||||||
is_command = message.is_command or False
|
is_command = message.is_command or False
|
||||||
|
is_public_notice = message.is_public_notice or False
|
||||||
|
notice_type = message.notice_type
|
||||||
|
actions = message.actions
|
||||||
|
should_reply = message.should_reply
|
||||||
|
should_act = message.should_act
|
||||||
|
additional_config = message.additional_config
|
||||||
key_words = ""
|
key_words = ""
|
||||||
key_words_lite = ""
|
key_words_lite = ""
|
||||||
memorized_times = 0
|
memorized_times = 0
|
||||||
@@ -208,6 +232,12 @@ class MessageStorageBatcher:
|
|||||||
is_picid = False
|
is_picid = False
|
||||||
is_notify = False
|
is_notify = False
|
||||||
is_command = False
|
is_command = False
|
||||||
|
is_public_notice = False
|
||||||
|
notice_type = None
|
||||||
|
actions = None
|
||||||
|
should_reply = None
|
||||||
|
should_act = None
|
||||||
|
additional_config = None
|
||||||
key_words = ""
|
key_words = ""
|
||||||
key_words_lite = ""
|
key_words_lite = ""
|
||||||
else:
|
else:
|
||||||
@@ -221,6 +251,12 @@ class MessageStorageBatcher:
|
|||||||
is_picid = message.is_picid
|
is_picid = message.is_picid
|
||||||
is_notify = message.is_notify
|
is_notify = message.is_notify
|
||||||
is_command = message.is_command
|
is_command = message.is_command
|
||||||
|
is_public_notice = getattr(message, 'is_public_notice', False)
|
||||||
|
notice_type = getattr(message, 'notice_type', None)
|
||||||
|
actions = getattr(message, 'actions', None)
|
||||||
|
should_reply = getattr(message, 'should_reply', None)
|
||||||
|
should_act = getattr(message, 'should_act', None)
|
||||||
|
additional_config = getattr(message, 'additional_config', None)
|
||||||
key_words = MessageStorage._serialize_keywords(message.key_words)
|
key_words = MessageStorage._serialize_keywords(message.key_words)
|
||||||
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
|
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
|
||||||
|
|
||||||
@@ -282,10 +318,16 @@ class MessageStorageBatcher:
|
|||||||
interest_value=interest_value,
|
interest_value=interest_value,
|
||||||
priority_mode=priority_mode,
|
priority_mode=priority_mode,
|
||||||
priority_info=priority_info_json,
|
priority_info=priority_info_json,
|
||||||
|
additional_config=additional_config,
|
||||||
is_emoji=is_emoji,
|
is_emoji=is_emoji,
|
||||||
is_picid=is_picid,
|
is_picid=is_picid,
|
||||||
is_notify=is_notify,
|
is_notify=is_notify,
|
||||||
is_command=is_command,
|
is_command=is_command,
|
||||||
|
is_public_notice=is_public_notice,
|
||||||
|
notice_type=notice_type,
|
||||||
|
actions=actions,
|
||||||
|
should_reply=should_reply,
|
||||||
|
should_act=should_act,
|
||||||
key_words=key_words,
|
key_words=key_words,
|
||||||
key_words_lite=key_words_lite,
|
key_words_lite=key_words_lite,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -266,7 +266,14 @@ class CRUDBase:
|
|||||||
await session.refresh(instance)
|
await session.refresh(instance)
|
||||||
# 注意:commit在get_db_session的context manager退出时自动执行
|
# 注意:commit在get_db_session的context manager退出时自动执行
|
||||||
# 但为了明确性,这里不需要显式commit
|
# 但为了明确性,这里不需要显式commit
|
||||||
return instance
|
|
||||||
|
# 注意:create不清除缓存,因为:
|
||||||
|
# 1. 新记录不会影响已有的单条查询缓存(get/get_by)
|
||||||
|
# 2. get_multi的缓存会自然过期(TTL机制)
|
||||||
|
# 3. 清除所有缓存代价太大,影响性能
|
||||||
|
# 如果需要强一致性,应该在查询时设置use_cache=False
|
||||||
|
|
||||||
|
return instance
|
||||||
|
|
||||||
async def update(
|
async def update(
|
||||||
self,
|
self,
|
||||||
@@ -460,7 +467,14 @@ class CRUDBase:
|
|||||||
for instance in instances:
|
for instance in instances:
|
||||||
await session.refresh(instance)
|
await session.refresh(instance)
|
||||||
|
|
||||||
return instances
|
# 批量创建的缓存策略:
|
||||||
|
# bulk_create通常用于批量导入场景,此时清除缓存是合理的
|
||||||
|
# 因为可能创建大量记录,缓存的列表查询会明显过期
|
||||||
|
cache = await get_cache()
|
||||||
|
await cache.clear()
|
||||||
|
logger.info(f"批量创建{len(instances)}条{self.model_name}记录后已清除缓存")
|
||||||
|
|
||||||
|
return instances
|
||||||
|
|
||||||
async def bulk_update(
|
async def bulk_update(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -393,8 +393,10 @@ class AdaptiveBatchScheduler:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""批量执行更新操作"""
|
"""批量执行更新操作"""
|
||||||
async with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
for op in operations:
|
results = []
|
||||||
try:
|
try:
|
||||||
|
# 🔧 修复:收集所有操作后一次性commit,而不是循环中多次commit
|
||||||
|
for op in operations:
|
||||||
# 构建更新语句
|
# 构建更新语句
|
||||||
stmt = update(op.model_class)
|
stmt = update(op.model_class)
|
||||||
for key, value in op.conditions.items():
|
for key, value in op.conditions.items():
|
||||||
@@ -404,23 +406,29 @@ class AdaptiveBatchScheduler:
|
|||||||
if op.data:
|
if op.data:
|
||||||
stmt = stmt.values(**op.data)
|
stmt = stmt.values(**op.data)
|
||||||
|
|
||||||
# 执行更新
|
# 执行更新(但不commit)
|
||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
await session.commit()
|
results.append((op, result.rowcount))
|
||||||
|
|
||||||
# 设置结果
|
# 所有操作成功后,一次性commit
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
# 设置所有操作的结果
|
||||||
|
for op, rowcount in results:
|
||||||
if op.future and not op.future.done():
|
if op.future and not op.future.done():
|
||||||
op.future.set_result(result.rowcount)
|
op.future.set_result(rowcount)
|
||||||
|
|
||||||
if op.callback:
|
if op.callback:
|
||||||
try:
|
try:
|
||||||
op.callback(result.rowcount)
|
op.callback(rowcount)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"回调执行失败: {e}")
|
logger.warning(f"回调执行失败: {e}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"更新失败: {e}", exc_info=True)
|
logger.error(f"批量更新失败: {e}", exc_info=True)
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
|
# 所有操作都失败
|
||||||
|
for op in operations:
|
||||||
if op.future and not op.future.done():
|
if op.future and not op.future.done():
|
||||||
op.future.set_exception(e)
|
op.future.set_exception(e)
|
||||||
|
|
||||||
@@ -430,31 +438,39 @@ class AdaptiveBatchScheduler:
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""批量执行删除操作"""
|
"""批量执行删除操作"""
|
||||||
async with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
for op in operations:
|
results = []
|
||||||
try:
|
try:
|
||||||
|
# 🔧 修复:收集所有操作后一次性commit,而不是循环中多次commit
|
||||||
|
for op in operations:
|
||||||
# 构建删除语句
|
# 构建删除语句
|
||||||
stmt = delete(op.model_class)
|
stmt = delete(op.model_class)
|
||||||
for key, value in op.conditions.items():
|
for key, value in op.conditions.items():
|
||||||
attr = getattr(op.model_class, key)
|
attr = getattr(op.model_class, key)
|
||||||
stmt = stmt.where(attr == value)
|
stmt = stmt.where(attr == value)
|
||||||
|
|
||||||
# 执行删除
|
# 执行删除(但不commit)
|
||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
await session.commit()
|
results.append((op, result.rowcount))
|
||||||
|
|
||||||
# 设置结果
|
# 所有操作成功后,一次性commit
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
# 设置所有操作的结果
|
||||||
|
for op, rowcount in results:
|
||||||
if op.future and not op.future.done():
|
if op.future and not op.future.done():
|
||||||
op.future.set_result(result.rowcount)
|
op.future.set_result(rowcount)
|
||||||
|
|
||||||
if op.callback:
|
if op.callback:
|
||||||
try:
|
try:
|
||||||
op.callback(result.rowcount)
|
op.callback(rowcount)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"回调执行失败: {e}")
|
logger.warning(f"回调执行失败: {e}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"删除失败: {e}", exc_info=True)
|
logger.error(f"批量删除失败: {e}", exc_info=True)
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
|
# 所有操作都失败
|
||||||
|
for op in operations:
|
||||||
if op.future and not op.future.done():
|
if op.future and not op.future.done():
|
||||||
op.future.set_exception(e)
|
op.future.set_exception(e)
|
||||||
|
|
||||||
|
|||||||
@@ -59,6 +59,20 @@ class ProactiveThinkingReplyHandler(BaseEventHandler):
|
|||||||
logger.debug("[主动思考事件] reply_reset_enabled 为 False,跳过重置")
|
logger.debug("[主动思考事件] reply_reset_enabled 为 False,跳过重置")
|
||||||
return HandlerResult(success=True, continue_process=True, message=None)
|
return HandlerResult(success=True, continue_process=True, message=None)
|
||||||
|
|
||||||
|
# 检查白名单/黑名单(获取 stream_config 进行验证)
|
||||||
|
try:
|
||||||
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
|
chat_manager = get_chat_manager()
|
||||||
|
chat_stream = await chat_manager.get_stream(stream_id)
|
||||||
|
|
||||||
|
if chat_stream:
|
||||||
|
stream_config = chat_stream.get_raw_id()
|
||||||
|
if not proactive_thinking_scheduler._check_whitelist_blacklist(stream_config):
|
||||||
|
logger.debug(f"[主动思考事件] 聊天流 {stream_id} ({stream_config}) 不在白名单中,跳过重置")
|
||||||
|
return HandlerResult(success=True, continue_process=True, message=None)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[主动思考事件] 白名单检查时出错: {e}")
|
||||||
|
|
||||||
# 检查是否被暂停
|
# 检查是否被暂停
|
||||||
was_paused = await proactive_thinking_scheduler.is_paused(stream_id)
|
was_paused = await proactive_thinking_scheduler.is_paused(stream_id)
|
||||||
logger.debug(f"[主动思考事件] 聊天流 {stream_id} 暂停状态: {was_paused}")
|
logger.debug(f"[主动思考事件] 聊天流 {stream_id} 暂停状态: {was_paused}")
|
||||||
|
|||||||
@@ -541,10 +541,32 @@ async def execute_proactive_thinking(stream_id: str):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# 0. 前置检查
|
# 0. 前置检查
|
||||||
|
# 0.1 检查白名单/黑名单
|
||||||
|
# 从 stream_id 获取 stream_config 字符串进行验证
|
||||||
|
try:
|
||||||
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
|
chat_manager = get_chat_manager()
|
||||||
|
chat_stream = await chat_manager.get_stream(stream_id)
|
||||||
|
|
||||||
|
if chat_stream:
|
||||||
|
# 使用 ChatStream 的 get_raw_id() 方法获取配置字符串
|
||||||
|
stream_config = chat_stream.get_raw_id()
|
||||||
|
|
||||||
|
# 执行白名单/黑名单检查
|
||||||
|
if not proactive_thinking_scheduler._check_whitelist_blacklist(stream_config):
|
||||||
|
logger.debug(f"聊天流 {stream_id} ({stream_config}) 未通过白名单/黑名单检查,跳过主动思考")
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
logger.warning(f"无法获取聊天流 {stream_id} 的信息,跳过白名单检查")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"白名单检查时出错: {e},继续执行")
|
||||||
|
|
||||||
|
# 0.2 检查安静时段
|
||||||
if proactive_thinking_scheduler._is_in_quiet_hours():
|
if proactive_thinking_scheduler._is_in_quiet_hours():
|
||||||
logger.debug("安静时段,跳过")
|
logger.debug("安静时段,跳过")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# 0.3 检查每日限制
|
||||||
if not proactive_thinking_scheduler._check_daily_limit(stream_id):
|
if not proactive_thinking_scheduler._check_daily_limit(stream_id):
|
||||||
logger.debug("今日发言达上限")
|
logger.debug("今日发言达上限")
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -19,11 +19,13 @@ from .src.recv_handler.meta_event_handler import meta_event_handler
|
|||||||
from .src.recv_handler.notice_handler import notice_handler
|
from .src.recv_handler.notice_handler import notice_handler
|
||||||
from .src.response_pool import check_timeout_response, put_response
|
from .src.response_pool import check_timeout_response, put_response
|
||||||
from .src.send_handler import send_handler
|
from .src.send_handler import send_handler
|
||||||
|
from .src.stream_router import stream_router
|
||||||
from .src.websocket_manager import websocket_manager
|
from .src.websocket_manager import websocket_manager
|
||||||
|
|
||||||
logger = get_logger("napcat_adapter")
|
logger = get_logger("napcat_adapter")
|
||||||
|
|
||||||
message_queue = asyncio.Queue()
|
# 旧的全局消息队列已被流路由器替代
|
||||||
|
# message_queue = asyncio.Queue()
|
||||||
|
|
||||||
|
|
||||||
def get_classes_in_module(module):
|
def get_classes_in_module(module):
|
||||||
@@ -64,7 +66,8 @@ async def message_recv(server_connection: Server.ServerConnection):
|
|||||||
# 处理完整消息(可能是重组后的,也可能是原本就完整的)
|
# 处理完整消息(可能是重组后的,也可能是原本就完整的)
|
||||||
post_type = decoded_raw_message.get("post_type")
|
post_type = decoded_raw_message.get("post_type")
|
||||||
if post_type in ["meta_event", "message", "notice"]:
|
if post_type in ["meta_event", "message", "notice"]:
|
||||||
await message_queue.put(decoded_raw_message)
|
# 使用流路由器路由消息到对应的聊天流
|
||||||
|
await stream_router.route_message(decoded_raw_message)
|
||||||
elif post_type is None:
|
elif post_type is None:
|
||||||
await put_response(decoded_raw_message)
|
await put_response(decoded_raw_message)
|
||||||
|
|
||||||
@@ -76,61 +79,11 @@ async def message_recv(server_connection: Server.ServerConnection):
|
|||||||
logger.debug(f"原始消息: {raw_message[:500]}...")
|
logger.debug(f"原始消息: {raw_message[:500]}...")
|
||||||
|
|
||||||
|
|
||||||
async def message_process():
|
# 旧的单消费者消息处理循环已被流路由器替代
|
||||||
"""消息处理主循环"""
|
# 现在每个聊天流都有自己的消费者协程
|
||||||
logger.info("消息处理器已启动")
|
# async def message_process():
|
||||||
try:
|
# """消息处理主循环"""
|
||||||
while True:
|
# ...
|
||||||
try:
|
|
||||||
# 使用超时等待,以便能够响应取消请求
|
|
||||||
message = await asyncio.wait_for(message_queue.get(), timeout=1.0)
|
|
||||||
|
|
||||||
post_type = message.get("post_type")
|
|
||||||
if post_type == "message":
|
|
||||||
await message_handler.handle_raw_message(message)
|
|
||||||
elif post_type == "meta_event":
|
|
||||||
await meta_event_handler.handle_meta_event(message)
|
|
||||||
elif post_type == "notice":
|
|
||||||
await notice_handler.handle_notice(message)
|
|
||||||
else:
|
|
||||||
logger.warning(f"未知的post_type: {post_type}")
|
|
||||||
|
|
||||||
message_queue.task_done()
|
|
||||||
await asyncio.sleep(0.05)
|
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
# 超时是正常的,继续循环
|
|
||||||
continue
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
logger.info("消息处理器收到取消信号")
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"处理消息时出错: {e}")
|
|
||||||
# 即使出错也标记任务完成,避免队列阻塞
|
|
||||||
try:
|
|
||||||
message_queue.task_done()
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
logger.info("消息处理器已停止")
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"消息处理器异常: {e}")
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
logger.info("消息处理器正在清理...")
|
|
||||||
# 清空剩余的队列项目
|
|
||||||
try:
|
|
||||||
while not message_queue.empty():
|
|
||||||
try:
|
|
||||||
message_queue.get_nowait()
|
|
||||||
message_queue.task_done()
|
|
||||||
except asyncio.QueueEmpty:
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(f"清理消息队列时出错: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
async def napcat_server(plugin_config: dict):
|
async def napcat_server(plugin_config: dict):
|
||||||
@@ -151,6 +104,12 @@ async def graceful_shutdown():
|
|||||||
try:
|
try:
|
||||||
logger.info("正在关闭adapter...")
|
logger.info("正在关闭adapter...")
|
||||||
|
|
||||||
|
# 停止流路由器
|
||||||
|
try:
|
||||||
|
await stream_router.stop()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"停止流路由器时出错: {e}")
|
||||||
|
|
||||||
# 停止消息重组器的清理任务
|
# 停止消息重组器的清理任务
|
||||||
try:
|
try:
|
||||||
await reassembler.stop_cleanup_task()
|
await reassembler.stop_cleanup_task()
|
||||||
@@ -198,17 +157,6 @@ async def graceful_shutdown():
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Adapter关闭中出现错误: {e}")
|
logger.error(f"Adapter关闭中出现错误: {e}")
|
||||||
finally:
|
|
||||||
# 确保消息队列被清空
|
|
||||||
try:
|
|
||||||
while not message_queue.empty():
|
|
||||||
try:
|
|
||||||
message_queue.get_nowait()
|
|
||||||
message_queue.task_done()
|
|
||||||
except asyncio.QueueEmpty:
|
|
||||||
break
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class LauchNapcatAdapterHandler(BaseEventHandler):
|
class LauchNapcatAdapterHandler(BaseEventHandler):
|
||||||
@@ -225,12 +173,16 @@ class LauchNapcatAdapterHandler(BaseEventHandler):
|
|||||||
logger.info("启动消息重组器...")
|
logger.info("启动消息重组器...")
|
||||||
await reassembler.start_cleanup_task()
|
await reassembler.start_cleanup_task()
|
||||||
|
|
||||||
|
# 启动流路由器
|
||||||
|
logger.info("启动流路由器...")
|
||||||
|
await stream_router.start()
|
||||||
|
|
||||||
logger.info("开始启动Napcat Adapter")
|
logger.info("开始启动Napcat Adapter")
|
||||||
|
|
||||||
# 创建单独的异步任务,防止阻塞主线程
|
# 创建单独的异步任务,防止阻塞主线程
|
||||||
asyncio.create_task(self._start_maibot_connection())
|
asyncio.create_task(self._start_maibot_connection())
|
||||||
asyncio.create_task(napcat_server(self.plugin_config))
|
asyncio.create_task(napcat_server(self.plugin_config))
|
||||||
asyncio.create_task(message_process())
|
# 不再需要 message_process 任务,由流路由器管理消费者
|
||||||
asyncio.create_task(check_timeout_response())
|
asyncio.create_task(check_timeout_response())
|
||||||
|
|
||||||
async def _start_maibot_connection(self):
|
async def _start_maibot_connection(self):
|
||||||
@@ -347,6 +299,12 @@ class NapcatAdapterPlugin(BasePlugin):
|
|||||||
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
|
"stream_router": {
|
||||||
|
"max_streams": ConfigField(type=int, default=500, description="最大并发流数量"),
|
||||||
|
"stream_timeout": ConfigField(type=int, default=600, description="流不活跃超时时间(秒),超时后自动清理"),
|
||||||
|
"stream_queue_size": ConfigField(type=int, default=100, description="每个流的消息队列大小"),
|
||||||
|
"cleanup_interval": ConfigField(type=int, default=60, description="清理不活跃流的间隔时间(秒)"),
|
||||||
|
},
|
||||||
"features": {
|
"features": {
|
||||||
# 权限设置
|
# 权限设置
|
||||||
"group_list_type": ConfigField(
|
"group_list_type": ConfigField(
|
||||||
@@ -383,7 +341,6 @@ class NapcatAdapterPlugin(BasePlugin):
|
|||||||
"supported_formats": ConfigField(
|
"supported_formats": ConfigField(
|
||||||
type=list, default=["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"], description="支持的视频格式"
|
type=list, default=["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"], description="支持的视频格式"
|
||||||
),
|
),
|
||||||
# 消息缓冲功能已移除
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -397,7 +354,8 @@ class NapcatAdapterPlugin(BasePlugin):
|
|||||||
"voice": "发送语音设置",
|
"voice": "发送语音设置",
|
||||||
"slicing": "WebSocket消息切片设置",
|
"slicing": "WebSocket消息切片设置",
|
||||||
"debug": "调试设置",
|
"debug": "调试设置",
|
||||||
"features": "功能设置(权限控制、聊天功能、视频处理、消息缓冲等)",
|
"stream_router": "流路由器设置(按聊天流分配消费者,提升高并发性能)",
|
||||||
|
"features": "功能设置(权限控制、聊天功能、视频处理等)",
|
||||||
}
|
}
|
||||||
|
|
||||||
def register_events(self):
|
def register_events(self):
|
||||||
@@ -444,4 +402,11 @@ class NapcatAdapterPlugin(BasePlugin):
|
|||||||
notice_handler.set_plugin_config(self.config)
|
notice_handler.set_plugin_config(self.config)
|
||||||
# 设置meta_event_handler的插件配置
|
# 设置meta_event_handler的插件配置
|
||||||
meta_event_handler.set_plugin_config(self.config)
|
meta_event_handler.set_plugin_config(self.config)
|
||||||
|
|
||||||
|
# 设置流路由器的配置
|
||||||
|
stream_router.max_streams = config_api.get_plugin_config(self.config, "stream_router.max_streams", 500)
|
||||||
|
stream_router.stream_timeout = config_api.get_plugin_config(self.config, "stream_router.stream_timeout", 600)
|
||||||
|
stream_router.stream_queue_size = config_api.get_plugin_config(self.config, "stream_router.stream_queue_size", 100)
|
||||||
|
stream_router.cleanup_interval = config_api.get_plugin_config(self.config, "stream_router.cleanup_interval", 60)
|
||||||
|
|
||||||
# 设置其他handler的插件配置(现在由component_registry在注册时自动设置)
|
# 设置其他handler的插件配置(现在由component_registry在注册时自动设置)
|
||||||
|
|||||||
351
src/plugins/built_in/napcat_adapter_plugin/src/stream_router.py
Normal file
351
src/plugins/built_in/napcat_adapter_plugin/src/stream_router.py
Normal file
@@ -0,0 +1,351 @@
|
|||||||
|
"""
|
||||||
|
按聊天流分配消费者的消息路由系统
|
||||||
|
|
||||||
|
核心思想:
|
||||||
|
- 为每个活跃的聊天流(stream_id)创建独立的消息队列和消费者协程
|
||||||
|
- 同一聊天流的消息由同一个 worker 处理,保证顺序性
|
||||||
|
- 不同聊天流的消息并发处理,提高吞吐量
|
||||||
|
- 动态管理流的生命周期,自动清理不活跃的流
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger("stream_router")
|
||||||
|
|
||||||
|
|
||||||
|
class StreamConsumer:
|
||||||
|
"""单个聊天流的消息消费者
|
||||||
|
|
||||||
|
维护独立的消息队列和处理协程
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, stream_id: str, queue_maxsize: int = 100):
|
||||||
|
self.stream_id = stream_id
|
||||||
|
self.queue: asyncio.Queue = asyncio.Queue(maxsize=queue_maxsize)
|
||||||
|
self.worker_task: Optional[asyncio.Task] = None
|
||||||
|
self.last_active_time = time.time()
|
||||||
|
self.is_running = False
|
||||||
|
|
||||||
|
# 性能统计
|
||||||
|
self.stats = {
|
||||||
|
"total_messages": 0,
|
||||||
|
"total_processing_time": 0.0,
|
||||||
|
"queue_overflow_count": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
"""启动消费者"""
|
||||||
|
if not self.is_running:
|
||||||
|
self.is_running = True
|
||||||
|
self.worker_task = asyncio.create_task(self._process_loop())
|
||||||
|
logger.debug(f"Stream Consumer 启动: {self.stream_id}")
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
"""停止消费者"""
|
||||||
|
self.is_running = False
|
||||||
|
if self.worker_task:
|
||||||
|
self.worker_task.cancel()
|
||||||
|
try:
|
||||||
|
await self.worker_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
logger.debug(f"Stream Consumer 停止: {self.stream_id}")
|
||||||
|
|
||||||
|
async def enqueue(self, message: dict) -> None:
|
||||||
|
"""将消息加入队列"""
|
||||||
|
self.last_active_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 使用 put_nowait 避免阻塞路由器
|
||||||
|
self.queue.put_nowait(message)
|
||||||
|
except asyncio.QueueFull:
|
||||||
|
self.stats["queue_overflow_count"] += 1
|
||||||
|
logger.warning(
|
||||||
|
f"Stream {self.stream_id} 队列已满 "
|
||||||
|
f"({self.queue.qsize()}/{self.queue.maxsize}),"
|
||||||
|
f"消息被丢弃!溢出次数: {self.stats['queue_overflow_count']}"
|
||||||
|
)
|
||||||
|
# 可选策略:丢弃最旧的消息
|
||||||
|
# try:
|
||||||
|
# self.queue.get_nowait()
|
||||||
|
# self.queue.put_nowait(message)
|
||||||
|
# logger.debug(f"Stream {self.stream_id} 丢弃最旧消息,添加新消息")
|
||||||
|
# except asyncio.QueueEmpty:
|
||||||
|
# pass
|
||||||
|
|
||||||
|
async def _process_loop(self) -> None:
|
||||||
|
"""消息处理循环"""
|
||||||
|
# 延迟导入,避免循环依赖
|
||||||
|
from .recv_handler.message_handler import message_handler
|
||||||
|
from .recv_handler.meta_event_handler import meta_event_handler
|
||||||
|
from .recv_handler.notice_handler import notice_handler
|
||||||
|
|
||||||
|
logger.info(f"Stream {self.stream_id} 处理循环启动")
|
||||||
|
|
||||||
|
try:
|
||||||
|
while self.is_running:
|
||||||
|
try:
|
||||||
|
# 等待消息,1秒超时
|
||||||
|
message = await asyncio.wait_for(
|
||||||
|
self.queue.get(),
|
||||||
|
timeout=1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# 处理消息
|
||||||
|
post_type = message.get("post_type")
|
||||||
|
if post_type == "message":
|
||||||
|
await message_handler.handle_raw_message(message)
|
||||||
|
elif post_type == "meta_event":
|
||||||
|
await meta_event_handler.handle_meta_event(message)
|
||||||
|
elif post_type == "notice":
|
||||||
|
await notice_handler.handle_notice(message)
|
||||||
|
else:
|
||||||
|
logger.warning(f"未知的 post_type: {post_type}")
|
||||||
|
|
||||||
|
processing_time = time.time() - start_time
|
||||||
|
|
||||||
|
# 更新统计
|
||||||
|
self.stats["total_messages"] += 1
|
||||||
|
self.stats["total_processing_time"] += processing_time
|
||||||
|
self.last_active_time = time.time()
|
||||||
|
self.queue.task_done()
|
||||||
|
|
||||||
|
# 性能监控(每100条消息输出一次)
|
||||||
|
if self.stats["total_messages"] % 100 == 0:
|
||||||
|
avg_time = self.stats["total_processing_time"] / self.stats["total_messages"]
|
||||||
|
logger.info(
|
||||||
|
f"Stream {self.stream_id[:30]}... 统计: "
|
||||||
|
f"消息数={self.stats['total_messages']}, "
|
||||||
|
f"平均耗时={avg_time:.3f}秒, "
|
||||||
|
f"队列长度={self.queue.qsize()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 动态延迟:队列空时短暂休眠
|
||||||
|
if self.queue.qsize() == 0:
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
# 超时是正常的,继续循环
|
||||||
|
continue
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info(f"Stream {self.stream_id} 处理循环被取消")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Stream {self.stream_id} 处理消息时出错: {e}", exc_info=True)
|
||||||
|
# 继续处理下一条消息
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
logger.info(f"Stream {self.stream_id} 处理循环结束")
|
||||||
|
|
||||||
|
def get_stats(self) -> dict:
|
||||||
|
"""获取性能统计"""
|
||||||
|
avg_time = (
|
||||||
|
self.stats["total_processing_time"] / self.stats["total_messages"]
|
||||||
|
if self.stats["total_messages"] > 0
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"stream_id": self.stream_id,
|
||||||
|
"queue_size": self.queue.qsize(),
|
||||||
|
"total_messages": self.stats["total_messages"],
|
||||||
|
"avg_processing_time": avg_time,
|
||||||
|
"queue_overflow_count": self.stats["queue_overflow_count"],
|
||||||
|
"last_active_time": self.last_active_time,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class StreamRouter:
|
||||||
|
"""流路由器
|
||||||
|
|
||||||
|
负责将消息路由到对应的聊天流队列
|
||||||
|
动态管理聊天流的生命周期
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_streams: int = 500,
|
||||||
|
stream_timeout: int = 600,
|
||||||
|
stream_queue_size: int = 100,
|
||||||
|
cleanup_interval: int = 60,
|
||||||
|
):
|
||||||
|
self.streams: Dict[str, StreamConsumer] = {}
|
||||||
|
self.lock = asyncio.Lock()
|
||||||
|
self.max_streams = max_streams
|
||||||
|
self.stream_timeout = stream_timeout
|
||||||
|
self.stream_queue_size = stream_queue_size
|
||||||
|
self.cleanup_interval = cleanup_interval
|
||||||
|
self.cleanup_task: Optional[asyncio.Task] = None
|
||||||
|
self.is_running = False
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
"""启动路由器"""
|
||||||
|
if not self.is_running:
|
||||||
|
self.is_running = True
|
||||||
|
self.cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||||
|
logger.info(
|
||||||
|
f"StreamRouter 已启动 - "
|
||||||
|
f"最大流数: {self.max_streams}, "
|
||||||
|
f"超时: {self.stream_timeout}秒, "
|
||||||
|
f"队列大小: {self.stream_queue_size}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
"""停止路由器"""
|
||||||
|
self.is_running = False
|
||||||
|
|
||||||
|
if self.cleanup_task:
|
||||||
|
self.cleanup_task.cancel()
|
||||||
|
try:
|
||||||
|
await self.cleanup_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 停止所有流消费者
|
||||||
|
logger.info(f"正在停止 {len(self.streams)} 个流消费者...")
|
||||||
|
for consumer in self.streams.values():
|
||||||
|
await consumer.stop()
|
||||||
|
|
||||||
|
self.streams.clear()
|
||||||
|
logger.info("StreamRouter 已停止")
|
||||||
|
|
||||||
|
async def route_message(self, message: dict) -> None:
|
||||||
|
"""路由消息到对应的流"""
|
||||||
|
stream_id = self._extract_stream_id(message)
|
||||||
|
|
||||||
|
# 快速路径:流已存在
|
||||||
|
if stream_id in self.streams:
|
||||||
|
await self.streams[stream_id].enqueue(message)
|
||||||
|
return
|
||||||
|
|
||||||
|
# 慢路径:需要创建新流
|
||||||
|
async with self.lock:
|
||||||
|
# 双重检查
|
||||||
|
if stream_id not in self.streams:
|
||||||
|
# 检查流数量限制
|
||||||
|
if len(self.streams) >= self.max_streams:
|
||||||
|
logger.warning(
|
||||||
|
f"达到最大流数量限制 ({self.max_streams}),"
|
||||||
|
f"尝试清理不活跃的流..."
|
||||||
|
)
|
||||||
|
await self._cleanup_inactive_streams()
|
||||||
|
|
||||||
|
# 清理后仍然超限,记录警告但继续创建
|
||||||
|
if len(self.streams) >= self.max_streams:
|
||||||
|
logger.error(
|
||||||
|
f"清理后仍达到最大流数量 ({len(self.streams)}/{self.max_streams})!"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建新流
|
||||||
|
consumer = StreamConsumer(stream_id, self.stream_queue_size)
|
||||||
|
self.streams[stream_id] = consumer
|
||||||
|
await consumer.start()
|
||||||
|
logger.info(f"创建新的 Stream Consumer: {stream_id} (总流数: {len(self.streams)})")
|
||||||
|
|
||||||
|
await self.streams[stream_id].enqueue(message)
|
||||||
|
|
||||||
|
def _extract_stream_id(self, message: dict) -> str:
|
||||||
|
"""从消息中提取 stream_id
|
||||||
|
|
||||||
|
返回格式: platform:id:type
|
||||||
|
例如: qq:123456:group 或 qq:789012:private
|
||||||
|
"""
|
||||||
|
post_type = message.get("post_type")
|
||||||
|
|
||||||
|
# 非消息类型,使用默认流(避免创建过多流)
|
||||||
|
if post_type not in ["message", "notice"]:
|
||||||
|
return "system:meta_event"
|
||||||
|
|
||||||
|
# 消息类型
|
||||||
|
if post_type == "message":
|
||||||
|
message_type = message.get("message_type")
|
||||||
|
if message_type == "group":
|
||||||
|
group_id = message.get("group_id")
|
||||||
|
return f"qq:{group_id}:group"
|
||||||
|
elif message_type == "private":
|
||||||
|
user_id = message.get("user_id")
|
||||||
|
return f"qq:{user_id}:private"
|
||||||
|
|
||||||
|
# notice 类型
|
||||||
|
elif post_type == "notice":
|
||||||
|
group_id = message.get("group_id")
|
||||||
|
if group_id:
|
||||||
|
return f"qq:{group_id}:group"
|
||||||
|
user_id = message.get("user_id")
|
||||||
|
if user_id:
|
||||||
|
return f"qq:{user_id}:private"
|
||||||
|
|
||||||
|
# 未知类型,使用通用流
|
||||||
|
return "unknown:unknown"
|
||||||
|
|
||||||
|
async def _cleanup_inactive_streams(self) -> None:
|
||||||
|
"""清理不活跃的流"""
|
||||||
|
current_time = time.time()
|
||||||
|
to_remove = []
|
||||||
|
|
||||||
|
for stream_id, consumer in self.streams.items():
|
||||||
|
if current_time - consumer.last_active_time > self.stream_timeout:
|
||||||
|
to_remove.append(stream_id)
|
||||||
|
|
||||||
|
for stream_id in to_remove:
|
||||||
|
await self.streams[stream_id].stop()
|
||||||
|
del self.streams[stream_id]
|
||||||
|
logger.debug(f"清理不活跃的流: {stream_id}")
|
||||||
|
|
||||||
|
if to_remove:
|
||||||
|
logger.info(
|
||||||
|
f"清理了 {len(to_remove)} 个不活跃的流 "
|
||||||
|
f"(当前活跃流: {len(self.streams)}/{self.max_streams})"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _cleanup_loop(self) -> None:
|
||||||
|
"""定期清理循环"""
|
||||||
|
logger.info(f"清理循环已启动,间隔: {self.cleanup_interval}秒")
|
||||||
|
try:
|
||||||
|
while self.is_running:
|
||||||
|
await asyncio.sleep(self.cleanup_interval)
|
||||||
|
await self._cleanup_inactive_streams()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("清理循环已停止")
|
||||||
|
|
||||||
|
def get_all_stats(self) -> list[dict]:
|
||||||
|
"""获取所有流的统计信息"""
|
||||||
|
return [consumer.get_stats() for consumer in self.streams.values()]
|
||||||
|
|
||||||
|
def get_summary(self) -> dict:
|
||||||
|
"""获取路由器摘要"""
|
||||||
|
total_messages = sum(c.stats["total_messages"] for c in self.streams.values())
|
||||||
|
total_queue_size = sum(c.queue.qsize() for c in self.streams.values())
|
||||||
|
total_overflows = sum(c.stats["queue_overflow_count"] for c in self.streams.values())
|
||||||
|
|
||||||
|
# 计算平均队列长度
|
||||||
|
avg_queue_size = total_queue_size / len(self.streams) if self.streams else 0
|
||||||
|
|
||||||
|
# 找出最繁忙的流
|
||||||
|
busiest_stream = None
|
||||||
|
if self.streams:
|
||||||
|
busiest_stream = max(
|
||||||
|
self.streams.values(),
|
||||||
|
key=lambda c: c.stats["total_messages"]
|
||||||
|
).stream_id
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_streams": len(self.streams),
|
||||||
|
"max_streams": self.max_streams,
|
||||||
|
"total_messages_processed": total_messages,
|
||||||
|
"total_queue_size": total_queue_size,
|
||||||
|
"avg_queue_size": avg_queue_size,
|
||||||
|
"total_queue_overflows": total_overflows,
|
||||||
|
"busiest_stream": busiest_stream,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# 全局路由器实例
|
||||||
|
stream_router = StreamRouter()
|
||||||
Reference in New Issue
Block a user