🤖 自动格式化代码 [skip ci]
This commit is contained in:
1
bot.py
1
bot.py
@@ -326,7 +326,6 @@ if __name__ == "__main__":
|
|||||||
# Wait for all tasks to complete (which they won't, normally)
|
# Wait for all tasks to complete (which they won't, normally)
|
||||||
loop.run_until_complete(main_tasks)
|
loop.run_until_complete(main_tasks)
|
||||||
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
# loop.run_until_complete(get_global_api().stop())
|
# loop.run_until_complete(get_global_api().stop())
|
||||||
logger.warning("收到中断信号,正在优雅关闭...")
|
logger.warning("收到中断信号,正在优雅关闭...")
|
||||||
|
|||||||
@@ -3,10 +3,12 @@ from src.common.logger import get_logger
|
|||||||
|
|
||||||
logger = get_logger("MockAudio")
|
logger = get_logger("MockAudio")
|
||||||
|
|
||||||
|
|
||||||
class MockAudioPlayer:
|
class MockAudioPlayer:
|
||||||
"""
|
"""
|
||||||
一个模拟的音频播放器,它会根据音频数据的"长度"来模拟播放时间。
|
一个模拟的音频播放器,它会根据音频数据的"长度"来模拟播放时间。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, audio_data: bytes):
|
def __init__(self, audio_data: bytes):
|
||||||
self._audio_data = audio_data
|
self._audio_data = audio_data
|
||||||
# 模拟音频时长:假设每 1024 字节代表 0.5 秒的音频
|
# 模拟音频时长:假设每 1024 字节代表 0.5 秒的音频
|
||||||
@@ -22,12 +24,14 @@ class MockAudioPlayer:
|
|||||||
logger.info("模拟音频播放完毕。")
|
logger.info("模拟音频播放完毕。")
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.info("音频播放被中断。")
|
logger.info("音频播放被中断。")
|
||||||
raise # 重新抛出异常,以便上层逻辑可以捕获它
|
raise # 重新抛出异常,以便上层逻辑可以捕获它
|
||||||
|
|
||||||
|
|
||||||
class MockAudioGenerator:
|
class MockAudioGenerator:
|
||||||
"""
|
"""
|
||||||
一个模拟的文本到语音(TTS)生成器。
|
一个模拟的文本到语音(TTS)生成器。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# 模拟生成速度:每秒生成的字符数
|
# 模拟生成速度:每秒生成的字符数
|
||||||
self.chars_per_second = 25.0
|
self.chars_per_second = 25.0
|
||||||
@@ -43,16 +47,16 @@ class MockAudioGenerator:
|
|||||||
模拟的音频数据(bytes)。
|
模拟的音频数据(bytes)。
|
||||||
"""
|
"""
|
||||||
if not text:
|
if not text:
|
||||||
return b''
|
return b""
|
||||||
|
|
||||||
generation_time = len(text) / self.chars_per_second
|
generation_time = len(text) / self.chars_per_second
|
||||||
logger.info(f"模拟生成音频... 文本长度: {len(text)}, 预计耗时: {generation_time:.2f} 秒...")
|
logger.info(f"模拟生成音频... 文本长度: {len(text)}, 预计耗时: {generation_time:.2f} 秒...")
|
||||||
try:
|
try:
|
||||||
await asyncio.sleep(generation_time)
|
await asyncio.sleep(generation_time)
|
||||||
# 生成虚拟的音频数据,其长度与文本长度成正比
|
# 生成虚拟的音频数据,其长度与文本长度成正比
|
||||||
mock_audio_data = b'\x01\x02\x03' * (len(text) * 40)
|
mock_audio_data = b"\x01\x02\x03" * (len(text) * 40)
|
||||||
logger.info(f"模拟音频生成完毕,数据大小: {len(mock_audio_data) / 1024:.2f} KB。")
|
logger.info(f"模拟音频生成完毕,数据大小: {len(mock_audio_data) / 1024:.2f} KB。")
|
||||||
return mock_audio_data
|
return mock_audio_data
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.info("音频生成被中断。")
|
logger.info("音频生成被中断。")
|
||||||
raise # 重新抛出异常
|
raise # 重新抛出异常
|
||||||
|
|||||||
@@ -180,14 +180,12 @@ class ChatBot:
|
|||||||
# 如果在私聊中
|
# 如果在私聊中
|
||||||
if group_info is None:
|
if group_info is None:
|
||||||
logger.debug("检测到私聊消息")
|
logger.debug("检测到私聊消息")
|
||||||
|
|
||||||
if ENABLE_S4U_CHAT:
|
if ENABLE_S4U_CHAT:
|
||||||
logger.debug("进入S4U私聊处理流程")
|
logger.debug("进入S4U私聊处理流程")
|
||||||
await self.s4u_message_processor.process_message(message)
|
await self.s4u_message_processor.process_message(message)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if global_config.experimental.pfc_chatting:
|
if global_config.experimental.pfc_chatting:
|
||||||
logger.debug("进入PFC私聊处理流程")
|
logger.debug("进入PFC私聊处理流程")
|
||||||
# 创建聊天流
|
# 创建聊天流
|
||||||
@@ -200,13 +198,11 @@ class ChatBot:
|
|||||||
await self.heartflow_message_receiver.process_message(message)
|
await self.heartflow_message_receiver.process_message(message)
|
||||||
# 群聊默认进入心流消息处理逻辑
|
# 群聊默认进入心流消息处理逻辑
|
||||||
else:
|
else:
|
||||||
|
|
||||||
if ENABLE_S4U_CHAT:
|
if ENABLE_S4U_CHAT:
|
||||||
logger.debug("进入S4U私聊处理流程")
|
logger.debug("进入S4U私聊处理流程")
|
||||||
await self.s4u_message_processor.process_message(message)
|
await self.s4u_message_processor.process_message(message)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
logger.debug(f"检测到群聊消息,群ID: {group_info.group_id}")
|
logger.debug(f"检测到群聊消息,群ID: {group_info.group_id}")
|
||||||
await self.heartflow_message_receiver.process_message(message)
|
await self.heartflow_message_receiver.process_message(message)
|
||||||
|
|
||||||
|
|||||||
@@ -323,7 +323,7 @@ class MessageSending(MessageProcessBase):
|
|||||||
self.is_head = is_head
|
self.is_head = is_head
|
||||||
self.is_emoji = is_emoji
|
self.is_emoji = is_emoji
|
||||||
self.apply_set_reply_logic = apply_set_reply_logic
|
self.apply_set_reply_logic = apply_set_reply_logic
|
||||||
|
|
||||||
self.reply_to = reply_to
|
self.reply_to = reply_to
|
||||||
|
|
||||||
# 用于显示发送内容与显示不一致的情况
|
# 用于显示发送内容与显示不一致的情况
|
||||||
|
|||||||
@@ -35,11 +35,11 @@ class MessageStorage:
|
|||||||
filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
|
filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
|
||||||
else:
|
else:
|
||||||
filtered_display_message = ""
|
filtered_display_message = ""
|
||||||
|
|
||||||
reply_to = message.reply_to
|
reply_to = message.reply_to
|
||||||
else:
|
else:
|
||||||
filtered_display_message = ""
|
filtered_display_message = ""
|
||||||
|
|
||||||
reply_to = ""
|
reply_to = ""
|
||||||
|
|
||||||
chat_info_dict = chat_stream.to_dict()
|
chat_info_dict = chat_stream.to_dict()
|
||||||
|
|||||||
@@ -263,7 +263,6 @@ def _build_readable_messages_internal(
|
|||||||
# 处理图片ID
|
# 处理图片ID
|
||||||
if show_pic:
|
if show_pic:
|
||||||
content = process_pic_ids(content)
|
content = process_pic_ids(content)
|
||||||
|
|
||||||
|
|
||||||
# 检查必要信息是否存在
|
# 检查必要信息是否存在
|
||||||
if not all([platform, user_id, timestamp is not None]):
|
if not all([platform, user_id, timestamp is not None]):
|
||||||
@@ -632,10 +631,17 @@ def build_readable_messages(
|
|||||||
truncate,
|
truncate,
|
||||||
pic_id_mapping,
|
pic_id_mapping,
|
||||||
pic_counter,
|
pic_counter,
|
||||||
show_pic=show_pic
|
show_pic=show_pic,
|
||||||
)
|
)
|
||||||
formatted_after, _, pic_id_mapping, _ = _build_readable_messages_internal(
|
formatted_after, _, pic_id_mapping, _ = _build_readable_messages_internal(
|
||||||
messages_after_mark, replace_bot_name, merge_messages, timestamp_mode, False, pic_id_mapping, pic_counter, show_pic=show_pic
|
messages_after_mark,
|
||||||
|
replace_bot_name,
|
||||||
|
merge_messages,
|
||||||
|
timestamp_mode,
|
||||||
|
False,
|
||||||
|
pic_id_mapping,
|
||||||
|
pic_counter,
|
||||||
|
show_pic=show_pic,
|
||||||
)
|
)
|
||||||
|
|
||||||
read_mark_line = "\n--- 以上消息是你已经看过,请关注以下未读的新消息---\n"
|
read_mark_line = "\n--- 以上消息是你已经看过,请关注以下未读的新消息---\n"
|
||||||
|
|||||||
@@ -126,7 +126,7 @@ class Messages(BaseModel):
|
|||||||
time = DoubleField() # 消息时间戳
|
time = DoubleField() # 消息时间戳
|
||||||
|
|
||||||
chat_id = TextField(index=True) # 对应的 ChatStreams stream_id
|
chat_id = TextField(index=True) # 对应的 ChatStreams stream_id
|
||||||
|
|
||||||
reply_to = TextField(null=True)
|
reply_to = TextField(null=True)
|
||||||
|
|
||||||
# 从 chat_info 扁平化而来的字段
|
# 从 chat_info 扁平化而来的字段
|
||||||
|
|||||||
@@ -1,39 +1,15 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
import traceback
|
|
||||||
import random
|
import random
|
||||||
from typing import List, Optional, Dict # 导入类型提示
|
from typing import Optional, Dict # 导入类型提示
|
||||||
import os
|
|
||||||
import pickle
|
|
||||||
from maim_message import UserInfo, Seg
|
from maim_message import UserInfo, Seg
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info
|
|
||||||
from src.manager.mood_manager import mood_manager
|
|
||||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||||
from src.chat.utils.timer_calculator import Timer
|
|
||||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
|
||||||
from .s4u_stream_generator import S4UStreamGenerator
|
from .s4u_stream_generator import S4UStreamGenerator
|
||||||
from src.chat.message_receive.message import MessageSending, MessageRecv, MessageThinking, MessageSet
|
from src.chat.message_receive.message import MessageSending, MessageRecv
|
||||||
from src.chat.message_receive.message_sender import message_manager
|
|
||||||
from src.chat.normal_chat.willing.willing_manager import get_willing_manager
|
|
||||||
from src.chat.normal_chat.normal_chat_utils import get_recent_message_stats
|
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.chat.focus_chat.planners.action_manager import ActionManager
|
|
||||||
from src.chat.normal_chat.normal_chat_planner import NormalChatPlanner
|
|
||||||
from src.chat.normal_chat.normal_chat_action_modifier import NormalChatActionModifier
|
|
||||||
from src.chat.normal_chat.normal_chat_expressor import NormalChatExpressor
|
|
||||||
from src.chat.focus_chat.replyer.default_generator import DefaultReplyer
|
|
||||||
from src.person_info.person_info import PersonInfoManager
|
|
||||||
from src.person_info.relationship_manager import get_relationship_manager
|
|
||||||
from src.chat.utils.chat_message_builder import (
|
|
||||||
get_raw_msg_by_timestamp_with_chat,
|
|
||||||
get_raw_msg_by_timestamp_with_chat_inclusive,
|
|
||||||
get_raw_msg_before_timestamp_with_chat,
|
|
||||||
num_new_messages_since,
|
|
||||||
)
|
|
||||||
from src.common.message.api import get_global_api
|
from src.common.message.api import get_global_api
|
||||||
from src.chat.message_receive.storage import MessageStorage
|
from src.chat.message_receive.storage import MessageStorage
|
||||||
from src.audio.mock_audio import MockAudioGenerator, MockAudioPlayer
|
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("S4U_chat")
|
logger = get_logger("S4U_chat")
|
||||||
@@ -41,6 +17,7 @@ logger = get_logger("S4U_chat")
|
|||||||
|
|
||||||
class MessageSenderContainer:
|
class MessageSenderContainer:
|
||||||
"""一个简单的容器,用于按顺序发送消息并模拟打字效果。"""
|
"""一个简单的容器,用于按顺序发送消息并模拟打字效果。"""
|
||||||
|
|
||||||
def __init__(self, chat_stream: ChatStream, original_message: MessageRecv):
|
def __init__(self, chat_stream: ChatStream, original_message: MessageRecv):
|
||||||
self.chat_stream = chat_stream
|
self.chat_stream = chat_stream
|
||||||
self.original_message = original_message
|
self.original_message = original_message
|
||||||
@@ -71,7 +48,7 @@ class MessageSenderContainer:
|
|||||||
chars_per_second = 15.0
|
chars_per_second = 15.0
|
||||||
min_delay = 0.2
|
min_delay = 0.2
|
||||||
max_delay = 2.0
|
max_delay = 2.0
|
||||||
|
|
||||||
delay = len(text) / chars_per_second
|
delay = len(text) / chars_per_second
|
||||||
return max(min_delay, min(delay, max_delay))
|
return max(min_delay, min(delay, max_delay))
|
||||||
|
|
||||||
@@ -98,7 +75,7 @@ class MessageSenderContainer:
|
|||||||
|
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
msg_id = f"{current_time}_{random.randint(1000, 9999)}"
|
msg_id = f"{current_time}_{random.randint(1000, 9999)}"
|
||||||
|
|
||||||
text_to_send = chunk
|
text_to_send = chunk
|
||||||
if global_config.experimental.debug_show_chat_mode:
|
if global_config.experimental.debug_show_chat_mode:
|
||||||
text_to_send += "ⁿ"
|
text_to_send += "ⁿ"
|
||||||
@@ -117,19 +94,19 @@ class MessageSenderContainer:
|
|||||||
reply=self.original_message,
|
reply=self.original_message,
|
||||||
is_emoji=False,
|
is_emoji=False,
|
||||||
apply_set_reply_logic=True,
|
apply_set_reply_logic=True,
|
||||||
reply_to=f"{self.original_message.message_info.user_info.platform}:{self.original_message.message_info.user_info.user_id}"
|
reply_to=f"{self.original_message.message_info.user_info.platform}:{self.original_message.message_info.user_info.user_id}",
|
||||||
)
|
)
|
||||||
|
|
||||||
await bot_message.process()
|
await bot_message.process()
|
||||||
|
|
||||||
await get_global_api().send_message(bot_message)
|
await get_global_api().send_message(bot_message)
|
||||||
logger.info(f"已将消息 '{text_to_send}' 发往平台 '{bot_message.message_info.platform}'")
|
logger.info(f"已将消息 '{text_to_send}' 发往平台 '{bot_message.message_info.platform}'")
|
||||||
|
|
||||||
await self.storage.store_message(bot_message, self.chat_stream)
|
await self.storage.store_message(bot_message, self.chat_stream)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[{self.chat_stream.get_stream_name()}] 消息发送或存储时出现错误: {e}", exc_info=True)
|
logger.error(f"[{self.chat_stream.get_stream_name()}] 消息发送或存储时出现错误: {e}", exc_info=True)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# CRUCIAL: Always call task_done() for any item that was successfully retrieved.
|
# CRUCIAL: Always call task_done() for any item that was successfully retrieved.
|
||||||
self.queue.task_done()
|
self.queue.task_done()
|
||||||
@@ -138,7 +115,7 @@ class MessageSenderContainer:
|
|||||||
"""启动发送任务。"""
|
"""启动发送任务。"""
|
||||||
if self._task is None:
|
if self._task is None:
|
||||||
self._task = asyncio.create_task(self._send_worker())
|
self._task = asyncio.create_task(self._send_worker())
|
||||||
|
|
||||||
async def join(self):
|
async def join(self):
|
||||||
"""等待所有消息发送完毕。"""
|
"""等待所有消息发送完毕。"""
|
||||||
if self._task:
|
if self._task:
|
||||||
@@ -156,8 +133,10 @@ class S4UChatManager:
|
|||||||
self.s4u_chats[chat_stream.stream_id] = S4UChat(chat_stream)
|
self.s4u_chats[chat_stream.stream_id] = S4UChat(chat_stream)
|
||||||
return self.s4u_chats[chat_stream.stream_id]
|
return self.s4u_chats[chat_stream.stream_id]
|
||||||
|
|
||||||
|
|
||||||
s4u_chat_manager = S4UChatManager()
|
s4u_chat_manager = S4UChatManager()
|
||||||
|
|
||||||
|
|
||||||
def get_s4u_chat_manager() -> S4UChatManager:
|
def get_s4u_chat_manager() -> S4UChatManager:
|
||||||
return s4u_chat_manager
|
return s4u_chat_manager
|
||||||
|
|
||||||
@@ -169,22 +148,19 @@ class S4UChat:
|
|||||||
self.chat_stream = chat_stream
|
self.chat_stream = chat_stream
|
||||||
self.stream_id = chat_stream.stream_id
|
self.stream_id = chat_stream.stream_id
|
||||||
self.stream_name = get_chat_manager().get_stream_name(self.stream_id) or self.stream_id
|
self.stream_name = get_chat_manager().get_stream_name(self.stream_id) or self.stream_id
|
||||||
|
|
||||||
self._message_queue = asyncio.Queue()
|
self._message_queue = asyncio.Queue()
|
||||||
self._processing_task = asyncio.create_task(self._message_processor())
|
self._processing_task = asyncio.create_task(self._message_processor())
|
||||||
self._current_generation_task: Optional[asyncio.Task] = None
|
self._current_generation_task: Optional[asyncio.Task] = None
|
||||||
self._current_message_being_replied: Optional[MessageRecv] = None
|
self._current_message_being_replied: Optional[MessageRecv] = None
|
||||||
|
|
||||||
self._is_replying = False
|
self._is_replying = False
|
||||||
|
|
||||||
self.gpt = S4UStreamGenerator()
|
self.gpt = S4UStreamGenerator()
|
||||||
# self.audio_generator = MockAudioGenerator()
|
# self.audio_generator = MockAudioGenerator()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
logger.info(f"[{self.stream_name}] S4UChat")
|
logger.info(f"[{self.stream_name}] S4UChat")
|
||||||
|
|
||||||
|
|
||||||
# 改为实例方法, 移除 chat 参数
|
# 改为实例方法, 移除 chat 参数
|
||||||
async def response(self, message: MessageRecv, is_mentioned: bool, interested_rate: float) -> None:
|
async def response(self, message: MessageRecv, is_mentioned: bool, interested_rate: float) -> None:
|
||||||
"""将消息放入队列并根据发信人决定是否中断当前处理。"""
|
"""将消息放入队列并根据发信人决定是否中断当前处理。"""
|
||||||
@@ -226,8 +202,8 @@ class S4UChat:
|
|||||||
# 如果因快速中断导致队列中积压了更多消息,则只处理最新的一条
|
# 如果因快速中断导致队列中积压了更多消息,则只处理最新的一条
|
||||||
while not self._message_queue.empty():
|
while not self._message_queue.empty():
|
||||||
drained_msg = self._message_queue.get_nowait()
|
drained_msg = self._message_queue.get_nowait()
|
||||||
self._message_queue.task_done() # 为取出的旧消息调用 task_done
|
self._message_queue.task_done() # 为取出的旧消息调用 task_done
|
||||||
message = drained_msg # 始终处理最新消息
|
message = drained_msg # 始终处理最新消息
|
||||||
self._current_message_being_replied = message
|
self._current_message_being_replied = message
|
||||||
logger.info(f"[{self.stream_name}] 丢弃过时消息,处理最新消息: {message.processed_plain_text}")
|
logger.info(f"[{self.stream_name}] 丢弃过时消息,处理最新消息: {message.processed_plain_text}")
|
||||||
|
|
||||||
@@ -242,44 +218,40 @@ class S4UChat:
|
|||||||
finally:
|
finally:
|
||||||
self._current_generation_task = None
|
self._current_generation_task = None
|
||||||
self._current_message_being_replied = None
|
self._current_message_being_replied = None
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.info(f"[{self.stream_name}] 消息处理器正在关闭。")
|
logger.info(f"[{self.stream_name}] 消息处理器正在关闭。")
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[{self.stream_name}] 消息处理器主循环发生未知错误: {e}", exc_info=True)
|
logger.error(f"[{self.stream_name}] 消息处理器主循环发生未知错误: {e}", exc_info=True)
|
||||||
await asyncio.sleep(1) # 避免在未知错误下陷入CPU空转
|
await asyncio.sleep(1) # 避免在未知错误下陷入CPU空转
|
||||||
finally:
|
finally:
|
||||||
# 确保处理过的消息(无论是正常完成还是被丢弃)都被标记完成
|
# 确保处理过的消息(无论是正常完成还是被丢弃)都被标记完成
|
||||||
if 'message' in locals():
|
if "message" in locals():
|
||||||
self._message_queue.task_done()
|
self._message_queue.task_done()
|
||||||
|
|
||||||
|
|
||||||
async def _generate_and_send(self, message: MessageRecv):
|
async def _generate_and_send(self, message: MessageRecv):
|
||||||
"""为单个消息生成文本和音频回复。整个过程可以被中断。"""
|
"""为单个消息生成文本和音频回复。整个过程可以被中断。"""
|
||||||
self._is_replying = True
|
self._is_replying = True
|
||||||
sender_container = MessageSenderContainer(self.chat_stream, message)
|
sender_container = MessageSenderContainer(self.chat_stream, message)
|
||||||
sender_container.start()
|
sender_container.start()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(
|
logger.info(f"[S4U] 开始为消息生成文本和音频流: '{message.processed_plain_text[:30]}...'")
|
||||||
f"[S4U] 开始为消息生成文本和音频流: "
|
|
||||||
f"'{message.processed_plain_text[:30]}...'"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 1. 逐句生成文本、发送并播放音频
|
# 1. 逐句生成文本、发送并播放音频
|
||||||
gen = self.gpt.generate_response(message, "")
|
gen = self.gpt.generate_response(message, "")
|
||||||
async for chunk in gen:
|
async for chunk in gen:
|
||||||
# 如果任务被取消,await 会在此处引发 CancelledError
|
# 如果任务被取消,await 会在此处引发 CancelledError
|
||||||
|
|
||||||
# a. 发送文本块
|
# a. 发送文本块
|
||||||
await sender_container.add_message(chunk)
|
await sender_container.add_message(chunk)
|
||||||
|
|
||||||
# b. 为该文本块生成并播放音频
|
# b. 为该文本块生成并播放音频
|
||||||
# if chunk.strip():
|
# if chunk.strip():
|
||||||
# audio_data = await self.audio_generator.generate(chunk)
|
# audio_data = await self.audio_generator.generate(chunk)
|
||||||
# player = MockAudioPlayer(audio_data)
|
# player = MockAudioPlayer(audio_data)
|
||||||
# await player.play()
|
# await player.play()
|
||||||
|
|
||||||
# 等待所有文本消息发送完成
|
# 等待所有文本消息发送完成
|
||||||
await sender_container.close()
|
await sender_container.close()
|
||||||
@@ -300,20 +272,19 @@ class S4UChat:
|
|||||||
await sender_container.join()
|
await sender_container.join()
|
||||||
logger.info(f"[{self.stream_name}] _generate_and_send 任务结束,资源已清理。")
|
logger.info(f"[{self.stream_name}] _generate_and_send 任务结束,资源已清理。")
|
||||||
|
|
||||||
|
|
||||||
async def shutdown(self):
|
async def shutdown(self):
|
||||||
"""平滑关闭处理任务。"""
|
"""平滑关闭处理任务。"""
|
||||||
logger.info(f"正在关闭 S4UChat: {self.stream_name}")
|
logger.info(f"正在关闭 S4UChat: {self.stream_name}")
|
||||||
|
|
||||||
# 取消正在运行的任务
|
# 取消正在运行的任务
|
||||||
if self._current_generation_task and not self._current_generation_task.done():
|
if self._current_generation_task and not self._current_generation_task.done():
|
||||||
self._current_generation_task.cancel()
|
self._current_generation_task.cancel()
|
||||||
|
|
||||||
if self._processing_task and not self._processing_task.done():
|
if self._processing_task and not self._processing_task.done():
|
||||||
self._processing_task.cancel()
|
self._processing_task.cancel()
|
||||||
|
|
||||||
# 等待任务响应取消
|
# 等待任务响应取消
|
||||||
try:
|
try:
|
||||||
await self._processing_task
|
await self._processing_task
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.info(f"处理任务已成功取消: {self.stream_name}")
|
logger.info(f"处理任务已成功取消: {self.stream_name}")
|
||||||
|
|||||||
@@ -1,21 +1,10 @@
|
|||||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
|
||||||
from src.config.config import global_config
|
|
||||||
from src.chat.message_receive.message import MessageRecv
|
from src.chat.message_receive.message import MessageRecv
|
||||||
from src.chat.message_receive.storage import MessageStorage
|
from src.chat.message_receive.storage import MessageStorage
|
||||||
from src.chat.heart_flow.heartflow import heartflow
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream
|
|
||||||
from src.chat.utils.utils import is_mentioned_bot_in_message
|
from src.chat.utils.utils import is_mentioned_bot_in_message
|
||||||
from src.chat.utils.timer_calculator import Timer
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from .s4u_chat import get_s4u_chat_manager
|
from .s4u_chat import get_s4u_chat_manager
|
||||||
|
|
||||||
import math
|
|
||||||
import re
|
|
||||||
import traceback
|
|
||||||
from typing import Optional, Tuple
|
|
||||||
from maim_message import UserInfo
|
|
||||||
|
|
||||||
from src.person_info.relationship_manager import get_relationship_manager
|
|
||||||
|
|
||||||
# from ..message_receive.message_buffer import message_buffer
|
# from ..message_receive.message_buffer import message_buffer
|
||||||
|
|
||||||
@@ -44,7 +33,7 @@ class S4UMessageProcessor:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
target_user_id_list = ["1026294844", "964959351"]
|
target_user_id_list = ["1026294844", "964959351"]
|
||||||
|
|
||||||
# 1. 消息解析与初始化
|
# 1. 消息解析与初始化
|
||||||
groupinfo = message.message_info.group_info
|
groupinfo = message.message_info.group_info
|
||||||
userinfo = message.message_info.user_info
|
userinfo = message.message_info.user_info
|
||||||
@@ -60,7 +49,7 @@ class S4UMessageProcessor:
|
|||||||
|
|
||||||
is_mentioned = is_mentioned_bot_in_message(message)
|
is_mentioned = is_mentioned_bot_in_message(message)
|
||||||
s4u_chat = get_s4u_chat_manager().get_or_create_chat(chat)
|
s4u_chat = get_s4u_chat_manager().get_or_create_chat(chat)
|
||||||
|
|
||||||
if userinfo.user_id in target_user_id_list:
|
if userinfo.user_id in target_user_id_list:
|
||||||
await s4u_chat.response(message, is_mentioned=is_mentioned, interested_rate=1.0)
|
await s4u_chat.response(message, is_mentioned=is_mentioned, interested_rate=1.0)
|
||||||
else:
|
else:
|
||||||
@@ -68,4 +57,3 @@ class S4UMessageProcessor:
|
|||||||
|
|
||||||
# 7. 日志记录
|
# 7. 日志记录
|
||||||
logger.info(f"[S4U]{userinfo.user_nickname}:{message.processed_plain_text}")
|
logger.info(f"[S4U]{userinfo.user_nickname}:{message.processed_plain_text}")
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,8 @@
|
|||||||
|
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.individuality.individuality import get_individuality
|
from src.individuality.individuality import get_individuality
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
|
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
|
||||||
from src.chat.message_receive.message import MessageRecv
|
|
||||||
import time
|
import time
|
||||||
from src.chat.utils.utils import get_recent_group_speaker
|
from src.chat.utils.utils import get_recent_group_speaker
|
||||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||||
@@ -23,7 +21,6 @@ def init_prompt():
|
|||||||
|
|
||||||
Prompt("\n你有以下这些**知识**:\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n", "knowledge_prompt")
|
Prompt("\n你有以下这些**知识**:\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n", "knowledge_prompt")
|
||||||
|
|
||||||
|
|
||||||
Prompt(
|
Prompt(
|
||||||
"""
|
"""
|
||||||
你的名字叫{bot_name},昵称是:{bot_other_names},{prompt_personality}。
|
你的名字叫{bot_name},昵称是:{bot_other_names},{prompt_personality}。
|
||||||
@@ -79,7 +76,6 @@ class PromptBuilder:
|
|||||||
relationship_manager = get_relationship_manager()
|
relationship_manager = get_relationship_manager()
|
||||||
relation_prompt += await relationship_manager.build_relationship_info(person)
|
relation_prompt += await relationship_manager.build_relationship_info(person)
|
||||||
|
|
||||||
|
|
||||||
memory_prompt = ""
|
memory_prompt = ""
|
||||||
related_memory = await hippocampus_manager.get_memory_from_text(
|
related_memory = await hippocampus_manager.get_memory_from_text(
|
||||||
text=message_txt, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False
|
text=message_txt, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False
|
||||||
@@ -98,23 +94,20 @@ class PromptBuilder:
|
|||||||
timestamp=time.time(),
|
timestamp=time.time(),
|
||||||
limit=100,
|
limit=100,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
talk_type = message.message_info.platform + ":" + message.chat_stream.user_info.user_id
|
talk_type = message.message_info.platform + ":" + message.chat_stream.user_info.user_id
|
||||||
print(f"talk_type: {talk_type}")
|
print(f"talk_type: {talk_type}")
|
||||||
|
|
||||||
|
|
||||||
# 分别筛选核心对话和背景对话
|
# 分别筛选核心对话和背景对话
|
||||||
core_dialogue_list = []
|
core_dialogue_list = []
|
||||||
background_dialogue_list = []
|
background_dialogue_list = []
|
||||||
bot_id = str(global_config.bot.qq_account)
|
bot_id = str(global_config.bot.qq_account)
|
||||||
target_user_id = str(message.chat_stream.user_info.user_id)
|
target_user_id = str(message.chat_stream.user_info.user_id)
|
||||||
|
|
||||||
|
|
||||||
for msg_dict in message_list_before_now:
|
for msg_dict in message_list_before_now:
|
||||||
try:
|
try:
|
||||||
# 直接通过字典访问
|
# 直接通过字典访问
|
||||||
msg_user_id = str(msg_dict.get('user_id'))
|
msg_user_id = str(msg_dict.get("user_id"))
|
||||||
if msg_user_id == bot_id:
|
if msg_user_id == bot_id:
|
||||||
if msg_dict.get("reply_to") and talk_type == msg_dict.get("reply_to"):
|
if msg_dict.get("reply_to") and talk_type == msg_dict.get("reply_to"):
|
||||||
print(f"reply: {msg_dict.get('reply_to')}")
|
print(f"reply: {msg_dict.get('reply_to')}")
|
||||||
@@ -127,24 +120,24 @@ class PromptBuilder:
|
|||||||
background_dialogue_list.append(msg_dict)
|
background_dialogue_list.append(msg_dict)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"无法处理历史消息记录: {msg_dict}, 错误: {e}")
|
logger.error(f"无法处理历史消息记录: {msg_dict}, 错误: {e}")
|
||||||
|
|
||||||
if background_dialogue_list:
|
if background_dialogue_list:
|
||||||
latest_25_msgs = background_dialogue_list[-25:]
|
latest_25_msgs = background_dialogue_list[-25:]
|
||||||
background_dialogue_prompt = build_readable_messages(
|
background_dialogue_prompt = build_readable_messages(
|
||||||
latest_25_msgs,
|
latest_25_msgs,
|
||||||
merge_messages=True,
|
merge_messages=True,
|
||||||
timestamp_mode = "normal_no_YMD",
|
timestamp_mode="normal_no_YMD",
|
||||||
show_pic = False,
|
show_pic=False,
|
||||||
)
|
)
|
||||||
background_dialogue_prompt = f"这是其他用户的发言:\n{background_dialogue_prompt}"
|
background_dialogue_prompt = f"这是其他用户的发言:\n{background_dialogue_prompt}"
|
||||||
else:
|
else:
|
||||||
background_dialogue_prompt = ""
|
background_dialogue_prompt = ""
|
||||||
|
|
||||||
# 分别获取最新50条和最新25条(从message_list_before_now截取)
|
# 分别获取最新50条和最新25条(从message_list_before_now截取)
|
||||||
core_dialogue_list = core_dialogue_list[-50:]
|
core_dialogue_list = core_dialogue_list[-50:]
|
||||||
|
|
||||||
first_msg = core_dialogue_list[0]
|
first_msg = core_dialogue_list[0]
|
||||||
start_speaking_user_id = first_msg.get('user_id')
|
start_speaking_user_id = first_msg.get("user_id")
|
||||||
if start_speaking_user_id == bot_id:
|
if start_speaking_user_id == bot_id:
|
||||||
last_speaking_user_id = bot_id
|
last_speaking_user_id = bot_id
|
||||||
msg_seg_str = "你的发言:\n"
|
msg_seg_str = "你的发言:\n"
|
||||||
@@ -152,30 +145,33 @@ class PromptBuilder:
|
|||||||
start_speaking_user_id = target_user_id
|
start_speaking_user_id = target_user_id
|
||||||
last_speaking_user_id = start_speaking_user_id
|
last_speaking_user_id = start_speaking_user_id
|
||||||
msg_seg_str = "对方的发言:\n"
|
msg_seg_str = "对方的发言:\n"
|
||||||
|
|
||||||
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(first_msg.get('time')))}: {first_msg.get('processed_plain_text')}\n"
|
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(first_msg.get('time')))}: {first_msg.get('processed_plain_text')}\n"
|
||||||
|
|
||||||
all_msg_seg_list = []
|
all_msg_seg_list = []
|
||||||
for msg in core_dialogue_list[1:]:
|
for msg in core_dialogue_list[1:]:
|
||||||
speaker = msg.get('user_id')
|
speaker = msg.get("user_id")
|
||||||
if speaker == last_speaking_user_id:
|
if speaker == last_speaking_user_id:
|
||||||
#还是同一个人讲话
|
# 还是同一个人讲话
|
||||||
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(msg.get('time')))}: {msg.get('processed_plain_text')}\n"
|
msg_seg_str += (
|
||||||
|
f"{time.strftime('%H:%M:%S', time.localtime(msg.get('time')))}: {msg.get('processed_plain_text')}\n"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
#换人了
|
# 换人了
|
||||||
msg_seg_str = f"{msg_seg_str}\n"
|
msg_seg_str = f"{msg_seg_str}\n"
|
||||||
all_msg_seg_list.append(msg_seg_str)
|
all_msg_seg_list.append(msg_seg_str)
|
||||||
|
|
||||||
if speaker == bot_id:
|
if speaker == bot_id:
|
||||||
msg_seg_str = "你的发言:\n"
|
msg_seg_str = "你的发言:\n"
|
||||||
else:
|
else:
|
||||||
msg_seg_str = "对方的发言:\n"
|
msg_seg_str = "对方的发言:\n"
|
||||||
|
|
||||||
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(msg.get('time')))}: {msg.get('processed_plain_text')}\n"
|
|
||||||
last_speaking_user_id = speaker
|
|
||||||
|
|
||||||
all_msg_seg_list.append(msg_seg_str)
|
|
||||||
|
|
||||||
|
msg_seg_str += (
|
||||||
|
f"{time.strftime('%H:%M:%S', time.localtime(msg.get('time')))}: {msg.get('processed_plain_text')}\n"
|
||||||
|
)
|
||||||
|
last_speaking_user_id = speaker
|
||||||
|
|
||||||
|
all_msg_seg_list.append(msg_seg_str)
|
||||||
|
|
||||||
core_msg_str = ""
|
core_msg_str = ""
|
||||||
for msg in all_msg_seg_list:
|
for msg in all_msg_seg_list:
|
||||||
|
|||||||
@@ -43,8 +43,8 @@ class S4UStreamGenerator:
|
|||||||
# 匹配常见的句子结束符,但会忽略引号内和数字中的标点
|
# 匹配常见的句子结束符,但会忽略引号内和数字中的标点
|
||||||
self.sentence_split_pattern = re.compile(
|
self.sentence_split_pattern = re.compile(
|
||||||
r'([^\s\w"\'([{]*["\'([{].*?["\'}\])][^\s\w"\'([{]*|' # 匹配被引号/括号包裹的内容
|
r'([^\s\w"\'([{]*["\'([{].*?["\'}\])][^\s\w"\'([{]*|' # 匹配被引号/括号包裹的内容
|
||||||
r'[^.。!??!\n\r]+(?:[.。!??!\n\r](?![\'"])|$))' # 匹配直到句子结束符
|
r'[^.。!??!\n\r]+(?:[.。!??!\n\r](?![\'"])|$))', # 匹配直到句子结束符
|
||||||
, re.UNICODE | re.DOTALL
|
re.UNICODE | re.DOTALL,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def generate_response(
|
async def generate_response(
|
||||||
@@ -68,7 +68,7 @@ class S4UStreamGenerator:
|
|||||||
|
|
||||||
# 构建prompt
|
# 构建prompt
|
||||||
if previous_reply_context:
|
if previous_reply_context:
|
||||||
message_txt = f"""
|
message_txt = f"""
|
||||||
你正在回复用户的消息,但中途被打断了。这是已有的对话上下文:
|
你正在回复用户的消息,但中途被打断了。这是已有的对话上下文:
|
||||||
[你已经对上一条消息说的话]: {previous_reply_context}
|
[你已经对上一条消息说的话]: {previous_reply_context}
|
||||||
---
|
---
|
||||||
@@ -78,9 +78,8 @@ class S4UStreamGenerator:
|
|||||||
else:
|
else:
|
||||||
message_txt = message.processed_plain_text
|
message_txt = message.processed_plain_text
|
||||||
|
|
||||||
|
|
||||||
prompt = await prompt_builder.build_prompt_normal(
|
prompt = await prompt_builder.build_prompt_normal(
|
||||||
message = message,
|
message=message,
|
||||||
message_txt=message_txt,
|
message_txt=message_txt,
|
||||||
sender_name=sender_name,
|
sender_name=sender_name,
|
||||||
chat_stream=message.chat_stream,
|
chat_stream=message.chat_stream,
|
||||||
@@ -109,16 +108,16 @@ class S4UStreamGenerator:
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
print(prompt)
|
print(prompt)
|
||||||
|
|
||||||
buffer = ""
|
buffer = ""
|
||||||
delimiters = ",。!?,.!?\n\r" # For final trimming
|
delimiters = ",。!?,.!?\n\r" # For final trimming
|
||||||
punctuation_buffer = ""
|
punctuation_buffer = ""
|
||||||
|
|
||||||
async for content in client.get_stream_content(
|
async for content in client.get_stream_content(
|
||||||
messages=[{"role": "user", "content": prompt}], model=model_name, **kwargs
|
messages=[{"role": "user", "content": prompt}], model=model_name, **kwargs
|
||||||
):
|
):
|
||||||
buffer += content
|
buffer += content
|
||||||
|
|
||||||
# 使用正则表达式匹配句子
|
# 使用正则表达式匹配句子
|
||||||
last_match_end = 0
|
last_match_end = 0
|
||||||
for match in self.sentence_split_pattern.finditer(buffer):
|
for match in self.sentence_split_pattern.finditer(buffer):
|
||||||
@@ -132,24 +131,23 @@ class S4UStreamGenerator:
|
|||||||
else:
|
else:
|
||||||
# 发送之前累积的标点和当前句子
|
# 发送之前累积的标点和当前句子
|
||||||
to_yield = punctuation_buffer + sentence
|
to_yield = punctuation_buffer + sentence
|
||||||
if to_yield.endswith((',', ',')):
|
if to_yield.endswith((",", ",")):
|
||||||
to_yield = to_yield.rstrip(',,')
|
to_yield = to_yield.rstrip(",,")
|
||||||
|
|
||||||
yield to_yield
|
yield to_yield
|
||||||
punctuation_buffer = "" # 清空标点符号缓冲区
|
punctuation_buffer = "" # 清空标点符号缓冲区
|
||||||
await asyncio.sleep(0) # 允许其他任务运行
|
await asyncio.sleep(0) # 允许其他任务运行
|
||||||
|
|
||||||
last_match_end = match.end(0)
|
last_match_end = match.end(0)
|
||||||
|
|
||||||
# 从缓冲区移除已发送的部分
|
# 从缓冲区移除已发送的部分
|
||||||
if last_match_end > 0:
|
if last_match_end > 0:
|
||||||
buffer = buffer[last_match_end:]
|
buffer = buffer[last_match_end:]
|
||||||
|
|
||||||
# 发送缓冲区中剩余的任何内容
|
# 发送缓冲区中剩余的任何内容
|
||||||
to_yield = (punctuation_buffer + buffer).strip()
|
to_yield = (punctuation_buffer + buffer).strip()
|
||||||
if to_yield:
|
if to_yield:
|
||||||
if to_yield.endswith((',', ',')):
|
if to_yield.endswith((",", ",")):
|
||||||
to_yield = to_yield.rstrip(',,')
|
to_yield = to_yield.rstrip(",,")
|
||||||
if to_yield:
|
if to_yield:
|
||||||
yield to_yield
|
yield to_yield
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,5 @@
|
|||||||
import asyncio
|
from typing import AsyncGenerator, Dict, List, Optional, Union
|
||||||
import json
|
|
||||||
from typing import AsyncGenerator, Dict, List, Optional, Union, Any
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import aiohttp
|
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||||
|
|
||||||
@@ -10,20 +7,21 @@ from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ChatMessage:
|
class ChatMessage:
|
||||||
"""聊天消息数据类"""
|
"""聊天消息数据类"""
|
||||||
|
|
||||||
role: str
|
role: str
|
||||||
content: str
|
content: str
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, str]:
|
def to_dict(self) -> Dict[str, str]:
|
||||||
return {"role": self.role, "content": self.content}
|
return {"role": self.role, "content": self.content}
|
||||||
|
|
||||||
|
|
||||||
class AsyncOpenAIClient:
|
class AsyncOpenAIClient:
|
||||||
"""异步OpenAI客户端,支持流式传输"""
|
"""异步OpenAI客户端,支持流式传输"""
|
||||||
|
|
||||||
def __init__(self, api_key: str, base_url: Optional[str] = None):
|
def __init__(self, api_key: str, base_url: Optional[str] = None):
|
||||||
"""
|
"""
|
||||||
初始化客户端
|
初始化客户端
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
api_key: OpenAI API密钥
|
api_key: OpenAI API密钥
|
||||||
base_url: 可选的API基础URL,用于自定义端点
|
base_url: 可选的API基础URL,用于自定义端点
|
||||||
@@ -33,25 +31,25 @@ class AsyncOpenAIClient:
|
|||||||
base_url=base_url,
|
base_url=base_url,
|
||||||
timeout=10.0, # 设置60秒的全局超时
|
timeout=10.0, # 设置60秒的全局超时
|
||||||
)
|
)
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
messages: List[Union[ChatMessage, Dict[str, str]]],
|
messages: List[Union[ChatMessage, Dict[str, str]]],
|
||||||
model: str = "gpt-3.5-turbo",
|
model: str = "gpt-3.5-turbo",
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: Optional[int] = None,
|
||||||
**kwargs
|
**kwargs,
|
||||||
) -> ChatCompletion:
|
) -> ChatCompletion:
|
||||||
"""
|
"""
|
||||||
非流式聊天完成
|
非流式聊天完成
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: 消息列表
|
messages: 消息列表
|
||||||
model: 模型名称
|
model: 模型名称
|
||||||
temperature: 温度参数
|
temperature: 温度参数
|
||||||
max_tokens: 最大token数
|
max_tokens: 最大token数
|
||||||
**kwargs: 其他参数
|
**kwargs: 其他参数
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
完整的聊天回复
|
完整的聊天回复
|
||||||
"""
|
"""
|
||||||
@@ -62,7 +60,7 @@ class AsyncOpenAIClient:
|
|||||||
formatted_messages.append(msg.to_dict())
|
formatted_messages.append(msg.to_dict())
|
||||||
else:
|
else:
|
||||||
formatted_messages.append(msg)
|
formatted_messages.append(msg)
|
||||||
|
|
||||||
extra_body = {}
|
extra_body = {}
|
||||||
if kwargs.get("enable_thinking") is not None:
|
if kwargs.get("enable_thinking") is not None:
|
||||||
extra_body["enable_thinking"] = kwargs.pop("enable_thinking")
|
extra_body["enable_thinking"] = kwargs.pop("enable_thinking")
|
||||||
@@ -76,29 +74,29 @@ class AsyncOpenAIClient:
|
|||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
stream=False,
|
stream=False,
|
||||||
extra_body=extra_body if extra_body else None,
|
extra_body=extra_body if extra_body else None,
|
||||||
**kwargs
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def chat_completion_stream(
|
async def chat_completion_stream(
|
||||||
self,
|
self,
|
||||||
messages: List[Union[ChatMessage, Dict[str, str]]],
|
messages: List[Union[ChatMessage, Dict[str, str]]],
|
||||||
model: str = "gpt-3.5-turbo",
|
model: str = "gpt-3.5-turbo",
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: Optional[int] = None,
|
||||||
**kwargs
|
**kwargs,
|
||||||
) -> AsyncGenerator[ChatCompletionChunk, None]:
|
) -> AsyncGenerator[ChatCompletionChunk, None]:
|
||||||
"""
|
"""
|
||||||
流式聊天完成
|
流式聊天完成
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: 消息列表
|
messages: 消息列表
|
||||||
model: 模型名称
|
model: 模型名称
|
||||||
temperature: 温度参数
|
temperature: 温度参数
|
||||||
max_tokens: 最大token数
|
max_tokens: 最大token数
|
||||||
**kwargs: 其他参数
|
**kwargs: 其他参数
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
ChatCompletionChunk: 流式响应块
|
ChatCompletionChunk: 流式响应块
|
||||||
"""
|
"""
|
||||||
@@ -109,7 +107,7 @@ class AsyncOpenAIClient:
|
|||||||
formatted_messages.append(msg.to_dict())
|
formatted_messages.append(msg.to_dict())
|
||||||
else:
|
else:
|
||||||
formatted_messages.append(msg)
|
formatted_messages.append(msg)
|
||||||
|
|
||||||
extra_body = {}
|
extra_body = {}
|
||||||
if kwargs.get("enable_thinking") is not None:
|
if kwargs.get("enable_thinking") is not None:
|
||||||
extra_body["enable_thinking"] = kwargs.pop("enable_thinking")
|
extra_body["enable_thinking"] = kwargs.pop("enable_thinking")
|
||||||
@@ -123,84 +121,76 @@ class AsyncOpenAIClient:
|
|||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
stream=True,
|
stream=True,
|
||||||
extra_body=extra_body if extra_body else None,
|
extra_body=extra_body if extra_body else None,
|
||||||
**kwargs
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def get_stream_content(
|
async def get_stream_content(
|
||||||
self,
|
self,
|
||||||
messages: List[Union[ChatMessage, Dict[str, str]]],
|
messages: List[Union[ChatMessage, Dict[str, str]]],
|
||||||
model: str = "gpt-3.5-turbo",
|
model: str = "gpt-3.5-turbo",
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: Optional[int] = None,
|
||||||
**kwargs
|
**kwargs,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""
|
"""
|
||||||
获取流式内容(只返回文本内容)
|
获取流式内容(只返回文本内容)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: 消息列表
|
messages: 消息列表
|
||||||
model: 模型名称
|
model: 模型名称
|
||||||
temperature: 温度参数
|
temperature: 温度参数
|
||||||
max_tokens: 最大token数
|
max_tokens: 最大token数
|
||||||
**kwargs: 其他参数
|
**kwargs: 其他参数
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
str: 文本内容片段
|
str: 文本内容片段
|
||||||
"""
|
"""
|
||||||
async for chunk in self.chat_completion_stream(
|
async for chunk in self.chat_completion_stream(
|
||||||
messages=messages,
|
messages=messages, model=model, temperature=temperature, max_tokens=max_tokens, **kwargs
|
||||||
model=model,
|
|
||||||
temperature=temperature,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
**kwargs
|
|
||||||
):
|
):
|
||||||
if chunk.choices and chunk.choices[0].delta.content:
|
if chunk.choices and chunk.choices[0].delta.content:
|
||||||
yield chunk.choices[0].delta.content
|
yield chunk.choices[0].delta.content
|
||||||
|
|
||||||
async def collect_stream_response(
|
async def collect_stream_response(
|
||||||
self,
|
self,
|
||||||
messages: List[Union[ChatMessage, Dict[str, str]]],
|
messages: List[Union[ChatMessage, Dict[str, str]]],
|
||||||
model: str = "gpt-3.5-turbo",
|
model: str = "gpt-3.5-turbo",
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
max_tokens: Optional[int] = None,
|
max_tokens: Optional[int] = None,
|
||||||
**kwargs
|
**kwargs,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
收集完整的流式响应
|
收集完整的流式响应
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: 消息列表
|
messages: 消息列表
|
||||||
model: 模型名称
|
model: 模型名称
|
||||||
temperature: 温度参数
|
temperature: 温度参数
|
||||||
max_tokens: 最大token数
|
max_tokens: 最大token数
|
||||||
**kwargs: 其他参数
|
**kwargs: 其他参数
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: 完整的响应文本
|
str: 完整的响应文本
|
||||||
"""
|
"""
|
||||||
full_response = ""
|
full_response = ""
|
||||||
async for content in self.get_stream_content(
|
async for content in self.get_stream_content(
|
||||||
messages=messages,
|
messages=messages, model=model, temperature=temperature, max_tokens=max_tokens, **kwargs
|
||||||
model=model,
|
|
||||||
temperature=temperature,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
**kwargs
|
|
||||||
):
|
):
|
||||||
full_response += content
|
full_response += content
|
||||||
|
|
||||||
return full_response
|
return full_response
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
"""关闭客户端"""
|
"""关闭客户端"""
|
||||||
await self.client.close()
|
await self.client.close()
|
||||||
|
|
||||||
async def __aenter__(self):
|
async def __aenter__(self):
|
||||||
"""异步上下文管理器入口"""
|
"""异步上下文管理器入口"""
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
"""异步上下文管理器退出"""
|
"""异步上下文管理器退出"""
|
||||||
await self.close()
|
await self.close()
|
||||||
@@ -208,93 +198,77 @@ class AsyncOpenAIClient:
|
|||||||
|
|
||||||
class ConversationManager:
|
class ConversationManager:
|
||||||
"""对话管理器,用于管理对话历史"""
|
"""对话管理器,用于管理对话历史"""
|
||||||
|
|
||||||
def __init__(self, client: AsyncOpenAIClient, system_prompt: Optional[str] = None):
|
def __init__(self, client: AsyncOpenAIClient, system_prompt: Optional[str] = None):
|
||||||
"""
|
"""
|
||||||
初始化对话管理器
|
初始化对话管理器
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
client: OpenAI客户端实例
|
client: OpenAI客户端实例
|
||||||
system_prompt: 系统提示词
|
system_prompt: 系统提示词
|
||||||
"""
|
"""
|
||||||
self.client = client
|
self.client = client
|
||||||
self.messages: List[ChatMessage] = []
|
self.messages: List[ChatMessage] = []
|
||||||
|
|
||||||
if system_prompt:
|
if system_prompt:
|
||||||
self.messages.append(ChatMessage(role="system", content=system_prompt))
|
self.messages.append(ChatMessage(role="system", content=system_prompt))
|
||||||
|
|
||||||
def add_user_message(self, content: str):
|
def add_user_message(self, content: str):
|
||||||
"""添加用户消息"""
|
"""添加用户消息"""
|
||||||
self.messages.append(ChatMessage(role="user", content=content))
|
self.messages.append(ChatMessage(role="user", content=content))
|
||||||
|
|
||||||
def add_assistant_message(self, content: str):
|
def add_assistant_message(self, content: str):
|
||||||
"""添加助手消息"""
|
"""添加助手消息"""
|
||||||
self.messages.append(ChatMessage(role="assistant", content=content))
|
self.messages.append(ChatMessage(role="assistant", content=content))
|
||||||
|
|
||||||
async def send_message_stream(
|
async def send_message_stream(
|
||||||
self,
|
self, content: str, model: str = "gpt-3.5-turbo", **kwargs
|
||||||
content: str,
|
|
||||||
model: str = "gpt-3.5-turbo",
|
|
||||||
**kwargs
|
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""
|
"""
|
||||||
发送消息并获取流式响应
|
发送消息并获取流式响应
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
content: 用户消息内容
|
content: 用户消息内容
|
||||||
model: 模型名称
|
model: 模型名称
|
||||||
**kwargs: 其他参数
|
**kwargs: 其他参数
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
str: 响应内容片段
|
str: 响应内容片段
|
||||||
"""
|
"""
|
||||||
self.add_user_message(content)
|
self.add_user_message(content)
|
||||||
|
|
||||||
response_content = ""
|
response_content = ""
|
||||||
async for chunk in self.client.get_stream_content(
|
async for chunk in self.client.get_stream_content(messages=self.messages, model=model, **kwargs):
|
||||||
messages=self.messages,
|
|
||||||
model=model,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
response_content += chunk
|
response_content += chunk
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
self.add_assistant_message(response_content)
|
self.add_assistant_message(response_content)
|
||||||
|
|
||||||
async def send_message(
|
async def send_message(self, content: str, model: str = "gpt-3.5-turbo", **kwargs) -> str:
|
||||||
self,
|
|
||||||
content: str,
|
|
||||||
model: str = "gpt-3.5-turbo",
|
|
||||||
**kwargs
|
|
||||||
) -> str:
|
|
||||||
"""
|
"""
|
||||||
发送消息并获取完整响应
|
发送消息并获取完整响应
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
content: 用户消息内容
|
content: 用户消息内容
|
||||||
model: 模型名称
|
model: 模型名称
|
||||||
**kwargs: 其他参数
|
**kwargs: 其他参数
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: 完整响应
|
str: 完整响应
|
||||||
"""
|
"""
|
||||||
self.add_user_message(content)
|
self.add_user_message(content)
|
||||||
|
|
||||||
response = await self.client.chat_completion(
|
response = await self.client.chat_completion(messages=self.messages, model=model, **kwargs)
|
||||||
messages=self.messages,
|
|
||||||
model=model,
|
|
||||||
**kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
response_content = response.choices[0].message.content
|
response_content = response.choices[0].message.content
|
||||||
self.add_assistant_message(response_content)
|
self.add_assistant_message(response_content)
|
||||||
|
|
||||||
return response_content
|
return response_content
|
||||||
|
|
||||||
def clear_history(self, keep_system: bool = True):
|
def clear_history(self, keep_system: bool = True):
|
||||||
"""
|
"""
|
||||||
清除对话历史
|
清除对话历史
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
keep_system: 是否保留系统消息
|
keep_system: 是否保留系统消息
|
||||||
"""
|
"""
|
||||||
@@ -302,11 +276,11 @@ class ConversationManager:
|
|||||||
self.messages = [self.messages[0]]
|
self.messages = [self.messages[0]]
|
||||||
else:
|
else:
|
||||||
self.messages = []
|
self.messages = []
|
||||||
|
|
||||||
def get_message_count(self) -> int:
|
def get_message_count(self) -> int:
|
||||||
"""获取消息数量"""
|
"""获取消息数量"""
|
||||||
return len(self.messages)
|
return len(self.messages)
|
||||||
|
|
||||||
def get_conversation_history(self) -> List[Dict[str, str]]:
|
def get_conversation_history(self) -> List[Dict[str, str]]:
|
||||||
"""获取对话历史"""
|
"""获取对话历史"""
|
||||||
return [msg.to_dict() for msg in self.messages]
|
return [msg.to_dict() for msg in self.messages]
|
||||||
|
|||||||
Reference in New Issue
Block a user