feat:为s4u添加了优先队列和普通队列
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import time
|
||||
import random
|
||||
from typing import Optional, Dict # 导入类型提示
|
||||
from typing import Optional, Dict, Tuple # 导入类型提示
|
||||
from maim_message import UserInfo, Seg
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
@@ -142,6 +142,8 @@ def get_s4u_chat_manager() -> S4UChatManager:
|
||||
|
||||
|
||||
class S4UChat:
|
||||
_MESSAGE_TIMEOUT_SECONDS = 60 # 普通消息存活时间(秒)
|
||||
|
||||
def __init__(self, chat_stream: ChatStream):
|
||||
"""初始化 S4UChat 实例。"""
|
||||
|
||||
@@ -149,86 +151,141 @@ class S4UChat:
|
||||
self.stream_id = chat_stream.stream_id
|
||||
self.stream_name = get_chat_manager().get_stream_name(self.stream_id) or self.stream_id
|
||||
|
||||
self._message_queue = asyncio.Queue()
|
||||
# 两个消息队列
|
||||
self._vip_queue = asyncio.PriorityQueue()
|
||||
self._normal_queue = asyncio.PriorityQueue()
|
||||
|
||||
self._entry_counter = 0 # 保证FIFO的全局计数器
|
||||
self._new_message_event = asyncio.Event() # 用于唤醒处理器
|
||||
|
||||
self._processing_task = asyncio.create_task(self._message_processor())
|
||||
self._current_generation_task: Optional[asyncio.Task] = None
|
||||
self._current_message_being_replied: Optional[MessageRecv] = None
|
||||
# 当前消息的元数据:(队列类型, 优先级, 计数器, 消息对象)
|
||||
self._current_message_being_replied: Optional[Tuple[str, int, int, MessageRecv]] = None
|
||||
|
||||
self._is_replying = False
|
||||
|
||||
self.gpt = S4UStreamGenerator()
|
||||
# self.audio_generator = MockAudioGenerator()
|
||||
logger.info(f"[{self.stream_name}] S4UChat with two-queue system initialized.")
|
||||
|
||||
logger.info(f"[{self.stream_name}] S4UChat")
|
||||
def _is_vip(self, message: MessageRecv) -> bool:
|
||||
"""检查消息是否来自VIP用户。"""
|
||||
# 您需要修改此处或在配置文件中定义VIP用户
|
||||
vip_user_ids = ["1026294844"]
|
||||
vip_user_ids = [""]
|
||||
return message.message_info.user_info.user_id in vip_user_ids
|
||||
|
||||
# 改为实例方法, 移除 chat 参数
|
||||
async def response(self, message: MessageRecv, is_mentioned: bool, interested_rate: float) -> None:
|
||||
"""将消息放入队列并根据发信人决定是否中断当前处理。"""
|
||||
def _get_message_priority(self, message: MessageRecv) -> int:
|
||||
"""为消息分配优先级。数字越小,优先级越高。"""
|
||||
if f"@{global_config.bot.nickname}" in message.processed_plain_text or any(
|
||||
f"@{alias}" in message.processed_plain_text for alias in global_config.bot.alias_names
|
||||
):
|
||||
return 0
|
||||
return 1
|
||||
|
||||
async def add_message(self, message: MessageRecv) -> None:
|
||||
"""根据VIP状态和中断逻辑将消息放入相应队列。"""
|
||||
is_vip = self._is_vip(message)
|
||||
new_priority = self._get_message_priority(message)
|
||||
|
||||
should_interrupt = False
|
||||
if self._current_generation_task and not self._current_generation_task.done():
|
||||
if self._current_message_being_replied:
|
||||
# 检查新消息发送者和正在回复的消息发送者是否为同一人
|
||||
new_sender_id = message.message_info.user_info.user_id
|
||||
original_sender_id = self._current_message_being_replied.message_info.user_info.user_id
|
||||
|
||||
if new_sender_id == original_sender_id:
|
||||
should_interrupt = True
|
||||
logger.info(f"[{self.stream_name}] 来自同一用户的消息,中断当前回复。")
|
||||
else:
|
||||
if random.random() < 0.2:
|
||||
current_queue, current_priority, _, current_msg = self._current_message_being_replied
|
||||
|
||||
# 规则:VIP从不被打断
|
||||
if current_queue == "vip":
|
||||
pass # Do nothing
|
||||
|
||||
# 规则:普通消息可以被打断
|
||||
elif current_queue == "normal":
|
||||
# VIP消息可以打断普通消息
|
||||
if is_vip:
|
||||
should_interrupt = True
|
||||
logger.info(f"[{self.stream_name}] 来自不同用户的消息,随机中断(20%)。")
|
||||
logger.info(f"[{self.stream_name}] VIP message received, interrupting current normal task.")
|
||||
# 普通消息的内部打断逻辑
|
||||
else:
|
||||
logger.info(f"[{self.stream_name}] 来自不同用户的消息,不中断。")
|
||||
else:
|
||||
# Fallback: if we don't know who we are replying to, interrupt.
|
||||
should_interrupt = True
|
||||
logger.warning(f"[{self.stream_name}] 正在生成回复,但无法获取原始消息发送者信息,将默认中断。")
|
||||
|
||||
new_sender_id = message.message_info.user_info.user_id
|
||||
current_sender_id = current_msg.message_info.user_info.user_id
|
||||
# 新消息优先级更高
|
||||
if new_priority < current_priority:
|
||||
should_interrupt = True
|
||||
logger.info(f"[{self.stream_name}] New normal message has higher priority, interrupting.")
|
||||
# 同用户,同级或更高级
|
||||
elif new_sender_id == current_sender_id and new_priority <= current_priority:
|
||||
should_interrupt = True
|
||||
logger.info(f"[{self.stream_name}] Same user sent new message, interrupting.")
|
||||
|
||||
if should_interrupt:
|
||||
if self.gpt.partial_response:
|
||||
logger.warning(f"[{self.stream_name}] Interrupting reply. Already generated: '{self.gpt.partial_response}'")
|
||||
self._current_generation_task.cancel()
|
||||
logger.info(f"[{self.stream_name}] 请求中断当前回复生成任务。")
|
||||
|
||||
await self._message_queue.put(message)
|
||||
# 将消息放入对应的队列
|
||||
item = (new_priority, self._entry_counter, time.time(), message)
|
||||
if is_vip:
|
||||
await self._vip_queue.put(item)
|
||||
logger.info(f"[{self.stream_name}] VIP message added to queue.")
|
||||
else:
|
||||
await self._normal_queue.put(item)
|
||||
|
||||
self._entry_counter += 1
|
||||
self._new_message_event.set() # 唤醒处理器
|
||||
|
||||
async def _message_processor(self):
|
||||
"""从队列中处理消息,支持中断。"""
|
||||
"""调度器:优先处理VIP队列,然后处理普通队列。"""
|
||||
while True:
|
||||
try:
|
||||
# 等待第一条消息
|
||||
message = await self._message_queue.get()
|
||||
self._current_message_being_replied = message
|
||||
# 等待有新消息的信号,避免空转
|
||||
await self._new_message_event.wait()
|
||||
self._new_message_event.clear()
|
||||
|
||||
# 如果因快速中断导致队列中积压了更多消息,则只处理最新的一条
|
||||
while not self._message_queue.empty():
|
||||
drained_msg = self._message_queue.get_nowait()
|
||||
self._message_queue.task_done() # 为取出的旧消息调用 task_done
|
||||
message = drained_msg # 始终处理最新消息
|
||||
self._current_message_being_replied = message
|
||||
logger.info(f"[{self.stream_name}] 丢弃过时消息,处理最新消息: {message.processed_plain_text}")
|
||||
# 优先处理VIP队列
|
||||
if not self._vip_queue.empty():
|
||||
priority, entry_count, _, message = self._vip_queue.get_nowait()
|
||||
queue_name = "vip"
|
||||
# 其次处理普通队列
|
||||
elif not self._normal_queue.empty():
|
||||
priority, entry_count, timestamp, message = self._normal_queue.get_nowait()
|
||||
# 检查普通消息是否超时
|
||||
if time.time() - timestamp > self._MESSAGE_TIMEOUT_SECONDS:
|
||||
logger.info(f"[{self.stream_name}] Discarding stale normal message: {message.processed_plain_text[:20]}...")
|
||||
self._normal_queue.task_done()
|
||||
continue # 处理下一条
|
||||
queue_name = "normal"
|
||||
else:
|
||||
continue # 没有消息了,回去等事件
|
||||
|
||||
self._current_message_being_replied = (queue_name, priority, entry_count, message)
|
||||
self._current_generation_task = asyncio.create_task(self._generate_and_send(message))
|
||||
|
||||
try:
|
||||
await self._current_generation_task
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"[{self.stream_name}] 回复生成被外部中断。")
|
||||
logger.info(f"[{self.stream_name}] Reply generation was interrupted externally for {queue_name} message. The message will be discarded.")
|
||||
# 被中断的消息应该被丢弃,而不是重新排队,以响应最新的用户输入。
|
||||
# 旧的重新入队逻辑会导致所有中断的消息最终都被回复。
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.stream_name}] _generate_and_send 任务出现错误: {e}", exc_info=True)
|
||||
logger.error(f"[{self.stream_name}] _generate_and_send task error: {e}", exc_info=True)
|
||||
finally:
|
||||
self._current_generation_task = None
|
||||
self._current_message_being_replied = None
|
||||
# 标记任务完成
|
||||
if queue_name == 'vip':
|
||||
self._vip_queue.task_done()
|
||||
else:
|
||||
self._normal_queue.task_done()
|
||||
|
||||
# 检查是否还有任务,有则立即再次触发事件
|
||||
if not self._vip_queue.empty() or not self._normal_queue.empty():
|
||||
self._new_message_event.set()
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"[{self.stream_name}] 消息处理器正在关闭。")
|
||||
logger.info(f"[{self.stream_name}] Message processor is shutting down.")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.stream_name}] 消息处理器主循环发生未知错误: {e}", exc_info=True)
|
||||
await asyncio.sleep(1) # 避免在未知错误下陷入CPU空转
|
||||
finally:
|
||||
# 确保处理过的消息(无论是正常完成还是被丢弃)都被标记完成
|
||||
if "message" in locals():
|
||||
self._message_queue.task_done()
|
||||
logger.error(f"[{self.stream_name}] Message processor main loop error: {e}", exc_info=True)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def _generate_and_send(self, message: MessageRecv):
|
||||
"""为单个消息生成文本和音频回复。整个过程可以被中断。"""
|
||||
|
||||
Reference in New Issue
Block a user