🤖 自动格式化代码 [skip ci]

This commit is contained in:
github-actions[bot]
2025-07-01 11:33:16 +00:00
parent 3ef3923a8b
commit 324b294b5f
12 changed files with 157 additions and 225 deletions

1
bot.py
View File

@@ -326,7 +326,6 @@ if __name__ == "__main__":
# Wait for all tasks to complete (which they won't, normally) # Wait for all tasks to complete (which they won't, normally)
loop.run_until_complete(main_tasks) loop.run_until_complete(main_tasks)
except KeyboardInterrupt: except KeyboardInterrupt:
# loop.run_until_complete(get_global_api().stop()) # loop.run_until_complete(get_global_api().stop())
logger.warning("收到中断信号,正在优雅关闭...") logger.warning("收到中断信号,正在优雅关闭...")

View File

@@ -3,10 +3,12 @@ from src.common.logger import get_logger
logger = get_logger("MockAudio") logger = get_logger("MockAudio")
class MockAudioPlayer: class MockAudioPlayer:
""" """
一个模拟的音频播放器,它会根据音频数据的"长度"来模拟播放时间。 一个模拟的音频播放器,它会根据音频数据的"长度"来模拟播放时间。
""" """
def __init__(self, audio_data: bytes): def __init__(self, audio_data: bytes):
self._audio_data = audio_data self._audio_data = audio_data
# 模拟音频时长:假设每 1024 字节代表 0.5 秒的音频 # 模拟音频时长:假设每 1024 字节代表 0.5 秒的音频
@@ -24,10 +26,12 @@ class MockAudioPlayer:
logger.info("音频播放被中断。") logger.info("音频播放被中断。")
raise # 重新抛出异常,以便上层逻辑可以捕获它 raise # 重新抛出异常,以便上层逻辑可以捕获它
class MockAudioGenerator: class MockAudioGenerator:
""" """
一个模拟的文本到语音TTS生成器。 一个模拟的文本到语音TTS生成器。
""" """
def __init__(self): def __init__(self):
# 模拟生成速度:每秒生成的字符数 # 模拟生成速度:每秒生成的字符数
self.chars_per_second = 25.0 self.chars_per_second = 25.0
@@ -43,14 +47,14 @@ class MockAudioGenerator:
模拟的音频数据bytes 模拟的音频数据bytes
""" """
if not text: if not text:
return b'' return b""
generation_time = len(text) / self.chars_per_second generation_time = len(text) / self.chars_per_second
logger.info(f"模拟生成音频... 文本长度: {len(text)}, 预计耗时: {generation_time:.2f} 秒...") logger.info(f"模拟生成音频... 文本长度: {len(text)}, 预计耗时: {generation_time:.2f} 秒...")
try: try:
await asyncio.sleep(generation_time) await asyncio.sleep(generation_time)
# 生成虚拟的音频数据,其长度与文本长度成正比 # 生成虚拟的音频数据,其长度与文本长度成正比
mock_audio_data = b'\x01\x02\x03' * (len(text) * 40) mock_audio_data = b"\x01\x02\x03" * (len(text) * 40)
logger.info(f"模拟音频生成完毕,数据大小: {len(mock_audio_data) / 1024:.2f} KB。") logger.info(f"模拟音频生成完毕,数据大小: {len(mock_audio_data) / 1024:.2f} KB。")
return mock_audio_data return mock_audio_data
except asyncio.CancelledError: except asyncio.CancelledError:

View File

