From 9b3251f8ecbe77a84ab6406b6a3e036730734793 Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Tue, 24 Jun 2025 21:51:46 +0800 Subject: [PATCH] =?UTF-8?q?feat=EF=BC=9A=E5=A2=9E=E5=8A=A0=E4=BA=86reply?= =?UTF-8?q?=5Fto=E6=96=B0message=E5=B1=9E=E6=80=A7=EF=BC=8C=E4=BC=98?= =?UTF-8?q?=E5=8C=96prompt=EF=BC=8C=E5=88=87=E5=89=B2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/message_receive/message.py | 3 ++ src/chat/message_receive/storage.py | 5 ++ src/common/database/database_model.py | 2 + src/mais4u/mais4u_chat/s4u_chat.py | 53 ++++++++++++------- src/mais4u/mais4u_chat/s4u_msg_processor.py | 7 +-- src/mais4u/mais4u_chat/s4u_prompt.py | 29 +++++++--- .../mais4u_chat/s4u_stream_generator.py | 25 +++++++-- 7 files changed, 90 insertions(+), 34 deletions(-) diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index 5798eb512..2ba50d7ec 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -283,6 +283,7 @@ class MessageSending(MessageProcessBase): is_emoji: bool = False, thinking_start_time: float = 0, apply_set_reply_logic: bool = False, + reply_to: str = None, ): # 调用父类初始化 super().__init__( @@ -300,6 +301,8 @@ class MessageSending(MessageProcessBase): self.is_head = is_head self.is_emoji = is_emoji self.apply_set_reply_logic = apply_set_reply_logic + + self.reply_to = reply_to # 用于显示发送内容与显示不一致的情况 self.display_message = display_message diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index c4ef047de..58835a921 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -35,8 +35,12 @@ class MessageStorage: filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL) else: filtered_display_message = "" + + reply_to = message.reply_to else: filtered_display_message = "" + + reply_to = "" chat_info_dict = chat_stream.to_dict() user_info_dict = message.message_info.user_info.to_dict() @@ -54,6 +58,7 @@ class MessageStorage: time=float(message.message_info.time), chat_id=chat_stream.stream_id, # Flattened chat_info + reply_to=reply_to, chat_info_stream_id=chat_info_dict.get("stream_id"), chat_info_platform=chat_info_dict.get("platform"), chat_info_user_platform=user_info_from_chat.get("platform"), diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index 5e3a08313..82bf28122 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -126,6 +126,8 @@ class Messages(BaseModel): time = DoubleField() # 消息时间戳 chat_id = TextField(index=True) # 对应的 ChatStreams stream_id + + reply_to = TextField(null=True) # 从 chat_info 扁平化而来的字段 chat_info_stream_id = TextField() diff --git a/src/mais4u/mais4u_chat/s4u_chat.py b/src/mais4u/mais4u_chat/s4u_chat.py index fbf4c29df..c63f2bc9c 100644 --- a/src/mais4u/mais4u_chat/s4u_chat.py +++ b/src/mais4u/mais4u_chat/s4u_chat.py @@ -92,7 +92,8 @@ class MessageSenderContainer: # Check for pause signal *after* getting an item. await self._paused_event.wait() - delay = self._calculate_typing_delay(chunk) + # delay = self._calculate_typing_delay(chunk) + delay = 0.1 await asyncio.sleep(delay) current_time = time.time() @@ -116,6 +117,7 @@ class MessageSenderContainer: reply=self.original_message, is_emoji=False, apply_set_reply_logic=True, + reply_to=f"{self.original_message.message_info.user_info.platform}:{self.original_message.message_info.user_info.user_id}" ) await bot_message.process() @@ -171,22 +173,13 @@ class S4UChat: self._message_queue = asyncio.Queue() self._processing_task = asyncio.create_task(self._message_processor()) self._current_generation_task: Optional[asyncio.Task] = None + self._current_message_being_replied: Optional[MessageRecv] = None self._is_replying = False - # 初始化Normal Chat专用表达器 - self.expressor = NormalChatExpressor(self.chat_stream) - self.replyer = DefaultReplyer(self.chat_stream) - self.gpt = S4UStreamGenerator() - self.audio_generator = MockAudioGenerator() - self.start_time = time.time() + # self.audio_generator = MockAudioGenerator() - # 记录最近的回复内容,每项包含: {time, user_message, response, is_mentioned, is_reference_reply} - self.recent_replies = [] - self.max_replies_history = 20 # 最多保存最近20条回复记录 - - self.storage = MessageStorage() logger.info(f"[{self.stream_name}] S4UChat") @@ -194,11 +187,32 @@ class S4UChat: # 改为实例方法, 移除 chat 参数 async def response(self, message: MessageRecv, is_mentioned: bool, interested_rate: float) -> None: - """将消息放入队列并中断当前处理(如果正在处理)。""" + """将消息放入队列并根据发信人决定是否中断当前处理。""" + should_interrupt = False if self._current_generation_task and not self._current_generation_task.done(): + if self._current_message_being_replied: + # 检查新消息发送者和正在回复的消息发送者是否为同一人 + new_sender_id = message.message_info.user_info.user_id + original_sender_id = self._current_message_being_replied.message_info.user_info.user_id + + if new_sender_id == original_sender_id: + should_interrupt = True + logger.info(f"[{self.stream_name}] 来自同一用户的消息,中断当前回复。") + else: + if random.random() < 0.2: + should_interrupt = True + logger.info(f"[{self.stream_name}] 来自不同用户的消息,随机中断(20%)。") + else: + logger.info(f"[{self.stream_name}] 来自不同用户的消息,不中断。") + else: + # Fallback: if we don't know who we are replying to, interrupt. + should_interrupt = True + logger.warning(f"[{self.stream_name}] 正在生成回复,但无法获取原始消息发送者信息,将默认中断。") + + if should_interrupt: self._current_generation_task.cancel() logger.info(f"[{self.stream_name}] 请求中断当前回复生成任务。") - + await self._message_queue.put(message) async def _message_processor(self): @@ -207,12 +221,14 @@ class S4UChat: try: # 等待第一条消息 message = await self._message_queue.get() + self._current_message_being_replied = message # 如果因快速中断导致队列中积压了更多消息,则只处理最新的一条 while not self._message_queue.empty(): drained_msg = self._message_queue.get_nowait() self._message_queue.task_done() # 为取出的旧消息调用 task_done message = drained_msg # 始终处理最新消息 + self._current_message_being_replied = message logger.info(f"[{self.stream_name}] 丢弃过时消息,处理最新消息: {message.processed_plain_text}") self._current_generation_task = asyncio.create_task(self._generate_and_send(message)) @@ -225,6 +241,7 @@ class S4UChat: logger.error(f"[{self.stream_name}] _generate_and_send 任务出现错误: {e}", exc_info=True) finally: self._current_generation_task = None + self._current_message_being_replied = None except asyncio.CancelledError: logger.info(f"[{self.stream_name}] 消息处理器正在关闭。") @@ -259,10 +276,10 @@ class S4UChat: await sender_container.add_message(chunk) # b. 为该文本块生成并播放音频 - if chunk.strip(): - audio_data = await self.audio_generator.generate(chunk) - player = MockAudioPlayer(audio_data) - await player.play() + # if chunk.strip(): + # audio_data = await self.audio_generator.generate(chunk) + # player = MockAudioPlayer(audio_data) + # await player.play() # 等待所有文本消息发送完成 await sender_container.close() diff --git a/src/mais4u/mais4u_chat/s4u_msg_processor.py b/src/mais4u/mais4u_chat/s4u_msg_processor.py index 8525b6a93..4a3737a70 100644 --- a/src/mais4u/mais4u_chat/s4u_msg_processor.py +++ b/src/mais4u/mais4u_chat/s4u_msg_processor.py @@ -43,7 +43,7 @@ class S4UMessageProcessor: message_data: 原始消息字符串 """ - target_user_id = "1026294844" + target_user_id_list = ["1026294844", "964959351"] # 1. 消息解析与初始化 groupinfo = message.message_info.group_info @@ -61,9 +61,10 @@ class S4UMessageProcessor: is_mentioned = is_mentioned_bot_in_message(message) s4u_chat = get_s4u_chat_manager().get_or_create_chat(chat) - if userinfo.user_id == target_user_id: + if userinfo.user_id in target_user_id_list: await s4u_chat.response(message, is_mentioned=is_mentioned, interested_rate=1.0) - + else: + await s4u_chat.response(message, is_mentioned=is_mentioned, interested_rate=0.0) # 7. 日志记录 logger.info(f"[S4U]{userinfo.user_nickname}:{message.processed_plain_text}") diff --git a/src/mais4u/mais4u_chat/s4u_prompt.py b/src/mais4u/mais4u_chat/s4u_prompt.py index b62d93552..831058567 100644 --- a/src/mais4u/mais4u_chat/s4u_prompt.py +++ b/src/mais4u/mais4u_chat/s4u_prompt.py @@ -27,7 +27,7 @@ def init_prompt(): Prompt( """ 你的名字叫{bot_name},昵称是:{bot_other_names},{prompt_personality}。 -你现在的主要任务是和 {sender_name} 聊天。同时,也有其他用户会参与你们的聊天,但是你主要还是关注你和{sender_name}的聊天内容。 +你现在的主要任务是和 {sender_name} 聊天。同时,也有其他用户会参与你们的聊天,你可以参考他们的回复内容,但是你主要还是关注你和{sender_name}的聊天内容。 {background_dialogue_prompt} -------------------------------- @@ -35,10 +35,13 @@ def init_prompt(): 这是你和{sender_name}的对话,你们正在交流中: {core_dialogue_prompt} -{message_txt} +对方最新发送的内容:{message_txt} 回复可以简短一些。可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞,平淡一些。 不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出回复内容,现在{sender_name}正在等待你的回复。 -你的回复风格不要浮夸,有逻辑和条理,请你继续回复{sender_name}。""", +你的回复风格不要浮夸,有逻辑和条理,请你继续回复{sender_name}。 +你的发言: + +""", "s4u_prompt", # New template for private CHAT chat ) @@ -96,19 +99,29 @@ class PromptBuilder: limit=100, ) + + talk_type = message.message_info.platform + ":" + message.chat_stream.user_info.user_id + print(f"talk_type: {talk_type}") + # 分别筛选核心对话和背景对话 core_dialogue_list = [] background_dialogue_list = [] bot_id = str(global_config.bot.qq_account) target_user_id = str(message.chat_stream.user_info.user_id) + for msg_dict in message_list_before_now: try: # 直接通过字典访问 msg_user_id = str(msg_dict.get('user_id')) - - if msg_user_id == bot_id or msg_user_id == target_user_id: + if msg_user_id == bot_id: + if msg_dict.get("reply_to") and talk_type == msg_dict.get("reply_to"): + print(f"reply: {msg_dict.get('reply_to')}") + core_dialogue_list.append(msg_dict) + else: + background_dialogue_list.append(msg_dict) + elif msg_user_id == target_user_id: core_dialogue_list.append(msg_dict) else: background_dialogue_list.append(msg_dict) @@ -140,14 +153,14 @@ class PromptBuilder: last_speaking_user_id = start_speaking_user_id msg_seg_str = "对方的发言:\n" - msg_seg_str += f"{first_msg.get('processed_plain_text')}\n" + msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(first_msg.get('time')))}: {first_msg.get('processed_plain_text')}\n" all_msg_seg_list = [] for msg in core_dialogue_list[1:]: speaker = msg.get('user_id') if speaker == last_speaking_user_id: #还是同一个人讲话 - msg_seg_str += f"{msg.get('processed_plain_text')}\n" + msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(msg.get('time')))}: {msg.get('processed_plain_text')}\n" else: #换人了 msg_seg_str = f"{msg_seg_str}\n" @@ -158,7 +171,7 @@ class PromptBuilder: else: msg_seg_str = "对方的发言:\n" - msg_seg_str += f"{msg.get('processed_plain_text')}\n" + msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(msg.get('time')))}: {msg.get('processed_plain_text')}\n" last_speaking_user_id = speaker all_msg_seg_list.append(msg_seg_str) diff --git a/src/mais4u/mais4u_chat/s4u_stream_generator.py b/src/mais4u/mais4u_chat/s4u_stream_generator.py index 54df5aece..0b27df958 100644 --- a/src/mais4u/mais4u_chat/s4u_stream_generator.py +++ b/src/mais4u/mais4u_chat/s4u_stream_generator.py @@ -112,6 +112,7 @@ class S4UStreamGenerator: buffer = "" delimiters = ",。!?,.!?\n\r" # For final trimming + punctuation_buffer = "" async for content in client.get_stream_content( messages=[{"role": "user", "content": prompt}], model=model_name, **kwargs @@ -125,8 +126,19 @@ class S4UStreamGenerator: if sentence: # 如果句子看起来完整(即不只是等待更多内容),则发送 if match.end(0) < len(buffer) or sentence.endswith(tuple(delimiters)): - yield sentence - await asyncio.sleep(0) # 允许其他任务运行 + # 检查是否只是一个标点符号 + if sentence in [",", ",", ".", "。", "!", "!", "?", "?"]: + punctuation_buffer += sentence + else: + # 发送之前累积的标点和当前句子 + to_yield = punctuation_buffer + sentence + if to_yield.endswith((',', ',')): + to_yield = to_yield.rstrip(',,') + + yield to_yield + punctuation_buffer = "" # 清空标点符号缓冲区 + await asyncio.sleep(0) # 允许其他任务运行 + last_match_end = match.end(0) # 从缓冲区移除已发送的部分 @@ -134,7 +146,10 @@ class S4UStreamGenerator: buffer = buffer[last_match_end:] # 发送缓冲区中剩余的任何内容 - if buffer.strip(): - yield buffer.strip() - await asyncio.sleep(0) + to_yield = (punctuation_buffer + buffer).strip() + if to_yield: + if to_yield.endswith((',', ',')): + to_yield = to_yield.rstrip(',,') + if to_yield: + yield to_yield