diff --git a/src/audio/mock_audio.py b/src/audio/mock_audio.py new file mode 100644 index 000000000..73d7176af --- /dev/null +++ b/src/audio/mock_audio.py @@ -0,0 +1,58 @@ +import asyncio +from src.common.logger import get_logger + +logger = get_logger("MockAudio") + +class MockAudioPlayer: + """ + 一个模拟的音频播放器,它会根据音频数据的"长度"来模拟播放时间。 + """ + def __init__(self, audio_data: bytes): + self._audio_data = audio_data + # 模拟音频时长:假设每 1024 字节代表 0.5 秒的音频 + self._duration = (len(audio_data) / 1024.0) * 0.5 + + async def play(self): + """模拟播放音频。该过程可以被中断。""" + if self._duration <= 0: + return + logger.info(f"开始播放模拟音频,预计时长: {self._duration:.2f} 秒...") + try: + await asyncio.sleep(self._duration) + logger.info("模拟音频播放完毕。") + except asyncio.CancelledError: + logger.info("音频播放被中断。") + raise # 重新抛出异常,以便上层逻辑可以捕获它 + +class MockAudioGenerator: + """ + 一个模拟的文本到语音(TTS)生成器。 + """ + def __init__(self): + # 模拟生成速度:每秒生成的字符数 + self.chars_per_second = 25.0 + + async def generate(self, text: str) -> bytes: + """ + 模拟从文本生成音频数据。该过程可以被中断。 + + Args: + text: 需要转换为音频的文本。 + + Returns: + 模拟的音频数据(bytes)。 + """ + if not text: + return b'' + + generation_time = len(text) / self.chars_per_second + logger.info(f"模拟生成音频... 文本长度: {len(text)}, 预计耗时: {generation_time:.2f} 秒...") + try: + await asyncio.sleep(generation_time) + # 生成虚拟的音频数据,其长度与文本长度成正比 + mock_audio_data = b'\x01\x02\x03' * (len(text) * 40) + logger.info(f"模拟音频生成完毕,数据大小: {len(mock_audio_data) / 1024:.2f} KB。") + return mock_audio_data + except asyncio.CancelledError: + logger.info("音频生成被中断。") + raise # 重新抛出异常 \ No newline at end of file diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index 8b8d6f255..099f3c062 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -13,8 +13,11 @@ from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.config.config import global_config from src.plugin_system.core.component_registry import component_registry # 导入新插件系统 from src.plugin_system.base.base_command import BaseCommand +from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor # 定义日志配置 +ENABLE_S4U_CHAT = True +# 仅内部开启 # 配置主程序日志格式 logger = get_logger("chat") @@ -30,6 +33,7 @@ class ChatBot: # 创建初始化PFC管理器的任务,会在_ensure_started时执行 self.only_process_chat = MessageProcessor() self.pfc_manager = PFCManager.get_instance() + self.s4u_message_processor = S4UMessageProcessor() async def _ensure_started(self): """确保所有任务已启动""" @@ -176,6 +180,14 @@ class ChatBot: # 如果在私聊中 if group_info is None: logger.debug("检测到私聊消息") + + if ENABLE_S4U_CHAT: + logger.debug("进入S4U私聊处理流程") + await self.s4u_message_processor.process_message(message) + return + + + if global_config.experimental.pfc_chatting: logger.debug("进入PFC私聊处理流程") # 创建聊天流 @@ -188,6 +200,13 @@ class ChatBot: await self.heartflow_message_receiver.process_message(message) # 群聊默认进入心流消息处理逻辑 else: + + if ENABLE_S4U_CHAT: + logger.debug("进入S4U私聊处理流程") + await self.s4u_message_processor.process_message(message) + return + + logger.debug(f"检测到群聊消息,群ID: {group_info.group_id}") await self.heartflow_message_receiver.process_message(message) diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index 84593bcff..580939f47 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -174,6 +174,7 @@ def _build_readable_messages_internal( truncate: bool = False, pic_id_mapping: Dict[str, str] = None, pic_counter: int = 1, + show_pic: bool = True, ) -> Tuple[str, List[Tuple[float, str, str]], Dict[str, str], int]: """ 内部辅助函数,构建可读消息字符串和原始消息详情列表。 @@ -260,7 +261,9 @@ def _build_readable_messages_internal( content = content.replace("ⁿ", "") # 处理图片ID - content = process_pic_ids(content) + if show_pic: + content = process_pic_ids(content) + # 检查必要信息是否存在 if not all([platform, user_id, timestamp is not None]): @@ -532,6 +535,7 @@ def build_readable_messages( read_mark: float = 0.0, truncate: bool = False, show_actions: bool = False, + show_pic: bool = True, ) -> str: """ 将消息列表转换为可读的文本格式。 @@ -601,7 +605,7 @@ def build_readable_messages( if read_mark <= 0: # 没有有效的 read_mark,直接格式化所有消息 formatted_string, _, pic_id_mapping, _ = _build_readable_messages_internal( - copy_messages, replace_bot_name, merge_messages, timestamp_mode, truncate + copy_messages, replace_bot_name, merge_messages, timestamp_mode, truncate, show_pic=show_pic ) # 生成图片映射信息并添加到最前面 @@ -628,9 +632,10 @@ def build_readable_messages( truncate, pic_id_mapping, pic_counter, + show_pic=show_pic ) formatted_after, _, pic_id_mapping, _ = _build_readable_messages_internal( - messages_after_mark, replace_bot_name, merge_messages, timestamp_mode, False, pic_id_mapping, pic_counter + messages_after_mark, replace_bot_name, merge_messages, timestamp_mode, False, pic_id_mapping, pic_counter, show_pic=show_pic ) read_mark_line = "\n--- 以上消息是你已经看过,请关注以下未读的新消息---\n" diff --git a/src/mais4u/mais4u_chat/s4u_chat.py b/src/mais4u/mais4u_chat/s4u_chat.py new file mode 100644 index 000000000..fbf4c29df --- /dev/null +++ b/src/mais4u/mais4u_chat/s4u_chat.py @@ -0,0 +1,302 @@ +import asyncio +import time +import traceback +import random +from typing import List, Optional, Dict # 导入类型提示 +import os +import pickle +from maim_message import UserInfo, Seg +from src.common.logger import get_logger +from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info +from src.manager.mood_manager import mood_manager +from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager +from src.chat.utils.timer_calculator import Timer +from src.chat.utils.prompt_builder import global_prompt_manager +from .s4u_stream_generator import S4UStreamGenerator +from src.chat.message_receive.message import MessageSending, MessageRecv, MessageThinking, MessageSet +from src.chat.message_receive.message_sender import message_manager +from src.chat.normal_chat.willing.willing_manager import get_willing_manager +from src.chat.normal_chat.normal_chat_utils import get_recent_message_stats +from src.config.config import global_config +from src.chat.focus_chat.planners.action_manager import ActionManager +from src.chat.normal_chat.normal_chat_planner import NormalChatPlanner +from src.chat.normal_chat.normal_chat_action_modifier import NormalChatActionModifier +from src.chat.normal_chat.normal_chat_expressor import NormalChatExpressor +from src.chat.focus_chat.replyer.default_generator import DefaultReplyer +from src.person_info.person_info import PersonInfoManager +from src.person_info.relationship_manager import get_relationship_manager +from src.chat.utils.chat_message_builder import ( + get_raw_msg_by_timestamp_with_chat, + get_raw_msg_by_timestamp_with_chat_inclusive, + get_raw_msg_before_timestamp_with_chat, + num_new_messages_since, +) +from src.common.message.api import get_global_api +from src.chat.message_receive.storage import MessageStorage +from src.audio.mock_audio import MockAudioGenerator, MockAudioPlayer + + +logger = get_logger("S4U_chat") + + +class MessageSenderContainer: + """一个简单的容器,用于按顺序发送消息并模拟打字效果。""" + def __init__(self, chat_stream: ChatStream, original_message: MessageRecv): + self.chat_stream = chat_stream + self.original_message = original_message + self.queue = asyncio.Queue() + self.storage = MessageStorage() + self._task: Optional[asyncio.Task] = None + self._paused_event = asyncio.Event() + self._paused_event.set() # 默认设置为非暂停状态 + + async def add_message(self, chunk: str): + """向队列中添加一个消息块。""" + await self.queue.put(chunk) + + async def close(self): + """表示没有更多消息了,关闭队列。""" + await self.queue.put(None) # Sentinel + + def pause(self): + """暂停发送。""" + self._paused_event.clear() + + def resume(self): + """恢复发送。""" + self._paused_event.set() + + def _calculate_typing_delay(self, text: str) -> float: + """根据文本长度计算模拟打字延迟。""" + chars_per_second = 15.0 + min_delay = 0.2 + max_delay = 2.0 + + delay = len(text) / chars_per_second + return max(min_delay, min(delay, max_delay)) + + async def _send_worker(self): + """从队列中取出消息并发送。""" + while True: + try: + # This structure ensures that task_done() is called for every item retrieved, + # even if the worker is cancelled while processing the item. + chunk = await self.queue.get() + except asyncio.CancelledError: + break + + try: + if chunk is None: + break + + # Check for pause signal *after* getting an item. + await self._paused_event.wait() + + delay = self._calculate_typing_delay(chunk) + await asyncio.sleep(delay) + + current_time = time.time() + msg_id = f"{current_time}_{random.randint(1000, 9999)}" + + text_to_send = chunk + if global_config.experimental.debug_show_chat_mode: + text_to_send += "ⁿ" + + message_segment = Seg(type="text", data=text_to_send) + bot_message = MessageSending( + message_id=msg_id, + chat_stream=self.chat_stream, + bot_user_info=UserInfo( + user_id=global_config.bot.qq_account, + user_nickname=global_config.bot.nickname, + platform=self.original_message.message_info.platform, + ), + sender_info=self.original_message.message_info.user_info, + message_segment=message_segment, + reply=self.original_message, + is_emoji=False, + apply_set_reply_logic=True, + ) + + await bot_message.process() + + await get_global_api().send_message(bot_message) + logger.info(f"已将消息 '{text_to_send}' 发往平台 '{bot_message.message_info.platform}'") + + await self.storage.store_message(bot_message, self.chat_stream) + + except Exception as e: + logger.error(f"[{self.chat_stream.get_stream_name()}] 消息发送或存储时出现错误: {e}", exc_info=True) + + finally: + # CRUCIAL: Always call task_done() for any item that was successfully retrieved. + self.queue.task_done() + + def start(self): + """启动发送任务。""" + if self._task is None: + self._task = asyncio.create_task(self._send_worker()) + + async def join(self): + """等待所有消息发送完毕。""" + if self._task: + await self._task + + +class S4UChatManager: + def __init__(self): + self.s4u_chats: Dict[str, "S4UChat"] = {} + + def get_or_create_chat(self, chat_stream: ChatStream) -> "S4UChat": + if chat_stream.stream_id not in self.s4u_chats: + stream_name = get_chat_manager().get_stream_name(chat_stream.stream_id) or chat_stream.stream_id + logger.info(f"Creating new S4UChat for stream: {stream_name}") + self.s4u_chats[chat_stream.stream_id] = S4UChat(chat_stream) + return self.s4u_chats[chat_stream.stream_id] + +s4u_chat_manager = S4UChatManager() + +def get_s4u_chat_manager() -> S4UChatManager: + return s4u_chat_manager + + +class S4UChat: + def __init__(self, chat_stream: ChatStream): + """初始化 S4UChat 实例。""" + + self.chat_stream = chat_stream + self.stream_id = chat_stream.stream_id + self.stream_name = get_chat_manager().get_stream_name(self.stream_id) or self.stream_id + + self._message_queue = asyncio.Queue() + self._processing_task = asyncio.create_task(self._message_processor()) + self._current_generation_task: Optional[asyncio.Task] = None + + self._is_replying = False + + # 初始化Normal Chat专用表达器 + self.expressor = NormalChatExpressor(self.chat_stream) + self.replyer = DefaultReplyer(self.chat_stream) + + self.gpt = S4UStreamGenerator() + self.audio_generator = MockAudioGenerator() + self.start_time = time.time() + + # 记录最近的回复内容,每项包含: {time, user_message, response, is_mentioned, is_reference_reply} + self.recent_replies = [] + self.max_replies_history = 20 # 最多保存最近20条回复记录 + + self.storage = MessageStorage() + + + logger.info(f"[{self.stream_name}] S4UChat") + + + # 改为实例方法, 移除 chat 参数 + async def response(self, message: MessageRecv, is_mentioned: bool, interested_rate: float) -> None: + """将消息放入队列并中断当前处理(如果正在处理)。""" + if self._current_generation_task and not self._current_generation_task.done(): + self._current_generation_task.cancel() + logger.info(f"[{self.stream_name}] 请求中断当前回复生成任务。") + + await self._message_queue.put(message) + + async def _message_processor(self): + """从队列中处理消息,支持中断。""" + while True: + try: + # 等待第一条消息 + message = await self._message_queue.get() + + # 如果因快速中断导致队列中积压了更多消息,则只处理最新的一条 + while not self._message_queue.empty(): + drained_msg = self._message_queue.get_nowait() + self._message_queue.task_done() # 为取出的旧消息调用 task_done + message = drained_msg # 始终处理最新消息 + logger.info(f"[{self.stream_name}] 丢弃过时消息,处理最新消息: {message.processed_plain_text}") + + self._current_generation_task = asyncio.create_task(self._generate_and_send(message)) + + try: + await self._current_generation_task + except asyncio.CancelledError: + logger.info(f"[{self.stream_name}] 回复生成被外部中断。") + except Exception as e: + logger.error(f"[{self.stream_name}] _generate_and_send 任务出现错误: {e}", exc_info=True) + finally: + self._current_generation_task = None + + except asyncio.CancelledError: + logger.info(f"[{self.stream_name}] 消息处理器正在关闭。") + break + except Exception as e: + logger.error(f"[{self.stream_name}] 消息处理器主循环发生未知错误: {e}", exc_info=True) + await asyncio.sleep(1) # 避免在未知错误下陷入CPU空转 + finally: + # 确保处理过的消息(无论是正常完成还是被丢弃)都被标记完成 + if 'message' in locals(): + self._message_queue.task_done() + + + async def _generate_and_send(self, message: MessageRecv): + """为单个消息生成文本和音频回复。整个过程可以被中断。""" + self._is_replying = True + sender_container = MessageSenderContainer(self.chat_stream, message) + sender_container.start() + + try: + logger.info( + f"[S4U] 开始为消息生成文本和音频流: " + f"'{message.processed_plain_text[:30]}...'" + ) + + # 1. 逐句生成文本、发送并播放音频 + gen = self.gpt.generate_response(message, "") + async for chunk in gen: + # 如果任务被取消,await 会在此处引发 CancelledError + + # a. 发送文本块 + await sender_container.add_message(chunk) + + # b. 为该文本块生成并播放音频 + if chunk.strip(): + audio_data = await self.audio_generator.generate(chunk) + player = MockAudioPlayer(audio_data) + await player.play() + + # 等待所有文本消息发送完成 + await sender_container.close() + await sender_container.join() + logger.info(f"[{self.stream_name}] 所有文本和音频块处理完毕。") + + except asyncio.CancelledError: + logger.info(f"[{self.stream_name}] 回复流程(文本或音频)被中断。") + raise # 将取消异常向上传播 + except Exception as e: + logger.error(f"[{self.stream_name}] 回复生成过程中出现错误: {e}", exc_info=True) + finally: + self._is_replying = False + # 确保发送器被妥善关闭(即使已关闭,再次调用也是安全的) + sender_container.resume() + if not sender_container._task.done(): + await sender_container.close() + await sender_container.join() + logger.info(f"[{self.stream_name}] _generate_and_send 任务结束,资源已清理。") + + + async def shutdown(self): + """平滑关闭处理任务。""" + logger.info(f"正在关闭 S4UChat: {self.stream_name}") + + # 取消正在运行的任务 + if self._current_generation_task and not self._current_generation_task.done(): + self._current_generation_task.cancel() + + if self._processing_task and not self._processing_task.done(): + self._processing_task.cancel() + + # 等待任务响应取消 + try: + await self._processing_task + except asyncio.CancelledError: + logger.info(f"处理任务已成功取消: {self.stream_name}") \ No newline at end of file diff --git a/src/mais4u/mais4u_chat/s4u_msg_processor.py b/src/mais4u/mais4u_chat/s4u_msg_processor.py new file mode 100644 index 000000000..8525b6a93 --- /dev/null +++ b/src/mais4u/mais4u_chat/s4u_msg_processor.py @@ -0,0 +1,70 @@ +from src.chat.memory_system.Hippocampus import hippocampus_manager +from src.config.config import global_config +from src.chat.message_receive.message import MessageRecv +from src.chat.message_receive.storage import MessageStorage +from src.chat.heart_flow.heartflow import heartflow +from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream +from src.chat.utils.utils import is_mentioned_bot_in_message +from src.chat.utils.timer_calculator import Timer +from src.common.logger import get_logger +from .s4u_chat import get_s4u_chat_manager + +import math +import re +import traceback +from typing import Optional, Tuple +from maim_message import UserInfo + +from src.person_info.relationship_manager import get_relationship_manager + +# from ..message_receive.message_buffer import message_buffer + +logger = get_logger("chat") + + +class S4UMessageProcessor: + """心流处理器,负责处理接收到的消息并计算兴趣度""" + + def __init__(self): + """初始化心流处理器,创建消息存储实例""" + self.storage = MessageStorage() + + async def process_message(self, message: MessageRecv) -> None: + """处理接收到的原始消息数据 + + 主要流程: + 1. 消息解析与初始化 + 2. 消息缓冲处理 + 3. 过滤检查 + 4. 兴趣度计算 + 5. 关系处理 + + Args: + message_data: 原始消息字符串 + """ + + target_user_id = "1026294844" + + # 1. 消息解析与初始化 + groupinfo = message.message_info.group_info + userinfo = message.message_info.user_info + messageinfo = message.message_info + + chat = await get_chat_manager().get_or_create_stream( + platform=messageinfo.platform, + user_info=userinfo, + group_info=groupinfo, + ) + + await self.storage.store_message(message, chat) + + is_mentioned = is_mentioned_bot_in_message(message) + s4u_chat = get_s4u_chat_manager().get_or_create_chat(chat) + + if userinfo.user_id == target_user_id: + await s4u_chat.response(message, is_mentioned=is_mentioned, interested_rate=1.0) + + + # 7. 日志记录 + logger.info(f"[S4U]{userinfo.user_nickname}:{message.processed_plain_text}") + diff --git a/src/mais4u/mais4u_chat/s4u_prompt.py b/src/mais4u/mais4u_chat/s4u_prompt.py new file mode 100644 index 000000000..b62d93552 --- /dev/null +++ b/src/mais4u/mais4u_chat/s4u_prompt.py @@ -0,0 +1,230 @@ + +from src.config.config import global_config +from src.common.logger import get_logger +from src.individuality.individuality import get_individuality +from src.chat.utils.prompt_builder import Prompt, global_prompt_manager +from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat +from src.chat.message_receive.message import MessageRecv +import time +from src.chat.utils.utils import get_recent_group_speaker +from src.chat.memory_system.Hippocampus import hippocampus_manager +import random + +from src.person_info.relationship_manager import get_relationship_manager + +logger = get_logger("prompt") + + +def init_prompt(): + Prompt("你正在qq群里聊天,下面是群里在聊的内容:", "chat_target_group1") + Prompt("你正在和{sender_name}聊天,这是你们之前聊的内容:", "chat_target_private1") + Prompt("在群里聊天", "chat_target_group2") + Prompt("和{sender_name}私聊", "chat_target_private2") + + Prompt("\n你有以下这些**知识**:\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n", "knowledge_prompt") + + + Prompt( + """ +你的名字叫{bot_name},昵称是:{bot_other_names},{prompt_personality}。 +你现在的主要任务是和 {sender_name} 聊天。同时,也有其他用户会参与你们的聊天,但是你主要还是关注你和{sender_name}的聊天内容。 + +{background_dialogue_prompt} +-------------------------------- +{now_time} +这是你和{sender_name}的对话,你们正在交流中: +{core_dialogue_prompt} + +{message_txt} +回复可以简短一些。可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞,平淡一些。 +不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出回复内容,现在{sender_name}正在等待你的回复。 +你的回复风格不要浮夸,有逻辑和条理,请你继续回复{sender_name}。""", + "s4u_prompt", # New template for private CHAT chat + ) + + +class PromptBuilder: + def __init__(self): + self.prompt_built = "" + self.activate_messages = "" + + async def build_prompt_normal( + self, + message, + chat_stream, + message_txt: str, + sender_name: str = "某人", + ) -> str: + prompt_personality = get_individuality().get_prompt(x_person=2, level=2) + is_group_chat = bool(chat_stream.group_info) + + who_chat_in_group = [] + if is_group_chat: + who_chat_in_group = get_recent_group_speaker( + chat_stream.stream_id, + (chat_stream.user_info.platform, chat_stream.user_info.user_id) if chat_stream.user_info else None, + limit=global_config.normal_chat.max_context_size, + ) + elif chat_stream.user_info: + who_chat_in_group.append( + (chat_stream.user_info.platform, chat_stream.user_info.user_id, chat_stream.user_info.user_nickname) + ) + + relation_prompt = "" + if global_config.relationship.enable_relationship: + for person in who_chat_in_group: + relationship_manager = get_relationship_manager() + relation_prompt += await relationship_manager.build_relationship_info(person) + + + memory_prompt = "" + related_memory = await hippocampus_manager.get_memory_from_text( + text=message_txt, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False + ) + + related_memory_info = "" + if related_memory: + for memory in related_memory: + related_memory_info += memory[1] + memory_prompt = await global_prompt_manager.format_prompt( + "memory_prompt", related_memory_info=related_memory_info + ) + + message_list_before_now = get_raw_msg_before_timestamp_with_chat( + chat_id=chat_stream.stream_id, + timestamp=time.time(), + limit=100, + ) + + + # 分别筛选核心对话和背景对话 + core_dialogue_list = [] + background_dialogue_list = [] + bot_id = str(global_config.bot.qq_account) + target_user_id = str(message.chat_stream.user_info.user_id) + + for msg_dict in message_list_before_now: + try: + # 直接通过字典访问 + msg_user_id = str(msg_dict.get('user_id')) + + if msg_user_id == bot_id or msg_user_id == target_user_id: + core_dialogue_list.append(msg_dict) + else: + background_dialogue_list.append(msg_dict) + except Exception as e: + logger.error(f"无法处理历史消息记录: {msg_dict}, 错误: {e}") + + if background_dialogue_list: + latest_25_msgs = background_dialogue_list[-25:] + background_dialogue_prompt = build_readable_messages( + latest_25_msgs, + merge_messages=True, + timestamp_mode = "normal_no_YMD", + show_pic = False, + ) + background_dialogue_prompt = f"这是其他用户的发言:\n{background_dialogue_prompt}" + else: + background_dialogue_prompt = "" + + # 分别获取最新50条和最新25条(从message_list_before_now截取) + core_dialogue_list = core_dialogue_list[-50:] + + first_msg = core_dialogue_list[0] + start_speaking_user_id = first_msg.get('user_id') + if start_speaking_user_id == bot_id: + last_speaking_user_id = bot_id + msg_seg_str = "你的发言:\n" + else: + start_speaking_user_id = target_user_id + last_speaking_user_id = start_speaking_user_id + msg_seg_str = "对方的发言:\n" + + msg_seg_str += f"{first_msg.get('processed_plain_text')}\n" + + all_msg_seg_list = [] + for msg in core_dialogue_list[1:]: + speaker = msg.get('user_id') + if speaker == last_speaking_user_id: + #还是同一个人讲话 + msg_seg_str += f"{msg.get('processed_plain_text')}\n" + else: + #换人了 + msg_seg_str = f"{msg_seg_str}\n" + all_msg_seg_list.append(msg_seg_str) + + if speaker == bot_id: + msg_seg_str = "你的发言:\n" + else: + msg_seg_str = "对方的发言:\n" + + msg_seg_str += f"{msg.get('processed_plain_text')}\n" + last_speaking_user_id = speaker + + all_msg_seg_list.append(msg_seg_str) + + + core_msg_str = "" + for msg in all_msg_seg_list: + # print(f"msg: {msg}") + core_msg_str += msg + + now_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + now_time = f"现在的时间是:{now_time}" + + template_name = "s4u_prompt" + effective_sender_name = sender_name + + prompt = await global_prompt_manager.format_prompt( + template_name, + relation_prompt=relation_prompt, + sender_name=effective_sender_name, + memory_prompt=memory_prompt, + core_dialogue_prompt=core_msg_str, + background_dialogue_prompt=background_dialogue_prompt, + message_txt=message_txt, + bot_name=global_config.bot.nickname, + bot_other_names="/".join(global_config.bot.alias_names), + prompt_personality=prompt_personality, + now_time=now_time, + ) + + return prompt + + +def weighted_sample_no_replacement(items, weights, k) -> list: + """ + 加权且不放回地随机抽取k个元素。 + + 参数: + items: 待抽取的元素列表 + weights: 每个元素对应的权重(与items等长,且为正数) + k: 需要抽取的元素个数 + 返回: + selected: 按权重加权且不重复抽取的k个元素组成的列表 + + 如果 items 中的元素不足 k 个,就只会返回所有可用的元素 + + 实现思路: + 每次从当前池中按权重加权随机选出一个元素,选中后将其从池中移除,重复k次。 + 这样保证了: + 1. count越大被选中概率越高 + 2. 不会重复选中同一个元素 + """ + selected = [] + pool = list(zip(items, weights)) + for _ in range(min(k, len(pool))): + total = sum(w for _, w in pool) + r = random.uniform(0, total) + upto = 0 + for idx, (item, weight) in enumerate(pool): + upto += weight + if upto >= r: + selected.append(item) + pool.pop(idx) + break + return selected + + +init_prompt() +prompt_builder = PromptBuilder() diff --git a/src/mais4u/mais4u_chat/s4u_stream_generator.py b/src/mais4u/mais4u_chat/s4u_stream_generator.py new file mode 100644 index 000000000..54df5aece --- /dev/null +++ b/src/mais4u/mais4u_chat/s4u_stream_generator.py @@ -0,0 +1,140 @@ +import os +from typing import AsyncGenerator +from src.llm_models.utils_model import LLMRequest +from src.mais4u.openai_client import AsyncOpenAIClient +from src.config.config import global_config +from src.chat.message_receive.message import MessageRecv +from src.mais4u.mais4u_chat.s4u_prompt import prompt_builder +from src.common.logger import get_logger +from src.person_info.person_info import PersonInfoManager, get_person_info_manager +import asyncio +import re + + +logger = get_logger("s4u_stream_generator") + + +class S4UStreamGenerator: + def __init__(self): + replyer_1_config = global_config.model.replyer_1 + provider = replyer_1_config.get("provider") + if not provider: + logger.error("`replyer_1` 在配置文件中缺少 `provider` 字段") + raise ValueError("`replyer_1` 在配置文件中缺少 `provider` 字段") + + api_key = os.environ.get(f"{provider.upper()}_KEY") + base_url = os.environ.get(f"{provider.upper()}_BASE_URL") + + if not api_key: + logger.error(f"环境变量 {provider.upper()}_KEY 未设置") + raise ValueError(f"环境变量 {provider.upper()}_KEY 未设置") + + self.client_1 = AsyncOpenAIClient(api_key=api_key, base_url=base_url) + self.model_1_name = replyer_1_config.get("name") + if not self.model_1_name: + logger.error("`replyer_1` 在配置文件中缺少 `model_name` 字段") + raise ValueError("`replyer_1` 在配置文件中缺少 `model_name` 字段") + self.replyer_1_config = replyer_1_config + + self.model_sum = LLMRequest(model=global_config.model.memory_summary, temperature=0.7, request_type="relation") + self.current_model_name = "unknown model" + + # 正则表达式用于按句子切分,同时处理各种标点和边缘情况 + # 匹配常见的句子结束符,但会忽略引号内和数字中的标点 + self.sentence_split_pattern = re.compile( + r'([^\s\w"\'([{]*["\'([{].*?["\'}\])][^\s\w"\'([{]*|' # 匹配被引号/括号包裹的内容 + r'[^.。!??!\n\r]+(?:[.。!??!\n\r](?![\'"])|$))' # 匹配直到句子结束符 + , re.UNICODE | re.DOTALL + ) + + async def generate_response( + self, message: MessageRecv, previous_reply_context: str = "" + ) -> AsyncGenerator[str, None]: + """根据当前模型类型选择对应的生成函数""" + # 从global_config中获取模型概率值并选择模型 + current_client = self.client_1 + self.current_model_name = self.model_1_name + + person_id = PersonInfoManager.get_person_id( + message.chat_stream.user_info.platform, message.chat_stream.user_info.user_id + ) + person_info_manager = get_person_info_manager() + person_name = await person_info_manager.get_value(person_id, "person_name") + + if message.chat_stream.user_info.user_nickname: + sender_name = f"[{message.chat_stream.user_info.user_nickname}](你叫ta{person_name})" + else: + sender_name = f"用户({message.chat_stream.user_info.user_id})" + + # 构建prompt + if previous_reply_context: + message_txt = f""" + 你正在回复用户的消息,但中途被打断了。这是已有的对话上下文: + [你已经对上一条消息说的话]: {previous_reply_context} + --- + [这是用户发来的新消息, 你需要结合上下文,对此进行回复]: + {message.processed_plain_text} + """ + else: + message_txt = message.processed_plain_text + + + prompt = await prompt_builder.build_prompt_normal( + message = message, + message_txt=message_txt, + sender_name=sender_name, + chat_stream=message.chat_stream, + ) + + logger.info( + f"{self.current_model_name}思考:{message_txt[:30] + '...' if len(message_txt) > 30 else message_txt}" + ) # noqa: E501 + + extra_kwargs = {} + if self.replyer_1_config.get("enable_thinking") is not None: + extra_kwargs["enable_thinking"] = self.replyer_1_config.get("enable_thinking") + if self.replyer_1_config.get("thinking_budget") is not None: + extra_kwargs["thinking_budget"] = self.replyer_1_config.get("thinking_budget") + + async for chunk in self._generate_response_with_model( + prompt, current_client, self.current_model_name, **extra_kwargs + ): + yield chunk + + async def _generate_response_with_model( + self, + prompt: str, + client: AsyncOpenAIClient, + model_name: str, + **kwargs, + ) -> AsyncGenerator[str, None]: + print(prompt) + + buffer = "" + delimiters = ",。!?,.!?\n\r" # For final trimming + + async for content in client.get_stream_content( + messages=[{"role": "user", "content": prompt}], model=model_name, **kwargs + ): + buffer += content + + # 使用正则表达式匹配句子 + last_match_end = 0 + for match in self.sentence_split_pattern.finditer(buffer): + sentence = match.group(0).strip() + if sentence: + # 如果句子看起来完整(即不只是等待更多内容),则发送 + if match.end(0) < len(buffer) or sentence.endswith(tuple(delimiters)): + yield sentence + await asyncio.sleep(0) # 允许其他任务运行 + last_match_end = match.end(0) + + # 从缓冲区移除已发送的部分 + if last_match_end > 0: + buffer = buffer[last_match_end:] + + # 发送缓冲区中剩余的任何内容 + if buffer.strip(): + yield buffer.strip() + await asyncio.sleep(0) + diff --git a/src/mais4u/openai_client.py b/src/mais4u/openai_client.py new file mode 100644 index 000000000..90d605a0c --- /dev/null +++ b/src/mais4u/openai_client.py @@ -0,0 +1,312 @@ +import asyncio +import json +from typing import AsyncGenerator, Dict, List, Optional, Union, Any +from dataclasses import dataclass +import aiohttp +from openai import AsyncOpenAI +from openai.types.chat import ChatCompletion, ChatCompletionChunk + + +@dataclass +class ChatMessage: + """聊天消息数据类""" + role: str + content: str + + def to_dict(self) -> Dict[str, str]: + return {"role": self.role, "content": self.content} + + +class AsyncOpenAIClient: + """异步OpenAI客户端,支持流式传输""" + + def __init__(self, api_key: str, base_url: Optional[str] = None): + """ + 初始化客户端 + + Args: + api_key: OpenAI API密钥 + base_url: 可选的API基础URL,用于自定义端点 + """ + self.client = AsyncOpenAI( + api_key=api_key, + base_url=base_url, + timeout=10.0, # 设置60秒的全局超时 + ) + + async def chat_completion( + self, + messages: List[Union[ChatMessage, Dict[str, str]]], + model: str = "gpt-3.5-turbo", + temperature: float = 0.7, + max_tokens: Optional[int] = None, + **kwargs + ) -> ChatCompletion: + """ + 非流式聊天完成 + + Args: + messages: 消息列表 + model: 模型名称 + temperature: 温度参数 + max_tokens: 最大token数 + **kwargs: 其他参数 + + Returns: + 完整的聊天回复 + """ + # 转换消息格式 + formatted_messages = [] + for msg in messages: + if isinstance(msg, ChatMessage): + formatted_messages.append(msg.to_dict()) + else: + formatted_messages.append(msg) + + extra_body = {} + if kwargs.get("enable_thinking") is not None: + extra_body["enable_thinking"] = kwargs.pop("enable_thinking") + if kwargs.get("thinking_budget") is not None: + extra_body["thinking_budget"] = kwargs.pop("thinking_budget") + + response = await self.client.chat.completions.create( + model=model, + messages=formatted_messages, + temperature=temperature, + max_tokens=max_tokens, + stream=False, + extra_body=extra_body if extra_body else None, + **kwargs + ) + + return response + + async def chat_completion_stream( + self, + messages: List[Union[ChatMessage, Dict[str, str]]], + model: str = "gpt-3.5-turbo", + temperature: float = 0.7, + max_tokens: Optional[int] = None, + **kwargs + ) -> AsyncGenerator[ChatCompletionChunk, None]: + """ + 流式聊天完成 + + Args: + messages: 消息列表 + model: 模型名称 + temperature: 温度参数 + max_tokens: 最大token数 + **kwargs: 其他参数 + + Yields: + ChatCompletionChunk: 流式响应块 + """ + # 转换消息格式 + formatted_messages = [] + for msg in messages: + if isinstance(msg, ChatMessage): + formatted_messages.append(msg.to_dict()) + else: + formatted_messages.append(msg) + + extra_body = {} + if kwargs.get("enable_thinking") is not None: + extra_body["enable_thinking"] = kwargs.pop("enable_thinking") + if kwargs.get("thinking_budget") is not None: + extra_body["thinking_budget"] = kwargs.pop("thinking_budget") + + stream = await self.client.chat.completions.create( + model=model, + messages=formatted_messages, + temperature=temperature, + max_tokens=max_tokens, + stream=True, + extra_body=extra_body if extra_body else None, + **kwargs + ) + + async for chunk in stream: + yield chunk + + async def get_stream_content( + self, + messages: List[Union[ChatMessage, Dict[str, str]]], + model: str = "gpt-3.5-turbo", + temperature: float = 0.7, + max_tokens: Optional[int] = None, + **kwargs + ) -> AsyncGenerator[str, None]: + """ + 获取流式内容(只返回文本内容) + + Args: + messages: 消息列表 + model: 模型名称 + temperature: 温度参数 + max_tokens: 最大token数 + **kwargs: 其他参数 + + Yields: + str: 文本内容片段 + """ + async for chunk in self.chat_completion_stream( + messages=messages, + model=model, + temperature=temperature, + max_tokens=max_tokens, + **kwargs + ): + if chunk.choices and chunk.choices[0].delta.content: + yield chunk.choices[0].delta.content + + async def collect_stream_response( + self, + messages: List[Union[ChatMessage, Dict[str, str]]], + model: str = "gpt-3.5-turbo", + temperature: float = 0.7, + max_tokens: Optional[int] = None, + **kwargs + ) -> str: + """ + 收集完整的流式响应 + + Args: + messages: 消息列表 + model: 模型名称 + temperature: 温度参数 + max_tokens: 最大token数 + **kwargs: 其他参数 + + Returns: + str: 完整的响应文本 + """ + full_response = "" + async for content in self.get_stream_content( + messages=messages, + model=model, + temperature=temperature, + max_tokens=max_tokens, + **kwargs + ): + full_response += content + + return full_response + + async def close(self): + """关闭客户端""" + await self.client.close() + + async def __aenter__(self): + """异步上下文管理器入口""" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """异步上下文管理器退出""" + await self.close() + + +class ConversationManager: + """对话管理器,用于管理对话历史""" + + def __init__(self, client: AsyncOpenAIClient, system_prompt: Optional[str] = None): + """ + 初始化对话管理器 + + Args: + client: OpenAI客户端实例 + system_prompt: 系统提示词 + """ + self.client = client + self.messages: List[ChatMessage] = [] + + if system_prompt: + self.messages.append(ChatMessage(role="system", content=system_prompt)) + + def add_user_message(self, content: str): + """添加用户消息""" + self.messages.append(ChatMessage(role="user", content=content)) + + def add_assistant_message(self, content: str): + """添加助手消息""" + self.messages.append(ChatMessage(role="assistant", content=content)) + + async def send_message_stream( + self, + content: str, + model: str = "gpt-3.5-turbo", + **kwargs + ) -> AsyncGenerator[str, None]: + """ + 发送消息并获取流式响应 + + Args: + content: 用户消息内容 + model: 模型名称 + **kwargs: 其他参数 + + Yields: + str: 响应内容片段 + """ + self.add_user_message(content) + + response_content = "" + async for chunk in self.client.get_stream_content( + messages=self.messages, + model=model, + **kwargs + ): + response_content += chunk + yield chunk + + self.add_assistant_message(response_content) + + async def send_message( + self, + content: str, + model: str = "gpt-3.5-turbo", + **kwargs + ) -> str: + """ + 发送消息并获取完整响应 + + Args: + content: 用户消息内容 + model: 模型名称 + **kwargs: 其他参数 + + Returns: + str: 完整响应 + """ + self.add_user_message(content) + + response = await self.client.chat_completion( + messages=self.messages, + model=model, + **kwargs + ) + + response_content = response.choices[0].message.content + self.add_assistant_message(response_content) + + return response_content + + def clear_history(self, keep_system: bool = True): + """ + 清除对话历史 + + Args: + keep_system: 是否保留系统消息 + """ + if keep_system and self.messages and self.messages[0].role == "system": + self.messages = [self.messages[0]] + else: + self.messages = [] + + def get_message_count(self) -> int: + """获取消息数量""" + return len(self.messages) + + def get_conversation_history(self) -> List[Dict[str, str]]: + """获取对话历史""" + return [msg.to_dict() for msg in self.messages] \ No newline at end of file