@@ -186,8 +186,6 @@ class ChatBot:
await self.s4u_message_processor.process_message(message) await self.s4u_message_processor.process_message(message)
return return
if global_config.experimental.pfc_chatting: if global_config.experimental.pfc_chatting:
logger.debug("进入PFC私聊处理流程") logger.debug("进入PFC私聊处理流程")
# 创建聊天流 # 创建聊天流
@@ -200,13 +198,11 @@ class ChatBot:
await self.heartflow_message_receiver.process_message(message) await self.heartflow_message_receiver.process_message(message)
# 群聊默认进入心流消息处理逻辑 # 群聊默认进入心流消息处理逻辑
else: else:
if ENABLE_S4U_CHAT: if ENABLE_S4U_CHAT:
logger.debug("进入S4U私聊处理流程") logger.debug("进入S4U私聊处理流程")
await self.s4u_message_processor.process_message(message) await self.s4u_message_processor.process_message(message)
return return
logger.debug(f"检测到群聊消息群ID: {group_info.group_id}") logger.debug(f"检测到群聊消息群ID: {group_info.group_id}")
await self.heartflow_message_receiver.process_message(message) await self.heartflow_message_receiver.process_message(message)

View File

@@ -264,7 +264,6 @@ def _build_readable_messages_internal(
if show_pic: if show_pic:
content = process_pic_ids(content) content = process_pic_ids(content)
# 检查必要信息是否存在 # 检查必要信息是否存在
if not all([platform, user_id, timestamp is not None]): if not all([platform, user_id, timestamp is not None]):
continue continue
@@ -632,10 +631,17 @@ def build_readable_messages(
truncate, truncate,
pic_id_mapping, pic_id_mapping,
pic_counter, pic_counter,
show_pic=show_pic show_pic=show_pic,
) )
formatted_after, _, pic_id_mapping, _ = _build_readable_messages_internal( formatted_after, _, pic_id_mapping, _ = _build_readable_messages_internal(
messages_after_mark, replace_bot_name, merge_messages, timestamp_mode, False, pic_id_mapping, pic_counter, show_pic=show_pic messages_after_mark,
replace_bot_name,
merge_messages,
timestamp_mode,
False,
pic_id_mapping,
pic_counter,
show_pic=show_pic,
) )
read_mark_line = "\n--- 以上消息是你已经看过,请关注以下未读的新消息---\n" read_mark_line = "\n--- 以上消息是你已经看过,请关注以下未读的新消息---\n"

View File

