diff --git a/bot.py b/bot.py index 0cc8faeca..a3e49fceb 100644 --- a/bot.py +++ b/bot.py @@ -326,7 +326,6 @@ if __name__ == "__main__": # Wait for all tasks to complete (which they won't, normally) loop.run_until_complete(main_tasks) - except KeyboardInterrupt: # loop.run_until_complete(get_global_api().stop()) logger.warning("收到中断信号,正在优雅关闭...") diff --git a/src/audio/mock_audio.py b/src/audio/mock_audio.py index 73d7176af..9772fdad9 100644 --- a/src/audio/mock_audio.py +++ b/src/audio/mock_audio.py @@ -3,10 +3,12 @@ from src.common.logger import get_logger logger = get_logger("MockAudio") + class MockAudioPlayer: """ 一个模拟的音频播放器,它会根据音频数据的"长度"来模拟播放时间。 """ + def __init__(self, audio_data: bytes): self._audio_data = audio_data # 模拟音频时长:假设每 1024 字节代表 0.5 秒的音频 @@ -22,12 +24,14 @@ class MockAudioPlayer: logger.info("模拟音频播放完毕。") except asyncio.CancelledError: logger.info("音频播放被中断。") - raise # 重新抛出异常,以便上层逻辑可以捕获它 + raise # 重新抛出异常,以便上层逻辑可以捕获它 + class MockAudioGenerator: """ 一个模拟的文本到语音(TTS)生成器。 """ + def __init__(self): # 模拟生成速度:每秒生成的字符数 self.chars_per_second = 25.0 @@ -43,16 +47,16 @@ class MockAudioGenerator: 模拟的音频数据(bytes)。 """ if not text: - return b'' + return b"" generation_time = len(text) / self.chars_per_second logger.info(f"模拟生成音频... 文本长度: {len(text)}, 预计耗时: {generation_time:.2f} 秒...") try: await asyncio.sleep(generation_time) # 生成虚拟的音频数据,其长度与文本长度成正比 - mock_audio_data = b'\x01\x02\x03' * (len(text) * 40) + mock_audio_data = b"\x01\x02\x03" * (len(text) * 40) logger.info(f"模拟音频生成完毕,数据大小: {len(mock_audio_data) / 1024:.2f} KB。") return mock_audio_data except asyncio.CancelledError: logger.info("音频生成被中断。") - raise # 重新抛出异常 \ No newline at end of file + raise # 重新抛出异常 diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index 099f3c062..601b00390 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -180,14 +180,12 @@ class ChatBot: # 如果在私聊中 if group_info is None: logger.debug("检测到私聊消息") - + if ENABLE_S4U_CHAT: logger.debug("进入S4U私聊处理流程") await self.s4u_message_processor.process_message(message) return - - - + if global_config.experimental.pfc_chatting: logger.debug("进入PFC私聊处理流程") # 创建聊天流 @@ -200,13 +198,11 @@ class ChatBot: await self.heartflow_message_receiver.process_message(message) # 群聊默认进入心流消息处理逻辑 else: - if ENABLE_S4U_CHAT: logger.debug("进入S4U私聊处理流程") await self.s4u_message_processor.process_message(message) return - - + logger.debug(f"检测到群聊消息,群ID: {group_info.group_id}") await self.heartflow_message_receiver.process_message(message) diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index 84291dbf6..ef68d7852 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -323,7 +323,7 @@ class MessageSending(MessageProcessBase): self.is_head = is_head self.is_emoji = is_emoji self.apply_set_reply_logic = apply_set_reply_logic - + self.reply_to = reply_to # 用于显示发送内容与显示不一致的情况 diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index 58835a921..862354db7 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -35,11 +35,11 @@ class MessageStorage: filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL) else: filtered_display_message = "" - + reply_to = message.reply_to else: filtered_display_message = "" - + reply_to = "" chat_info_dict = chat_stream.to_dict() diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index 580939f47..2359abf30 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -263,7 +263,6 @@ def _build_readable_messages_internal( # 处理图片ID if show_pic: content = process_pic_ids(content) - # 检查必要信息是否存在 if not all([platform, user_id, timestamp is not None]): @@ -632,10 +631,17 @@ def build_readable_messages( truncate, pic_id_mapping, pic_counter, - show_pic=show_pic + show_pic=show_pic, ) formatted_after, _, pic_id_mapping, _ = _build_readable_messages_internal( - messages_after_mark, replace_bot_name, merge_messages, timestamp_mode, False, pic_id_mapping, pic_counter, show_pic=show_pic + messages_after_mark, + replace_bot_name, + merge_messages, + timestamp_mode, + False, + pic_id_mapping, + pic_counter, + show_pic=show_pic, ) read_mark_line = "\n--- 以上消息是你已经看过,请关注以下未读的新消息---\n" diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index 82bf28122..500852d00 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -126,7 +126,7 @@ class Messages(BaseModel): time = DoubleField() # 消息时间戳 chat_id = TextField(index=True) # 对应的 ChatStreams stream_id - + reply_to = TextField(null=True) # 从 chat_info 扁平化而来的字段 diff --git a/src/mais4u/mais4u_chat/s4u_chat.py b/src/mais4u/mais4u_chat/s4u_chat.py index c63f2bc9c..94ae9458e 100644 --- a/src/mais4u/mais4u_chat/s4u_chat.py +++ b/src/mais4u/mais4u_chat/s4u_chat.py @@ -1,39 +1,15 @@ import asyncio import time -import traceback import random -from typing import List, Optional, Dict # 导入类型提示 -import os -import pickle +from typing import Optional, Dict # 导入类型提示 from maim_message import UserInfo, Seg from src.common.logger import get_logger -from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info -from src.manager.mood_manager import mood_manager from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager -from src.chat.utils.timer_calculator import Timer -from src.chat.utils.prompt_builder import global_prompt_manager from .s4u_stream_generator import S4UStreamGenerator -from src.chat.message_receive.message import MessageSending, MessageRecv, MessageThinking, MessageSet -from src.chat.message_receive.message_sender import message_manager -from src.chat.normal_chat.willing.willing_manager import get_willing_manager -from src.chat.normal_chat.normal_chat_utils import get_recent_message_stats +from src.chat.message_receive.message import MessageSending, MessageRecv from src.config.config import global_config -from src.chat.focus_chat.planners.action_manager import ActionManager -from src.chat.normal_chat.normal_chat_planner import NormalChatPlanner -from src.chat.normal_chat.normal_chat_action_modifier import NormalChatActionModifier -from src.chat.normal_chat.normal_chat_expressor import NormalChatExpressor -from src.chat.focus_chat.replyer.default_generator import DefaultReplyer -from src.person_info.person_info import PersonInfoManager -from src.person_info.relationship_manager import get_relationship_manager -from src.chat.utils.chat_message_builder import ( - get_raw_msg_by_timestamp_with_chat, - get_raw_msg_by_timestamp_with_chat_inclusive, - get_raw_msg_before_timestamp_with_chat, - num_new_messages_since, -) from src.common.message.api import get_global_api from src.chat.message_receive.storage import MessageStorage -from src.audio.mock_audio import MockAudioGenerator, MockAudioPlayer logger = get_logger("S4U_chat") @@ -41,6 +17,7 @@ logger = get_logger("S4U_chat") class MessageSenderContainer: """一个简单的容器,用于按顺序发送消息并模拟打字效果。""" + def __init__(self, chat_stream: ChatStream, original_message: MessageRecv): self.chat_stream = chat_stream self.original_message = original_message @@ -71,7 +48,7 @@ class MessageSenderContainer: chars_per_second = 15.0 min_delay = 0.2 max_delay = 2.0 - + delay = len(text) / chars_per_second return max(min_delay, min(delay, max_delay)) @@ -98,7 +75,7 @@ class MessageSenderContainer: current_time = time.time() msg_id = f"{current_time}_{random.randint(1000, 9999)}" - + text_to_send = chunk if global_config.experimental.debug_show_chat_mode: text_to_send += "ⁿ" @@ -117,19 +94,19 @@ class MessageSenderContainer: reply=self.original_message, is_emoji=False, apply_set_reply_logic=True, - reply_to=f"{self.original_message.message_info.user_info.platform}:{self.original_message.message_info.user_info.user_id}" + reply_to=f"{self.original_message.message_info.user_info.platform}:{self.original_message.message_info.user_info.user_id}", ) - + await bot_message.process() - + await get_global_api().send_message(bot_message) logger.info(f"已将消息 '{text_to_send}' 发往平台 '{bot_message.message_info.platform}'") - + await self.storage.store_message(bot_message, self.chat_stream) - + except Exception as e: logger.error(f"[{self.chat_stream.get_stream_name()}] 消息发送或存储时出现错误: {e}", exc_info=True) - + finally: # CRUCIAL: Always call task_done() for any item that was successfully retrieved. self.queue.task_done() @@ -138,7 +115,7 @@ class MessageSenderContainer: """启动发送任务。""" if self._task is None: self._task = asyncio.create_task(self._send_worker()) - + async def join(self): """等待所有消息发送完毕。""" if self._task: @@ -156,8 +133,10 @@ class S4UChatManager: self.s4u_chats[chat_stream.stream_id] = S4UChat(chat_stream) return self.s4u_chats[chat_stream.stream_id] + s4u_chat_manager = S4UChatManager() + def get_s4u_chat_manager() -> S4UChatManager: return s4u_chat_manager @@ -169,22 +148,19 @@ class S4UChat: self.chat_stream = chat_stream self.stream_id = chat_stream.stream_id self.stream_name = get_chat_manager().get_stream_name(self.stream_id) or self.stream_id - + self._message_queue = asyncio.Queue() self._processing_task = asyncio.create_task(self._message_processor()) self._current_generation_task: Optional[asyncio.Task] = None self._current_message_being_replied: Optional[MessageRecv] = None - + self._is_replying = False self.gpt = S4UStreamGenerator() # self.audio_generator = MockAudioGenerator() - - logger.info(f"[{self.stream_name}] S4UChat") - # 改为实例方法, 移除 chat 参数 async def response(self, message: MessageRecv, is_mentioned: bool, interested_rate: float) -> None: """将消息放入队列并根据发信人决定是否中断当前处理。""" @@ -226,8 +202,8 @@ class S4UChat: # 如果因快速中断导致队列中积压了更多消息,则只处理最新的一条 while not self._message_queue.empty(): drained_msg = self._message_queue.get_nowait() - self._message_queue.task_done() # 为取出的旧消息调用 task_done - message = drained_msg # 始终处理最新消息 + self._message_queue.task_done() # 为取出的旧消息调用 task_done + message = drained_msg # 始终处理最新消息 self._current_message_being_replied = message logger.info(f"[{self.stream_name}] 丢弃过时消息,处理最新消息: {message.processed_plain_text}") @@ -242,44 +218,40 @@ class S4UChat: finally: self._current_generation_task = None self._current_message_being_replied = None - + except asyncio.CancelledError: logger.info(f"[{self.stream_name}] 消息处理器正在关闭。") break except Exception as e: logger.error(f"[{self.stream_name}] 消息处理器主循环发生未知错误: {e}", exc_info=True) - await asyncio.sleep(1) # 避免在未知错误下陷入CPU空转 + await asyncio.sleep(1) # 避免在未知错误下陷入CPU空转 finally: # 确保处理过的消息(无论是正常完成还是被丢弃)都被标记完成 - if 'message' in locals(): + if "message" in locals(): self._message_queue.task_done() - async def _generate_and_send(self, message: MessageRecv): """为单个消息生成文本和音频回复。整个过程可以被中断。""" self._is_replying = True sender_container = MessageSenderContainer(self.chat_stream, message) sender_container.start() - + try: - logger.info( - f"[S4U] 开始为消息生成文本和音频流: " - f"'{message.processed_plain_text[:30]}...'" - ) - + logger.info(f"[S4U] 开始为消息生成文本和音频流: '{message.processed_plain_text[:30]}...'") + # 1. 逐句生成文本、发送并播放音频 gen = self.gpt.generate_response(message, "") async for chunk in gen: # 如果任务被取消,await 会在此处引发 CancelledError - + # a. 发送文本块 await sender_container.add_message(chunk) - + # b. 为该文本块生成并播放音频 # if chunk.strip(): - # audio_data = await self.audio_generator.generate(chunk) - # player = MockAudioPlayer(audio_data) - # await player.play() + # audio_data = await self.audio_generator.generate(chunk) + # player = MockAudioPlayer(audio_data) + # await player.play() # 等待所有文本消息发送完成 await sender_container.close() @@ -300,20 +272,19 @@ class S4UChat: await sender_container.join() logger.info(f"[{self.stream_name}] _generate_and_send 任务结束,资源已清理。") - async def shutdown(self): """平滑关闭处理任务。""" logger.info(f"正在关闭 S4UChat: {self.stream_name}") - + # 取消正在运行的任务 if self._current_generation_task and not self._current_generation_task.done(): self._current_generation_task.cancel() - + if self._processing_task and not self._processing_task.done(): self._processing_task.cancel() - + # 等待任务响应取消 try: await self._processing_task except asyncio.CancelledError: - logger.info(f"处理任务已成功取消: {self.stream_name}") \ No newline at end of file + logger.info(f"处理任务已成功取消: {self.stream_name}") diff --git a/src/mais4u/mais4u_chat/s4u_msg_processor.py b/src/mais4u/mais4u_chat/s4u_msg_processor.py index 4a3737a70..c3a37e7b7 100644 --- a/src/mais4u/mais4u_chat/s4u_msg_processor.py +++ b/src/mais4u/mais4u_chat/s4u_msg_processor.py @@ -1,21 +1,10 @@ -from src.chat.memory_system.Hippocampus import hippocampus_manager -from src.config.config import global_config from src.chat.message_receive.message import MessageRecv from src.chat.message_receive.storage import MessageStorage -from src.chat.heart_flow.heartflow import heartflow -from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream +from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.utils.utils import is_mentioned_bot_in_message -from src.chat.utils.timer_calculator import Timer from src.common.logger import get_logger from .s4u_chat import get_s4u_chat_manager -import math -import re -import traceback -from typing import Optional, Tuple -from maim_message import UserInfo - -from src.person_info.relationship_manager import get_relationship_manager # from ..message_receive.message_buffer import message_buffer @@ -44,7 +33,7 @@ class S4UMessageProcessor: """ target_user_id_list = ["1026294844", "964959351"] - + # 1. 消息解析与初始化 groupinfo = message.message_info.group_info userinfo = message.message_info.user_info @@ -60,7 +49,7 @@ class S4UMessageProcessor: is_mentioned = is_mentioned_bot_in_message(message) s4u_chat = get_s4u_chat_manager().get_or_create_chat(chat) - + if userinfo.user_id in target_user_id_list: await s4u_chat.response(message, is_mentioned=is_mentioned, interested_rate=1.0) else: @@ -68,4 +57,3 @@ class S4UMessageProcessor: # 7. 日志记录 logger.info(f"[S4U]{userinfo.user_nickname}:{message.processed_plain_text}") - diff --git a/src/mais4u/mais4u_chat/s4u_prompt.py b/src/mais4u/mais4u_chat/s4u_prompt.py index 831058567..b9914f582 100644 --- a/src/mais4u/mais4u_chat/s4u_prompt.py +++ b/src/mais4u/mais4u_chat/s4u_prompt.py @@ -1,10 +1,8 @@ - from src.config.config import global_config from src.common.logger import get_logger from src.individuality.individuality import get_individuality from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat -from src.chat.message_receive.message import MessageRecv import time from src.chat.utils.utils import get_recent_group_speaker from src.chat.memory_system.Hippocampus import hippocampus_manager @@ -23,7 +21,6 @@ def init_prompt(): Prompt("\n你有以下这些**知识**:\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n", "knowledge_prompt") - Prompt( """ 你的名字叫{bot_name},昵称是:{bot_other_names},{prompt_personality}。 @@ -79,7 +76,6 @@ class PromptBuilder: relationship_manager = get_relationship_manager() relation_prompt += await relationship_manager.build_relationship_info(person) - memory_prompt = "" related_memory = await hippocampus_manager.get_memory_from_text( text=message_txt, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False @@ -98,23 +94,20 @@ class PromptBuilder: timestamp=time.time(), limit=100, ) - - + talk_type = message.message_info.platform + ":" + message.chat_stream.user_info.user_id print(f"talk_type: {talk_type}") - # 分别筛选核心对话和背景对话 core_dialogue_list = [] background_dialogue_list = [] bot_id = str(global_config.bot.qq_account) target_user_id = str(message.chat_stream.user_info.user_id) - for msg_dict in message_list_before_now: try: # 直接通过字典访问 - msg_user_id = str(msg_dict.get('user_id')) + msg_user_id = str(msg_dict.get("user_id")) if msg_user_id == bot_id: if msg_dict.get("reply_to") and talk_type == msg_dict.get("reply_to"): print(f"reply: {msg_dict.get('reply_to')}") @@ -127,24 +120,24 @@ class PromptBuilder: background_dialogue_list.append(msg_dict) except Exception as e: logger.error(f"无法处理历史消息记录: {msg_dict}, 错误: {e}") - + if background_dialogue_list: latest_25_msgs = background_dialogue_list[-25:] background_dialogue_prompt = build_readable_messages( latest_25_msgs, merge_messages=True, - timestamp_mode = "normal_no_YMD", - show_pic = False, + timestamp_mode="normal_no_YMD", + show_pic=False, ) background_dialogue_prompt = f"这是其他用户的发言:\n{background_dialogue_prompt}" else: background_dialogue_prompt = "" - + # 分别获取最新50条和最新25条(从message_list_before_now截取) core_dialogue_list = core_dialogue_list[-50:] - + first_msg = core_dialogue_list[0] - start_speaking_user_id = first_msg.get('user_id') + start_speaking_user_id = first_msg.get("user_id") if start_speaking_user_id == bot_id: last_speaking_user_id = bot_id msg_seg_str = "你的发言:\n" @@ -152,30 +145,33 @@ class PromptBuilder: start_speaking_user_id = target_user_id last_speaking_user_id = start_speaking_user_id msg_seg_str = "对方的发言:\n" - + msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(first_msg.get('time')))}: {first_msg.get('processed_plain_text')}\n" all_msg_seg_list = [] for msg in core_dialogue_list[1:]: - speaker = msg.get('user_id') + speaker = msg.get("user_id") if speaker == last_speaking_user_id: - #还是同一个人讲话 - msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(msg.get('time')))}: {msg.get('processed_plain_text')}\n" + # 还是同一个人讲话 + msg_seg_str += ( + f"{time.strftime('%H:%M:%S', time.localtime(msg.get('time')))}: {msg.get('processed_plain_text')}\n" + ) else: - #换人了 + # 换人了 msg_seg_str = f"{msg_seg_str}\n" all_msg_seg_list.append(msg_seg_str) - + if speaker == bot_id: msg_seg_str = "你的发言:\n" else: msg_seg_str = "对方的发言:\n" - - msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(msg.get('time')))}: {msg.get('processed_plain_text')}\n" - last_speaking_user_id = speaker - - all_msg_seg_list.append(msg_seg_str) + msg_seg_str += ( + f"{time.strftime('%H:%M:%S', time.localtime(msg.get('time')))}: {msg.get('processed_plain_text')}\n" + ) + last_speaking_user_id = speaker + + all_msg_seg_list.append(msg_seg_str) core_msg_str = "" for msg in all_msg_seg_list: diff --git a/src/mais4u/mais4u_chat/s4u_stream_generator.py b/src/mais4u/mais4u_chat/s4u_stream_generator.py index 0b27df958..ec8b48959 100644 --- a/src/mais4u/mais4u_chat/s4u_stream_generator.py +++ b/src/mais4u/mais4u_chat/s4u_stream_generator.py @@ -43,8 +43,8 @@ class S4UStreamGenerator: # 匹配常见的句子结束符,但会忽略引号内和数字中的标点 self.sentence_split_pattern = re.compile( r'([^\s\w"\'([{]*["\'([{].*?["\'}\])][^\s\w"\'([{]*|' # 匹配被引号/括号包裹的内容 - r'[^.。!??!\n\r]+(?:[.。!??!\n\r](?![\'"])|$))' # 匹配直到句子结束符 - , re.UNICODE | re.DOTALL + r'[^.。!??!\n\r]+(?:[.。!??!\n\r](?![\'"])|$))', # 匹配直到句子结束符 + re.UNICODE | re.DOTALL, ) async def generate_response( @@ -68,7 +68,7 @@ class S4UStreamGenerator: # 构建prompt if previous_reply_context: - message_txt = f""" + message_txt = f""" 你正在回复用户的消息,但中途被打断了。这是已有的对话上下文: [你已经对上一条消息说的话]: {previous_reply_context} --- @@ -78,9 +78,8 @@ class S4UStreamGenerator: else: message_txt = message.processed_plain_text - prompt = await prompt_builder.build_prompt_normal( - message = message, + message=message, message_txt=message_txt, sender_name=sender_name, chat_stream=message.chat_stream, @@ -109,16 +108,16 @@ class S4UStreamGenerator: **kwargs, ) -> AsyncGenerator[str, None]: print(prompt) - + buffer = "" delimiters = ",。!?,.!?\n\r" # For final trimming punctuation_buffer = "" - + async for content in client.get_stream_content( messages=[{"role": "user", "content": prompt}], model=model_name, **kwargs ): buffer += content - + # 使用正则表达式匹配句子 last_match_end = 0 for match in self.sentence_split_pattern.finditer(buffer): @@ -132,24 +131,23 @@ class S4UStreamGenerator: else: # 发送之前累积的标点和当前句子 to_yield = punctuation_buffer + sentence - if to_yield.endswith((',', ',')): - to_yield = to_yield.rstrip(',,') - + if to_yield.endswith((",", ",")): + to_yield = to_yield.rstrip(",,") + yield to_yield - punctuation_buffer = "" # 清空标点符号缓冲区 - await asyncio.sleep(0) # 允许其他任务运行 - + punctuation_buffer = "" # 清空标点符号缓冲区 + await asyncio.sleep(0) # 允许其他任务运行 + last_match_end = match.end(0) - + # 从缓冲区移除已发送的部分 if last_match_end > 0: buffer = buffer[last_match_end:] - + # 发送缓冲区中剩余的任何内容 to_yield = (punctuation_buffer + buffer).strip() if to_yield: - if to_yield.endswith((',', ',')): - to_yield = to_yield.rstrip(',,') + if to_yield.endswith((",", ",")): + to_yield = to_yield.rstrip(",,") if to_yield: yield to_yield - diff --git a/src/mais4u/openai_client.py b/src/mais4u/openai_client.py index 90d605a0c..2a5873dec 100644 --- a/src/mais4u/openai_client.py +++ b/src/mais4u/openai_client.py @@ -1,8 +1,5 @@ -import asyncio -import json -from typing import AsyncGenerator, Dict, List, Optional, Union, Any +from typing import AsyncGenerator, Dict, List, Optional, Union from dataclasses import dataclass -import aiohttp from openai import AsyncOpenAI from openai.types.chat import ChatCompletion, ChatCompletionChunk @@ -10,20 +7,21 @@ from openai.types.chat import ChatCompletion, ChatCompletionChunk @dataclass class ChatMessage: """聊天消息数据类""" + role: str content: str - + def to_dict(self) -> Dict[str, str]: return {"role": self.role, "content": self.content} class AsyncOpenAIClient: """异步OpenAI客户端,支持流式传输""" - + def __init__(self, api_key: str, base_url: Optional[str] = None): """ 初始化客户端 - + Args: api_key: OpenAI API密钥 base_url: 可选的API基础URL,用于自定义端点 @@ -33,25 +31,25 @@ class AsyncOpenAIClient: base_url=base_url, timeout=10.0, # 设置60秒的全局超时 ) - + async def chat_completion( self, messages: List[Union[ChatMessage, Dict[str, str]]], model: str = "gpt-3.5-turbo", temperature: float = 0.7, max_tokens: Optional[int] = None, - **kwargs + **kwargs, ) -> ChatCompletion: """ 非流式聊天完成 - + Args: messages: 消息列表 model: 模型名称 temperature: 温度参数 max_tokens: 最大token数 **kwargs: 其他参数 - + Returns: 完整的聊天回复 """ @@ -62,7 +60,7 @@ class AsyncOpenAIClient: formatted_messages.append(msg.to_dict()) else: formatted_messages.append(msg) - + extra_body = {} if kwargs.get("enable_thinking") is not None: extra_body["enable_thinking"] = kwargs.pop("enable_thinking") @@ -76,29 +74,29 @@ class AsyncOpenAIClient: max_tokens=max_tokens, stream=False, extra_body=extra_body if extra_body else None, - **kwargs + **kwargs, ) - + return response - + async def chat_completion_stream( self, messages: List[Union[ChatMessage, Dict[str, str]]], model: str = "gpt-3.5-turbo", temperature: float = 0.7, max_tokens: Optional[int] = None, - **kwargs + **kwargs, ) -> AsyncGenerator[ChatCompletionChunk, None]: """ 流式聊天完成 - + Args: messages: 消息列表 model: 模型名称 temperature: 温度参数 max_tokens: 最大token数 **kwargs: 其他参数 - + Yields: ChatCompletionChunk: 流式响应块 """ @@ -109,7 +107,7 @@ class AsyncOpenAIClient: formatted_messages.append(msg.to_dict()) else: formatted_messages.append(msg) - + extra_body = {} if kwargs.get("enable_thinking") is not None: extra_body["enable_thinking"] = kwargs.pop("enable_thinking") @@ -123,84 +121,76 @@ class AsyncOpenAIClient: max_tokens=max_tokens, stream=True, extra_body=extra_body if extra_body else None, - **kwargs + **kwargs, ) - + async for chunk in stream: yield chunk - + async def get_stream_content( self, messages: List[Union[ChatMessage, Dict[str, str]]], model: str = "gpt-3.5-turbo", temperature: float = 0.7, max_tokens: Optional[int] = None, - **kwargs + **kwargs, ) -> AsyncGenerator[str, None]: """ 获取流式内容(只返回文本内容) - + Args: messages: 消息列表 model: 模型名称 temperature: 温度参数 max_tokens: 最大token数 **kwargs: 其他参数 - + Yields: str: 文本内容片段 """ async for chunk in self.chat_completion_stream( - messages=messages, - model=model, - temperature=temperature, - max_tokens=max_tokens, - **kwargs + messages=messages, model=model, temperature=temperature, max_tokens=max_tokens, **kwargs ): if chunk.choices and chunk.choices[0].delta.content: yield chunk.choices[0].delta.content - + async def collect_stream_response( self, messages: List[Union[ChatMessage, Dict[str, str]]], model: str = "gpt-3.5-turbo", temperature: float = 0.7, max_tokens: Optional[int] = None, - **kwargs + **kwargs, ) -> str: """ 收集完整的流式响应 - + Args: messages: 消息列表 model: 模型名称 temperature: 温度参数 max_tokens: 最大token数 **kwargs: 其他参数 - + Returns: str: 完整的响应文本 """ full_response = "" async for content in self.get_stream_content( - messages=messages, - model=model, - temperature=temperature, - max_tokens=max_tokens, - **kwargs + messages=messages, model=model, temperature=temperature, max_tokens=max_tokens, **kwargs ): full_response += content - + return full_response - + async def close(self): """关闭客户端""" await self.client.close() - + async def __aenter__(self): """异步上下文管理器入口""" return self - + async def __aexit__(self, exc_type, exc_val, exc_tb): """异步上下文管理器退出""" await self.close() @@ -208,93 +198,77 @@ class AsyncOpenAIClient: class ConversationManager: """对话管理器,用于管理对话历史""" - + def __init__(self, client: AsyncOpenAIClient, system_prompt: Optional[str] = None): """ 初始化对话管理器 - + Args: client: OpenAI客户端实例 system_prompt: 系统提示词 """ self.client = client self.messages: List[ChatMessage] = [] - + if system_prompt: self.messages.append(ChatMessage(role="system", content=system_prompt)) - + def add_user_message(self, content: str): """添加用户消息""" self.messages.append(ChatMessage(role="user", content=content)) - + def add_assistant_message(self, content: str): """添加助手消息""" self.messages.append(ChatMessage(role="assistant", content=content)) - + async def send_message_stream( - self, - content: str, - model: str = "gpt-3.5-turbo", - **kwargs + self, content: str, model: str = "gpt-3.5-turbo", **kwargs ) -> AsyncGenerator[str, None]: """ 发送消息并获取流式响应 - + Args: content: 用户消息内容 model: 模型名称 **kwargs: 其他参数 - + Yields: str: 响应内容片段 """ self.add_user_message(content) - + response_content = "" - async for chunk in self.client.get_stream_content( - messages=self.messages, - model=model, - **kwargs - ): + async for chunk in self.client.get_stream_content(messages=self.messages, model=model, **kwargs): response_content += chunk yield chunk - + self.add_assistant_message(response_content) - - async def send_message( - self, - content: str, - model: str = "gpt-3.5-turbo", - **kwargs - ) -> str: + + async def send_message(self, content: str, model: str = "gpt-3.5-turbo", **kwargs) -> str: """ 发送消息并获取完整响应 - + Args: content: 用户消息内容 model: 模型名称 **kwargs: 其他参数 - + Returns: str: 完整响应 """ self.add_user_message(content) - - response = await self.client.chat_completion( - messages=self.messages, - model=model, - **kwargs - ) - + + response = await self.client.chat_completion(messages=self.messages, model=model, **kwargs) + response_content = response.choices[0].message.content self.add_assistant_message(response_content) - + return response_content - + def clear_history(self, keep_system: bool = True): """ 清除对话历史 - + Args: keep_system: 是否保留系统消息 """ @@ -302,11 +276,11 @@ class ConversationManager: self.messages = [self.messages[0]] else: self.messages = [] - + def get_message_count(self) -> int: """获取消息数量""" return len(self.messages) - + def get_conversation_history(self) -> List[Dict[str, str]]: """获取对话历史""" - return [msg.to_dict() for msg in self.messages] \ No newline at end of file + return [msg.to_dict() for msg in self.messages]