@@ -1,39 +1,15 @@
import asyncio import asyncio
import time import time
import traceback
import random import random
from typing import List, Optional, Dict # 导入类型提示 from typing import Optional, Dict # 导入类型提示
import os
import pickle
from maim_message import UserInfo, Seg from maim_message import UserInfo, Seg
from src.common.logger import get_logger from src.common.logger import get_logger
from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info
from src.manager.mood_manager import mood_manager
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.utils.timer_calculator import Timer
from src.chat.utils.prompt_builder import global_prompt_manager
from .s4u_stream_generator import S4UStreamGenerator from .s4u_stream_generator import S4UStreamGenerator
from src.chat.message_receive.message import MessageSending, MessageRecv, MessageThinking, MessageSet from src.chat.message_receive.message import MessageSending, MessageRecv
from src.chat.message_receive.message_sender import message_manager
from src.chat.normal_chat.willing.willing_manager import get_willing_manager
from src.chat.normal_chat.normal_chat_utils import get_recent_message_stats
from src.config.config import global_config from src.config.config import global_config
from src.chat.focus_chat.planners.action_manager import ActionManager
from src.chat.normal_chat.normal_chat_planner import NormalChatPlanner
from src.chat.normal_chat.normal_chat_action_modifier import NormalChatActionModifier
from src.chat.normal_chat.normal_chat_expressor import NormalChatExpressor
from src.chat.focus_chat.replyer.default_generator import DefaultReplyer
from src.person_info.person_info import PersonInfoManager
from src.person_info.relationship_manager import get_relationship_manager
from src.chat.utils.chat_message_builder import (
get_raw_msg_by_timestamp_with_chat,
get_raw_msg_by_timestamp_with_chat_inclusive,
get_raw_msg_before_timestamp_with_chat,
num_new_messages_since,
)
from src.common.message.api import get_global_api from src.common.message.api import get_global_api
from src.chat.message_receive.storage import MessageStorage from src.chat.message_receive.storage import MessageStorage
from src.audio.mock_audio import MockAudioGenerator, MockAudioPlayer
logger = get_logger("S4U_chat") logger = get_logger("S4U_chat")
@@ -41,6 +17,7 @@ logger = get_logger("S4U_chat")
class MessageSenderContainer: class MessageSenderContainer:
"""一个简单的容器,用于按顺序发送消息并模拟打字效果。""" """一个简单的容器,用于按顺序发送消息并模拟打字效果。"""
def __init__(self, chat_stream: ChatStream, original_message: MessageRecv): def __init__(self, chat_stream: ChatStream, original_message: MessageRecv):
self.chat_stream = chat_stream self.chat_stream = chat_stream
self.original_message = original_message self.original_message = original_message
@@ -117,7 +94,7 @@ class MessageSenderContainer:
reply=self.original_message, reply=self.original_message,
is_emoji=False, is_emoji=False,
apply_set_reply_logic=True, apply_set_reply_logic=True,
reply_to=f"{self.original_message.message_info.user_info.platform}:{self.original_message.message_info.user_info.user_id}" reply_to=f"{self.original_message.message_info.user_info.platform}:{self.original_message.message_info.user_info.user_id}",
) )
await bot_message.process() await bot_message.process()
@@ -156,8 +133,10 @@ class S4UChatManager:
self.s4u_chats[chat_stream.stream_id] = S4UChat(chat_stream) self.s4u_chats[chat_stream.stream_id] = S4UChat(chat_stream)
return self.s4u_chats[chat_stream.stream_id] return self.s4u_chats[chat_stream.stream_id]
s4u_chat_manager = S4UChatManager() s4u_chat_manager = S4UChatManager()
def get_s4u_chat_manager() -> S4UChatManager: def get_s4u_chat_manager() -> S4UChatManager:
return s4u_chat_manager return s4u_chat_manager
@@ -180,11 +159,8 @@ class S4UChat:
self.gpt = S4UStreamGenerator() self.gpt = S4UStreamGenerator()
# self.audio_generator = MockAudioGenerator() # self.audio_generator = MockAudioGenerator()
logger.info(f"[{self.stream_name}] S4UChat") logger.info(f"[{self.stream_name}] S4UChat")
# 改为实例方法, 移除 chat 参数 # 改为实例方法, 移除 chat 参数
async def response(self, message: MessageRecv, is_mentioned: bool, interested_rate: float) -> None: async def response(self, message: MessageRecv, is_mentioned: bool, interested_rate: float) -> None:
"""将消息放入队列并根据发信人决定是否中断当前处理。""" """将消息放入队列并根据发信人决定是否中断当前处理。"""
@@ -251,10 +227,9 @@ class S4UChat:
await asyncio.sleep(1) # 避免在未知错误下陷入CPU空转 await asyncio.sleep(1) # 避免在未知错误下陷入CPU空转
finally: finally:
# 确保处理过的消息(无论是正常完成还是被丢弃)都被标记完成 # 确保处理过的消息(无论是正常完成还是被丢弃)都被标记完成
if 'message' in locals(): if "message" in locals():
self._message_queue.task_done() self._message_queue.task_done()
async def _generate_and_send(self, message: MessageRecv): async def _generate_and_send(self, message: MessageRecv):
"""为单个消息生成文本和音频回复。整个过程可以被中断。""" """为单个消息生成文本和音频回复。整个过程可以被中断。"""
self._is_replying = True self._is_replying = True
@@ -262,10 +237,7 @@ class S4UChat:
sender_container.start() sender_container.start()
try: try:
logger.info( logger.info(f"[S4U] 开始为消息生成文本和音频流: '{message.processed_plain_text[:30]}...'")
f"[S4U] 开始为消息生成文本和音频流: "
f"'{message.processed_plain_text[:30]}...'"
)
# 1. 逐句生成文本、发送并播放音频 # 1. 逐句生成文本、发送并播放音频
gen = self.gpt.generate_response(message, "") gen = self.gpt.generate_response(message, "")
@@ -300,7 +272,6 @@ class S4UChat:
await sender_container.join() await sender_container.join()
logger.info(f"[{self.stream_name}] _generate_and_send 任务结束,资源已清理。") logger.info(f"[{self.stream_name}] _generate_and_send 任务结束,资源已清理。")
async def shutdown(self): async def shutdown(self):
"""平滑关闭处理任务。""" """平滑关闭处理任务。"""
logger.info(f"正在关闭 S4UChat: {self.stream_name}") logger.info(f"正在关闭 S4UChat: {self.stream_name}")

View File

@@ -1,21 +1,10 @@
from src.chat.memory_system.Hippocampus import hippocampus_manager
from src.config.config import global_config
from src.chat.message_receive.message import MessageRecv from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.storage import MessageStorage from src.chat.message_receive.storage import MessageStorage
from src.chat.heart_flow.heartflow import heartflow from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream
from src.chat.utils.utils import is_mentioned_bot_in_message from src.chat.utils.utils import is_mentioned_bot_in_message
from src.chat.utils.timer_calculator import Timer
from src.common.logger import get_logger from src.common.logger import get_logger
from .s4u_chat import get_s4u_chat_manager from .s4u_chat import get_s4u_chat_manager
import math
import re
import traceback
from typing import Optional, Tuple
from maim_message import UserInfo
from src.person_info.relationship_manager import get_relationship_manager
# from ..message_receive.message_buffer import message_buffer # from ..message_receive.message_buffer import message_buffer
@@ -68,4 +57,3 @@ class S4UMessageProcessor:
# 7. 日志记录 # 7. 日志记录
logger.info(f"[S4U]{userinfo.user_nickname}:{message.processed_plain_text}") logger.info(f"[S4U]{userinfo.user_nickname}:{message.processed_plain_text}")

View File

@@ -1,10 +1,8 @@
from src.config.config import global_config from src.config.config import global_config
from src.common.logger import get_logger from src.common.logger import get_logger
from src.individuality.individuality import get_individuality from src.individuality.individuality import get_individuality
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
from src.chat.message_receive.message import MessageRecv
import time import time
from src.chat.utils.utils import get_recent_group_speaker from src.chat.utils.utils import get_recent_group_speaker
from src.chat.memory_system.Hippocampus import hippocampus_manager from src.chat.memory_system.Hippocampus import hippocampus_manager
@@ -23,7 +21,6 @@ def init_prompt():
Prompt("\n你有以下这些**知识**\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n", "knowledge_prompt") Prompt("\n你有以下这些**知识**\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n", "knowledge_prompt")
Prompt( Prompt(
""" """
你的名字叫{bot_name},昵称是:{bot_other_names}{prompt_personality} 你的名字叫{bot_name},昵称是:{bot_other_names}{prompt_personality}
@@ -79,7 +76,6 @@ class PromptBuilder:
relationship_manager = get_relationship_manager() relationship_manager = get_relationship_manager()
relation_prompt += await relationship_manager.build_relationship_info(person) relation_prompt += await relationship_manager.build_relationship_info(person)
memory_prompt = "" memory_prompt = ""
related_memory = await hippocampus_manager.get_memory_from_text( related_memory = await hippocampus_manager.get_memory_from_text(
text=message_txt, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False text=message_txt, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False
@@ -99,22 +95,19 @@ class PromptBuilder:
limit=100, limit=100,
) )
talk_type = message.message_info.platform + ":" + message.chat_stream.user_info.user_id talk_type = message.message_info.platform + ":" + message.chat_stream.user_info.user_id
print(f"talk_type: {talk_type}") print(f"talk_type: {talk_type}")
# 分别筛选核心对话和背景对话 # 分别筛选核心对话和背景对话
core_dialogue_list = [] core_dialogue_list = []
background_dialogue_list = [] background_dialogue_list = []
bot_id = str(global_config.bot.qq_account) bot_id = str(global_config.bot.qq_account)
target_user_id = str(message.chat_stream.user_info.user_id) target_user_id = str(message.chat_stream.user_info.user_id)
for msg_dict in message_list_before_now: for msg_dict in message_list_before_now:
try: try:
# 直接通过字典访问 # 直接通过字典访问
msg_user_id = str(msg_dict.get('user_id')) msg_user_id = str(msg_dict.get("user_id"))
if msg_user_id == bot_id: if msg_user_id == bot_id:
if msg_dict.get("reply_to") and talk_type == msg_dict.get("reply_to"): if msg_dict.get("reply_to") and talk_type == msg_dict.get("reply_to"):
print(f"reply: {msg_dict.get('reply_to')}") print(f"reply: {msg_dict.get('reply_to')}")
@@ -144,7 +137,7 @@ class PromptBuilder:
core_dialogue_list = core_dialogue_list[-50:] core_dialogue_list = core_dialogue_list[-50:]
first_msg = core_dialogue_list[0] first_msg = core_dialogue_list[0]
start_speaking_user_id = first_msg.get('user_id') start_speaking_user_id = first_msg.get("user_id")
if start_speaking_user_id == bot_id: if start_speaking_user_id == bot_id:
last_speaking_user_id = bot_id last_speaking_user_id = bot_id
msg_seg_str = "你的发言:\n" msg_seg_str = "你的发言:\n"
@@ -157,10 +150,12 @@ class PromptBuilder:
all_msg_seg_list = [] all_msg_seg_list = []
for msg in core_dialogue_list[1:]: for msg in core_dialogue_list[1:]:
speaker = msg.get('user_id') speaker = msg.get("user_id")
if speaker == last_speaking_user_id: if speaker == last_speaking_user_id:
# 还是同一个人讲话 # 还是同一个人讲话
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(msg.get('time')))}: {msg.get('processed_plain_text')}\n" msg_seg_str += (
f"{time.strftime('%H:%M:%S', time.localtime(msg.get('time')))}: {msg.get('processed_plain_text')}\n"
)
else: else:
# 换人了 # 换人了
msg_seg_str = f"{msg_seg_str}\n" msg_seg_str = f"{msg_seg_str}\n"
@@ -171,12 +166,13 @@ class PromptBuilder:
else: else:
msg_seg_str = "对方的发言:\n" msg_seg_str = "对方的发言:\n"
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(msg.get('time')))}: {msg.get('processed_plain_text')}\n" msg_seg_str += (
f"{time.strftime('%H:%M:%S', time.localtime(msg.get('time')))}: {msg.get('processed_plain_text')}\n"
)
last_speaking_user_id = speaker last_speaking_user_id = speaker
all_msg_seg_list.append(msg_seg_str) all_msg_seg_list.append(msg_seg_str)
core_msg_str = "" core_msg_str = ""
for msg in all_msg_seg_list: for msg in all_msg_seg_list:
# print(f"msg: {msg}") # print(f"msg: {msg}")

View File

@@ -43,8 +43,8 @@ class S4UStreamGenerator:
# 匹配常见的句子结束符,但会忽略引号内和数字中的标点 # 匹配常见的句子结束符,但会忽略引号内和数字中的标点
self.sentence_split_pattern = re.compile( self.sentence_split_pattern = re.compile(
r'([^\s\w"\'([{]*["\'([{].*?["\'}\])][^\s\w"\'([{]*|' # 匹配被引号/括号包裹的内容 r'([^\s\w"\'([{]*["\'([{].*?["\'}\])][^\s\w"\'([{]*|' # 匹配被引号/括号包裹的内容
r'[^.。!?\n\r]+(?:[.。!?\n\r](?![\'"])|$))' # 匹配直到句子结束符 r'[^.。!?\n\r]+(?:[.。!?\n\r](?![\'"])|$))', # 匹配直到句子结束符
, re.UNICODE | re.DOTALL re.UNICODE | re.DOTALL,
) )
async def generate_response( async def generate_response(
@@ -78,7 +78,6 @@ class S4UStreamGenerator:
else: else:
message_txt = message.processed_plain_text message_txt = message.processed_plain_text
prompt = await prompt_builder.build_prompt_normal( prompt = await prompt_builder.build_prompt_normal(
message=message, message=message,
message_txt=message_txt, message_txt=message_txt,
@@ -132,8 +131,8 @@ class S4UStreamGenerator:
else: else:
# 发送之前累积的标点和当前句子 # 发送之前累积的标点和当前句子
to_yield = punctuation_buffer + sentence to_yield = punctuation_buffer + sentence
if to_yield.endswith((',', '')): if to_yield.endswith((",", "")):
to_yield = to_yield.rstrip(',') to_yield = to_yield.rstrip(",")
yield to_yield yield to_yield
punctuation_buffer = "" # 清空标点符号缓冲区 punctuation_buffer = "" # 清空标点符号缓冲区
@@ -148,8 +147,7 @@ class S4UStreamGenerator:
# 发送缓冲区中剩余的任何内容 # 发送缓冲区中剩余的任何内容
to_yield = (punctuation_buffer + buffer).strip() to_yield = (punctuation_buffer + buffer).strip()
if to_yield: if to_yield:
if to_yield.endswith(('', ',')): if to_yield.endswith(("", ",")):
to_yield = to_yield.rstrip(',') to_yield = to_yield.rstrip(",")
if to_yield: if to_yield:
yield to_yield yield to_yield

View File

@@ -1,8 +1,5 @@
import asyncio from typing import AsyncGenerator, Dict, List, Optional, Union
import json
from typing import AsyncGenerator, Dict, List, Optional, Union, Any
from dataclasses import dataclass from dataclasses import dataclass
import aiohttp
from openai import AsyncOpenAI from openai import AsyncOpenAI
from openai.types.chat import ChatCompletion, ChatCompletionChunk from openai.types.chat import ChatCompletion, ChatCompletionChunk
@@ -10,6 +7,7 @@ from openai.types.chat import ChatCompletion, ChatCompletionChunk
@dataclass @dataclass
class ChatMessage: class ChatMessage:
"""聊天消息数据类""" """聊天消息数据类"""
role: str role: str
content: str content: str
@@ -40,7 +38,7 @@ class AsyncOpenAIClient:
model: str = "gpt-3.5-turbo", model: str = "gpt-3.5-turbo",
temperature: float = 0.7, temperature: float = 0.7,
max_tokens: Optional[int] = None, max_tokens: Optional[int] = None,
**kwargs **kwargs,
) -> ChatCompletion: ) -> ChatCompletion:
""" """
非流式聊天完成 非流式聊天完成
@@ -76,7 +74,7 @@ class AsyncOpenAIClient:
max_tokens=max_tokens, max_tokens=max_tokens,
stream=False, stream=False,
extra_body=extra_body if extra_body else None, extra_body=extra_body if extra_body else None,
**kwargs **kwargs,
) )
return response return response
@@ -87,7 +85,7 @@ class AsyncOpenAIClient:
model: str = "gpt-3.5-turbo", model: str = "gpt-3.5-turbo",
temperature: float = 0.7, temperature: float = 0.7,
max_tokens: Optional[int] = None, max_tokens: Optional[int] = None,
**kwargs **kwargs,
) -> AsyncGenerator[ChatCompletionChunk, None]: ) -> AsyncGenerator[ChatCompletionChunk, None]:
""" """
流式聊天完成 流式聊天完成
@@ -123,7 +121,7 @@ class AsyncOpenAIClient:
max_tokens=max_tokens, max_tokens=max_tokens,
stream=True, stream=True,
extra_body=extra_body if extra_body else None, extra_body=extra_body if extra_body else None,
**kwargs **kwargs,
) )
async for chunk in stream: async for chunk in stream:
@@ -135,7 +133,7 @@ class AsyncOpenAIClient:
model: str = "gpt-3.5-turbo", model: str = "gpt-3.5-turbo",
temperature: float = 0.7, temperature: float = 0.7,
max_tokens: Optional[int] = None, max_tokens: Optional[int] = None,
**kwargs **kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
""" """
获取流式内容(只返回文本内容) 获取流式内容(只返回文本内容)
@@ -151,11 +149,7 @@ class AsyncOpenAIClient:
str: 文本内容片段 str: 文本内容片段
""" """
async for chunk in self.chat_completion_stream( async for chunk in self.chat_completion_stream(
messages=messages, messages=messages, model=model, temperature=temperature, max_tokens=max_tokens, **kwargs
model=model,
temperature=temperature,
max_tokens=max_tokens,
**kwargs
): ):
if chunk.choices and chunk.choices[0].delta.content: if chunk.choices and chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content yield chunk.choices[0].delta.content
@@ -166,7 +160,7 @@ class AsyncOpenAIClient:
model: str = "gpt-3.5-turbo", model: str = "gpt-3.5-turbo",
temperature: float = 0.7, temperature: float = 0.7,
max_tokens: Optional[int] = None, max_tokens: Optional[int] = None,
**kwargs **kwargs,
) -> str: ) -> str:
""" """
收集完整的流式响应 收集完整的流式响应
@@ -183,11 +177,7 @@ class AsyncOpenAIClient:
""" """
full_response = "" full_response = ""
async for content in self.get_stream_content( async for content in self.get_stream_content(
messages=messages, messages=messages, model=model, temperature=temperature, max_tokens=max_tokens, **kwargs
model=model,
temperature=temperature,
max_tokens=max_tokens,
**kwargs
): ):
full_response += content full_response += content
@@ -232,10 +222,7 @@ class ConversationManager:
self.messages.append(ChatMessage(role="assistant", content=content)) self.messages.append(ChatMessage(role="assistant", content=content))
async def send_message_stream( async def send_message_stream(
self, self, content: str, model: str = "gpt-3.5-turbo", **kwargs
content: str,
model: str = "gpt-3.5-turbo",
**kwargs
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
""" """
发送消息并获取流式响应 发送消息并获取流式响应
@@ -251,22 +238,13 @@ class ConversationManager:
self.add_user_message(content) self.add_user_message(content)
response_content = "" response_content = ""
async for chunk in self.client.get_stream_content( async for chunk in self.client.get_stream_content(messages=self.messages, model=model, **kwargs):
messages=self.messages,
model=model,
**kwargs
):
response_content += chunk response_content += chunk
yield chunk yield chunk
self.add_assistant_message(response_content) self.add_assistant_message(response_content)
async def send_message( async def send_message(self, content: str, model: str = "gpt-3.5-turbo", **kwargs) -> str:
self,
content: str,
model: str = "gpt-3.5-turbo",
**kwargs
) -> str:
""" """
发送消息并获取完整响应 发送消息并获取完整响应
@@ -280,11 +258,7 @@ class ConversationManager:
""" """
self.add_user_message(content) self.add_user_message(content)
response = await self.client.chat_completion( response = await self.client.chat_completion(messages=self.messages, model=model, **kwargs)
messages=self.messages,
model=model,
**kwargs
)
response_content = response.choices[0].message.content response_content = response.choices[0].message.content
self.add_assistant_message(response_content) self.add_assistant_message(response_content)