feat:一对多的新模式
This commit is contained in:
58
src/audio/mock_audio.py
Normal file
58
src/audio/mock_audio.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import asyncio
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("MockAudio")
|
||||
|
||||
class MockAudioPlayer:
|
||||
"""
|
||||
一个模拟的音频播放器,它会根据音频数据的"长度"来模拟播放时间。
|
||||
"""
|
||||
def __init__(self, audio_data: bytes):
|
||||
self._audio_data = audio_data
|
||||
# 模拟音频时长:假设每 1024 字节代表 0.5 秒的音频
|
||||
self._duration = (len(audio_data) / 1024.0) * 0.5
|
||||
|
||||
async def play(self):
|
||||
"""模拟播放音频。该过程可以被中断。"""
|
||||
if self._duration <= 0:
|
||||
return
|
||||
logger.info(f"开始播放模拟音频,预计时长: {self._duration:.2f} 秒...")
|
||||
try:
|
||||
await asyncio.sleep(self._duration)
|
||||
logger.info("模拟音频播放完毕。")
|
||||
except asyncio.CancelledError:
|
||||
logger.info("音频播放被中断。")
|
||||
raise # 重新抛出异常,以便上层逻辑可以捕获它
|
||||
|
||||
class MockAudioGenerator:
|
||||
"""
|
||||
一个模拟的文本到语音(TTS)生成器。
|
||||
"""
|
||||
def __init__(self):
|
||||
# 模拟生成速度:每秒生成的字符数
|
||||
self.chars_per_second = 25.0
|
||||
|
||||
async def generate(self, text: str) -> bytes:
|
||||
"""
|
||||
模拟从文本生成音频数据。该过程可以被中断。
|
||||
|
||||
Args:
|
||||
text: 需要转换为音频的文本。
|
||||
|
||||
Returns:
|
||||
模拟的音频数据(bytes)。
|
||||
"""
|
||||
if not text:
|
||||
return b''
|
||||
|
||||
generation_time = len(text) / self.chars_per_second
|
||||
logger.info(f"模拟生成音频... 文本长度: {len(text)}, 预计耗时: {generation_time:.2f} 秒...")
|
||||
try:
|
||||
await asyncio.sleep(generation_time)
|
||||
# 生成虚拟的音频数据,其长度与文本长度成正比
|
||||
mock_audio_data = b'\x01\x02\x03' * (len(text) * 40)
|
||||
logger.info(f"模拟音频生成完毕,数据大小: {len(mock_audio_data) / 1024:.2f} KB。")
|
||||
return mock_audio_data
|
||||
except asyncio.CancelledError:
|
||||
logger.info("音频生成被中断。")
|
||||
raise # 重新抛出异常
|
||||
@@ -13,8 +13,11 @@ from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.config.config import global_config
|
||||
from src.plugin_system.core.component_registry import component_registry # 导入新插件系统
|
||||
from src.plugin_system.base.base_command import BaseCommand
|
||||
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
||||
# 定义日志配置
|
||||
|
||||
ENABLE_S4U_CHAT = True
|
||||
# 仅内部开启
|
||||
|
||||
# 配置主程序日志格式
|
||||
logger = get_logger("chat")
|
||||
@@ -30,6 +33,7 @@ class ChatBot:
|
||||
# 创建初始化PFC管理器的任务,会在_ensure_started时执行
|
||||
self.only_process_chat = MessageProcessor()
|
||||
self.pfc_manager = PFCManager.get_instance()
|
||||
self.s4u_message_processor = S4UMessageProcessor()
|
||||
|
||||
async def _ensure_started(self):
|
||||
"""确保所有任务已启动"""
|
||||
@@ -176,6 +180,14 @@ class ChatBot:
|
||||
# 如果在私聊中
|
||||
if group_info is None:
|
||||
logger.debug("检测到私聊消息")
|
||||
|
||||
if ENABLE_S4U_CHAT:
|
||||
logger.debug("进入S4U私聊处理流程")
|
||||
await self.s4u_message_processor.process_message(message)
|
||||
return
|
||||
|
||||
|
||||
|
||||
if global_config.experimental.pfc_chatting:
|
||||
logger.debug("进入PFC私聊处理流程")
|
||||
# 创建聊天流
|
||||
@@ -188,6 +200,13 @@ class ChatBot:
|
||||
await self.heartflow_message_receiver.process_message(message)
|
||||
# 群聊默认进入心流消息处理逻辑
|
||||
else:
|
||||
|
||||
if ENABLE_S4U_CHAT:
|
||||
logger.debug("进入S4U私聊处理流程")
|
||||
await self.s4u_message_processor.process_message(message)
|
||||
return
|
||||
|
||||
|
||||
logger.debug(f"检测到群聊消息,群ID: {group_info.group_id}")
|
||||
await self.heartflow_message_receiver.process_message(message)
|
||||
|
||||
|
||||
@@ -174,6 +174,7 @@ def _build_readable_messages_internal(
|
||||
truncate: bool = False,
|
||||
pic_id_mapping: Dict[str, str] = None,
|
||||
pic_counter: int = 1,
|
||||
show_pic: bool = True,
|
||||
) -> Tuple[str, List[Tuple[float, str, str]], Dict[str, str], int]:
|
||||
"""
|
||||
内部辅助函数,构建可读消息字符串和原始消息详情列表。
|
||||
@@ -260,8 +261,10 @@ def _build_readable_messages_internal(
|
||||
content = content.replace("ⁿ", "")
|
||||
|
||||
# 处理图片ID
|
||||
if show_pic:
|
||||
content = process_pic_ids(content)
|
||||
|
||||
|
||||
# 检查必要信息是否存在
|
||||
if not all([platform, user_id, timestamp is not None]):
|
||||
continue
|
||||
@@ -532,6 +535,7 @@ def build_readable_messages(
|
||||
read_mark: float = 0.0,
|
||||
truncate: bool = False,
|
||||
show_actions: bool = False,
|
||||
show_pic: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
将消息列表转换为可读的文本格式。
|
||||
@@ -601,7 +605,7 @@ def build_readable_messages(
|
||||
if read_mark <= 0:
|
||||
# 没有有效的 read_mark,直接格式化所有消息
|
||||
formatted_string, _, pic_id_mapping, _ = _build_readable_messages_internal(
|
||||
copy_messages, replace_bot_name, merge_messages, timestamp_mode, truncate
|
||||
copy_messages, replace_bot_name, merge_messages, timestamp_mode, truncate, show_pic=show_pic
|
||||
)
|
||||
|
||||
# 生成图片映射信息并添加到最前面
|
||||
@@ -628,9 +632,10 @@ def build_readable_messages(
|
||||
truncate,
|
||||
pic_id_mapping,
|
||||
pic_counter,
|
||||
show_pic=show_pic
|
||||
)
|
||||
formatted_after, _, pic_id_mapping, _ = _build_readable_messages_internal(
|
||||
messages_after_mark, replace_bot_name, merge_messages, timestamp_mode, False, pic_id_mapping, pic_counter
|
||||
messages_after_mark, replace_bot_name, merge_messages, timestamp_mode, False, pic_id_mapping, pic_counter, show_pic=show_pic
|
||||
)
|
||||
|
||||
read_mark_line = "\n--- 以上消息是你已经看过,请关注以下未读的新消息---\n"
|
||||
|
||||
302
src/mais4u/mais4u_chat/s4u_chat.py
Normal file
302
src/mais4u/mais4u_chat/s4u_chat.py
Normal file
@@ -0,0 +1,302 @@
|
||||
import asyncio
|
||||
import time
|
||||
import traceback
|
||||
import random
|
||||
from typing import List, Optional, Dict # 导入类型提示
|
||||
import os
|
||||
import pickle
|
||||
from maim_message import UserInfo, Seg
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info
|
||||
from src.manager.mood_manager import mood_manager
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
from .s4u_stream_generator import S4UStreamGenerator
|
||||
from src.chat.message_receive.message import MessageSending, MessageRecv, MessageThinking, MessageSet
|
||||
from src.chat.message_receive.message_sender import message_manager
|
||||
from src.chat.normal_chat.willing.willing_manager import get_willing_manager
|
||||
from src.chat.normal_chat.normal_chat_utils import get_recent_message_stats
|
||||
from src.config.config import global_config
|
||||
from src.chat.focus_chat.planners.action_manager import ActionManager
|
||||
from src.chat.normal_chat.normal_chat_planner import NormalChatPlanner
|
||||
from src.chat.normal_chat.normal_chat_action_modifier import NormalChatActionModifier
|
||||
from src.chat.normal_chat.normal_chat_expressor import NormalChatExpressor
|
||||
from src.chat.focus_chat.replyer.default_generator import DefaultReplyer
|
||||
from src.person_info.person_info import PersonInfoManager
|
||||
from src.person_info.relationship_manager import get_relationship_manager
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
get_raw_msg_by_timestamp_with_chat,
|
||||
get_raw_msg_by_timestamp_with_chat_inclusive,
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
num_new_messages_since,
|
||||
)
|
||||
from src.common.message.api import get_global_api
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.audio.mock_audio import MockAudioGenerator, MockAudioPlayer
|
||||
|
||||
|
||||
logger = get_logger("S4U_chat")
|
||||
|
||||
|
||||
class MessageSenderContainer:
|
||||
"""一个简单的容器,用于按顺序发送消息并模拟打字效果。"""
|
||||
def __init__(self, chat_stream: ChatStream, original_message: MessageRecv):
|
||||
self.chat_stream = chat_stream
|
||||
self.original_message = original_message
|
||||
self.queue = asyncio.Queue()
|
||||
self.storage = MessageStorage()
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._paused_event = asyncio.Event()
|
||||
self._paused_event.set() # 默认设置为非暂停状态
|
||||
|
||||
async def add_message(self, chunk: str):
|
||||
"""向队列中添加一个消息块。"""
|
||||
await self.queue.put(chunk)
|
||||
|
||||
async def close(self):
|
||||
"""表示没有更多消息了,关闭队列。"""
|
||||
await self.queue.put(None) # Sentinel
|
||||
|
||||
def pause(self):
|
||||
"""暂停发送。"""
|
||||
self._paused_event.clear()
|
||||
|
||||
def resume(self):
|
||||
"""恢复发送。"""
|
||||
self._paused_event.set()
|
||||
|
||||
def _calculate_typing_delay(self, text: str) -> float:
|
||||
"""根据文本长度计算模拟打字延迟。"""
|
||||
chars_per_second = 15.0
|
||||
min_delay = 0.2
|
||||
max_delay = 2.0
|
||||
|
||||
delay = len(text) / chars_per_second
|
||||
return max(min_delay, min(delay, max_delay))
|
||||
|
||||
async def _send_worker(self):
|
||||
"""从队列中取出消息并发送。"""
|
||||
while True:
|
||||
try:
|
||||
# This structure ensures that task_done() is called for every item retrieved,
|
||||
# even if the worker is cancelled while processing the item.
|
||||
chunk = await self.queue.get()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
try:
|
||||
if chunk is None:
|
||||
break
|
||||
|
||||
# Check for pause signal *after* getting an item.
|
||||
await self._paused_event.wait()
|
||||
|
||||
delay = self._calculate_typing_delay(chunk)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
current_time = time.time()
|
||||
msg_id = f"{current_time}_{random.randint(1000, 9999)}"
|
||||
|
||||
text_to_send = chunk
|
||||
if global_config.experimental.debug_show_chat_mode:
|
||||
text_to_send += "ⁿ"
|
||||
|
||||
message_segment = Seg(type="text", data=text_to_send)
|
||||
bot_message = MessageSending(
|
||||
message_id=msg_id,
|
||||
chat_stream=self.chat_stream,
|
||||
bot_user_info=UserInfo(
|
||||
user_id=global_config.bot.qq_account,
|
||||
user_nickname=global_config.bot.nickname,
|
||||
platform=self.original_message.message_info.platform,
|
||||
),
|
||||
sender_info=self.original_message.message_info.user_info,
|
||||
message_segment=message_segment,
|
||||
reply=self.original_message,
|
||||
is_emoji=False,
|
||||
apply_set_reply_logic=True,
|
||||
)
|
||||
|
||||
await bot_message.process()
|
||||
|
||||
await get_global_api().send_message(bot_message)
|
||||
logger.info(f"已将消息 '{text_to_send}' 发往平台 '{bot_message.message_info.platform}'")
|
||||
|
||||
await self.storage.store_message(bot_message, self.chat_stream)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.chat_stream.get_stream_name()}] 消息发送或存储时出现错误: {e}", exc_info=True)
|
||||
|
||||
finally:
|
||||
# CRUCIAL: Always call task_done() for any item that was successfully retrieved.
|
||||
self.queue.task_done()
|
||||
|
||||
def start(self):
|
||||
"""启动发送任务。"""
|
||||
if self._task is None:
|
||||
self._task = asyncio.create_task(self._send_worker())
|
||||
|
||||
async def join(self):
|
||||
"""等待所有消息发送完毕。"""
|
||||
if self._task:
|
||||
await self._task
|
||||
|
||||
|
||||
class S4UChatManager:
|
||||
def __init__(self):
|
||||
self.s4u_chats: Dict[str, "S4UChat"] = {}
|
||||
|
||||
def get_or_create_chat(self, chat_stream: ChatStream) -> "S4UChat":
|
||||
if chat_stream.stream_id not in self.s4u_chats:
|
||||
stream_name = get_chat_manager().get_stream_name(chat_stream.stream_id) or chat_stream.stream_id
|
||||
logger.info(f"Creating new S4UChat for stream: {stream_name}")
|
||||
self.s4u_chats[chat_stream.stream_id] = S4UChat(chat_stream)
|
||||
return self.s4u_chats[chat_stream.stream_id]
|
||||
|
||||
s4u_chat_manager = S4UChatManager()
|
||||
|
||||
def get_s4u_chat_manager() -> S4UChatManager:
|
||||
return s4u_chat_manager
|
||||
|
||||
|
||||
class S4UChat:
|
||||
def __init__(self, chat_stream: ChatStream):
|
||||
"""初始化 S4UChat 实例。"""
|
||||
|
||||
self.chat_stream = chat_stream
|
||||
self.stream_id = chat_stream.stream_id
|
||||
self.stream_name = get_chat_manager().get_stream_name(self.stream_id) or self.stream_id
|
||||
|
||||
self._message_queue = asyncio.Queue()
|
||||
self._processing_task = asyncio.create_task(self._message_processor())
|
||||
self._current_generation_task: Optional[asyncio.Task] = None
|
||||
|
||||
self._is_replying = False
|
||||
|
||||
# 初始化Normal Chat专用表达器
|
||||
self.expressor = NormalChatExpressor(self.chat_stream)
|
||||
self.replyer = DefaultReplyer(self.chat_stream)
|
||||
|
||||
self.gpt = S4UStreamGenerator()
|
||||
self.audio_generator = MockAudioGenerator()
|
||||
self.start_time = time.time()
|
||||
|
||||
# 记录最近的回复内容,每项包含: {time, user_message, response, is_mentioned, is_reference_reply}
|
||||
self.recent_replies = []
|
||||
self.max_replies_history = 20 # 最多保存最近20条回复记录
|
||||
|
||||
self.storage = MessageStorage()
|
||||
|
||||
|
||||
logger.info(f"[{self.stream_name}] S4UChat")
|
||||
|
||||
|
||||
# 改为实例方法, 移除 chat 参数
|
||||
async def response(self, message: MessageRecv, is_mentioned: bool, interested_rate: float) -> None:
|
||||
"""将消息放入队列并中断当前处理(如果正在处理)。"""
|
||||
if self._current_generation_task and not self._current_generation_task.done():
|
||||
self._current_generation_task.cancel()
|
||||
logger.info(f"[{self.stream_name}] 请求中断当前回复生成任务。")
|
||||
|
||||
await self._message_queue.put(message)
|
||||
|
||||
async def _message_processor(self):
|
||||
"""从队列中处理消息,支持中断。"""
|
||||
while True:
|
||||
try:
|
||||
# 等待第一条消息
|
||||
message = await self._message_queue.get()
|
||||
|
||||
# 如果因快速中断导致队列中积压了更多消息,则只处理最新的一条
|
||||
while not self._message_queue.empty():
|
||||
drained_msg = self._message_queue.get_nowait()
|
||||
self._message_queue.task_done() # 为取出的旧消息调用 task_done
|
||||
message = drained_msg # 始终处理最新消息
|
||||
logger.info(f"[{self.stream_name}] 丢弃过时消息,处理最新消息: {message.processed_plain_text}")
|
||||
|
||||
self._current_generation_task = asyncio.create_task(self._generate_and_send(message))
|
||||
|
||||
try:
|
||||
await self._current_generation_task
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"[{self.stream_name}] 回复生成被外部中断。")
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.stream_name}] _generate_and_send 任务出现错误: {e}", exc_info=True)
|
||||
finally:
|
||||
self._current_generation_task = None
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"[{self.stream_name}] 消息处理器正在关闭。")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.stream_name}] 消息处理器主循环发生未知错误: {e}", exc_info=True)
|
||||
await asyncio.sleep(1) # 避免在未知错误下陷入CPU空转
|
||||
finally:
|
||||
# 确保处理过的消息(无论是正常完成还是被丢弃)都被标记完成
|
||||
if 'message' in locals():
|
||||
self._message_queue.task_done()
|
||||
|
||||
|
||||
async def _generate_and_send(self, message: MessageRecv):
|
||||
"""为单个消息生成文本和音频回复。整个过程可以被中断。"""
|
||||
self._is_replying = True
|
||||
sender_container = MessageSenderContainer(self.chat_stream, message)
|
||||
sender_container.start()
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
f"[S4U] 开始为消息生成文本和音频流: "
|
||||
f"'{message.processed_plain_text[:30]}...'"
|
||||
)
|
||||
|
||||
# 1. 逐句生成文本、发送并播放音频
|
||||
gen = self.gpt.generate_response(message, "")
|
||||
async for chunk in gen:
|
||||
# 如果任务被取消,await 会在此处引发 CancelledError
|
||||
|
||||
# a. 发送文本块
|
||||
await sender_container.add_message(chunk)
|
||||
|
||||
# b. 为该文本块生成并播放音频
|
||||
if chunk.strip():
|
||||
audio_data = await self.audio_generator.generate(chunk)
|
||||
player = MockAudioPlayer(audio_data)
|
||||
await player.play()
|
||||
|
||||
# 等待所有文本消息发送完成
|
||||
await sender_container.close()
|
||||
await sender_container.join()
|
||||
logger.info(f"[{self.stream_name}] 所有文本和音频块处理完毕。")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"[{self.stream_name}] 回复流程(文本或音频)被中断。")
|
||||
raise # 将取消异常向上传播
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.stream_name}] 回复生成过程中出现错误: {e}", exc_info=True)
|
||||
finally:
|
||||
self._is_replying = False
|
||||
# 确保发送器被妥善关闭(即使已关闭,再次调用也是安全的)
|
||||
sender_container.resume()
|
||||
if not sender_container._task.done():
|
||||
await sender_container.close()
|
||||
await sender_container.join()
|
||||
logger.info(f"[{self.stream_name}] _generate_and_send 任务结束,资源已清理。")
|
||||
|
||||
|
||||
async def shutdown(self):
|
||||
"""平滑关闭处理任务。"""
|
||||
logger.info(f"正在关闭 S4UChat: {self.stream_name}")
|
||||
|
||||
# 取消正在运行的任务
|
||||
if self._current_generation_task and not self._current_generation_task.done():
|
||||
self._current_generation_task.cancel()
|
||||
|
||||
if self._processing_task and not self._processing_task.done():
|
||||
self._processing_task.cancel()
|
||||
|
||||
# 等待任务响应取消
|
||||
try:
|
||||
await self._processing_task
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"处理任务已成功取消: {self.stream_name}")
|
||||
70
src/mais4u/mais4u_chat/s4u_msg_processor.py
Normal file
70
src/mais4u/mais4u_chat/s4u_msg_processor.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.chat.heart_flow.heartflow import heartflow
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream
|
||||
from src.chat.utils.utils import is_mentioned_bot_in_message
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.common.logger import get_logger
|
||||
from .s4u_chat import get_s4u_chat_manager
|
||||
|
||||
import math
|
||||
import re
|
||||
import traceback
|
||||
from typing import Optional, Tuple
|
||||
from maim_message import UserInfo
|
||||
|
||||
from src.person_info.relationship_manager import get_relationship_manager
|
||||
|
||||
# from ..message_receive.message_buffer import message_buffer
|
||||
|
||||
logger = get_logger("chat")
|
||||
|
||||
|
||||
class S4UMessageProcessor:
|
||||
"""心流处理器,负责处理接收到的消息并计算兴趣度"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化心流处理器,创建消息存储实例"""
|
||||
self.storage = MessageStorage()
|
||||
|
||||
async def process_message(self, message: MessageRecv) -> None:
|
||||
"""处理接收到的原始消息数据
|
||||
|
||||
主要流程:
|
||||
1. 消息解析与初始化
|
||||
2. 消息缓冲处理
|
||||
3. 过滤检查
|
||||
4. 兴趣度计算
|
||||
5. 关系处理
|
||||
|
||||
Args:
|
||||
message_data: 原始消息字符串
|
||||
"""
|
||||
|
||||
target_user_id = "1026294844"
|
||||
|
||||
# 1. 消息解析与初始化
|
||||
groupinfo = message.message_info.group_info
|
||||
userinfo = message.message_info.user_info
|
||||
messageinfo = message.message_info
|
||||
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform=messageinfo.platform,
|
||||
user_info=userinfo,
|
||||
group_info=groupinfo,
|
||||
)
|
||||
|
||||
await self.storage.store_message(message, chat)
|
||||
|
||||
is_mentioned = is_mentioned_bot_in_message(message)
|
||||
s4u_chat = get_s4u_chat_manager().get_or_create_chat(chat)
|
||||
|
||||
if userinfo.user_id == target_user_id:
|
||||
await s4u_chat.response(message, is_mentioned=is_mentioned, interested_rate=1.0)
|
||||
|
||||
|
||||
# 7. 日志记录
|
||||
logger.info(f"[S4U]{userinfo.user_nickname}:{message.processed_plain_text}")
|
||||
|
||||
230
src/mais4u/mais4u_chat/s4u_prompt.py
Normal file
230
src/mais4u/mais4u_chat/s4u_prompt.py
Normal file
@@ -0,0 +1,230 @@
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
from src.individuality.individuality import get_individuality
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
import time
|
||||
from src.chat.utils.utils import get_recent_group_speaker
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
import random
|
||||
|
||||
from src.person_info.relationship_manager import get_relationship_manager
|
||||
|
||||
logger = get_logger("prompt")
|
||||
|
||||
|
||||
def init_prompt():
|
||||
Prompt("你正在qq群里聊天,下面是群里在聊的内容:", "chat_target_group1")
|
||||
Prompt("你正在和{sender_name}聊天,这是你们之前聊的内容:", "chat_target_private1")
|
||||
Prompt("在群里聊天", "chat_target_group2")
|
||||
Prompt("和{sender_name}私聊", "chat_target_private2")
|
||||
|
||||
Prompt("\n你有以下这些**知识**:\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n", "knowledge_prompt")
|
||||
|
||||
|
||||
Prompt(
|
||||
"""
|
||||
你的名字叫{bot_name},昵称是:{bot_other_names},{prompt_personality}。
|
||||
你现在的主要任务是和 {sender_name} 聊天。同时,也有其他用户会参与你们的聊天,但是你主要还是关注你和{sender_name}的聊天内容。
|
||||
|
||||
{background_dialogue_prompt}
|
||||
--------------------------------
|
||||
{now_time}
|
||||
这是你和{sender_name}的对话,你们正在交流中:
|
||||
{core_dialogue_prompt}
|
||||
|
||||
{message_txt}
|
||||
回复可以简短一些。可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞,平淡一些。
|
||||
不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出回复内容,现在{sender_name}正在等待你的回复。
|
||||
你的回复风格不要浮夸,有逻辑和条理,请你继续回复{sender_name}。""",
|
||||
"s4u_prompt", # New template for private CHAT chat
|
||||
)
|
||||
|
||||
|
||||
class PromptBuilder:
|
||||
def __init__(self):
|
||||
self.prompt_built = ""
|
||||
self.activate_messages = ""
|
||||
|
||||
async def build_prompt_normal(
|
||||
self,
|
||||
message,
|
||||
chat_stream,
|
||||
message_txt: str,
|
||||
sender_name: str = "某人",
|
||||
) -> str:
|
||||
prompt_personality = get_individuality().get_prompt(x_person=2, level=2)
|
||||
is_group_chat = bool(chat_stream.group_info)
|
||||
|
||||
who_chat_in_group = []
|
||||
if is_group_chat:
|
||||
who_chat_in_group = get_recent_group_speaker(
|
||||
chat_stream.stream_id,
|
||||
(chat_stream.user_info.platform, chat_stream.user_info.user_id) if chat_stream.user_info else None,
|
||||
limit=global_config.normal_chat.max_context_size,
|
||||
)
|
||||
elif chat_stream.user_info:
|
||||
who_chat_in_group.append(
|
||||
(chat_stream.user_info.platform, chat_stream.user_info.user_id, chat_stream.user_info.user_nickname)
|
||||
)
|
||||
|
||||
relation_prompt = ""
|
||||
if global_config.relationship.enable_relationship:
|
||||
for person in who_chat_in_group:
|
||||
relationship_manager = get_relationship_manager()
|
||||
relation_prompt += await relationship_manager.build_relationship_info(person)
|
||||
|
||||
|
||||
memory_prompt = ""
|
||||
related_memory = await hippocampus_manager.get_memory_from_text(
|
||||
text=message_txt, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False
|
||||
)
|
||||
|
||||
related_memory_info = ""
|
||||
if related_memory:
|
||||
for memory in related_memory:
|
||||
related_memory_info += memory[1]
|
||||
memory_prompt = await global_prompt_manager.format_prompt(
|
||||
"memory_prompt", related_memory_info=related_memory_info
|
||||
)
|
||||
|
||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_stream.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=100,
|
||||
)
|
||||
|
||||
|
||||
# 分别筛选核心对话和背景对话
|
||||
core_dialogue_list = []
|
||||
background_dialogue_list = []
|
||||
bot_id = str(global_config.bot.qq_account)
|
||||
target_user_id = str(message.chat_stream.user_info.user_id)
|
||||
|
||||
for msg_dict in message_list_before_now:
|
||||
try:
|
||||
# 直接通过字典访问
|
||||
msg_user_id = str(msg_dict.get('user_id'))
|
||||
|
||||
if msg_user_id == bot_id or msg_user_id == target_user_id:
|
||||
core_dialogue_list.append(msg_dict)
|
||||
else:
|
||||
background_dialogue_list.append(msg_dict)
|
||||
except Exception as e:
|
||||
logger.error(f"无法处理历史消息记录: {msg_dict}, 错误: {e}")
|
||||
|
||||
if background_dialogue_list:
|
||||
latest_25_msgs = background_dialogue_list[-25:]
|
||||
background_dialogue_prompt = build_readable_messages(
|
||||
latest_25_msgs,
|
||||
merge_messages=True,
|
||||
timestamp_mode = "normal_no_YMD",
|
||||
show_pic = False,
|
||||
)
|
||||
background_dialogue_prompt = f"这是其他用户的发言:\n{background_dialogue_prompt}"
|
||||
else:
|
||||
background_dialogue_prompt = ""
|
||||
|
||||
# 分别获取最新50条和最新25条(从message_list_before_now截取)
|
||||
core_dialogue_list = core_dialogue_list[-50:]
|
||||
|
||||
first_msg = core_dialogue_list[0]
|
||||
start_speaking_user_id = first_msg.get('user_id')
|
||||
if start_speaking_user_id == bot_id:
|
||||
last_speaking_user_id = bot_id
|
||||
msg_seg_str = "你的发言:\n"
|
||||
else:
|
||||
start_speaking_user_id = target_user_id
|
||||
last_speaking_user_id = start_speaking_user_id
|
||||
msg_seg_str = "对方的发言:\n"
|
||||
|
||||
msg_seg_str += f"{first_msg.get('processed_plain_text')}\n"
|
||||
|
||||
all_msg_seg_list = []
|
||||
for msg in core_dialogue_list[1:]:
|
||||
speaker = msg.get('user_id')
|
||||
if speaker == last_speaking_user_id:
|
||||
#还是同一个人讲话
|
||||
msg_seg_str += f"{msg.get('processed_plain_text')}\n"
|
||||
else:
|
||||
#换人了
|
||||
msg_seg_str = f"{msg_seg_str}\n"
|
||||
all_msg_seg_list.append(msg_seg_str)
|
||||
|
||||
if speaker == bot_id:
|
||||
msg_seg_str = "你的发言:\n"
|
||||
else:
|
||||
msg_seg_str = "对方的发言:\n"
|
||||
|
||||
msg_seg_str += f"{msg.get('processed_plain_text')}\n"
|
||||
last_speaking_user_id = speaker
|
||||
|
||||
all_msg_seg_list.append(msg_seg_str)
|
||||
|
||||
|
||||
core_msg_str = ""
|
||||
for msg in all_msg_seg_list:
|
||||
# print(f"msg: {msg}")
|
||||
core_msg_str += msg
|
||||
|
||||
now_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
now_time = f"现在的时间是:{now_time}"
|
||||
|
||||
template_name = "s4u_prompt"
|
||||
effective_sender_name = sender_name
|
||||
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
template_name,
|
||||
relation_prompt=relation_prompt,
|
||||
sender_name=effective_sender_name,
|
||||
memory_prompt=memory_prompt,
|
||||
core_dialogue_prompt=core_msg_str,
|
||||
background_dialogue_prompt=background_dialogue_prompt,
|
||||
message_txt=message_txt,
|
||||
bot_name=global_config.bot.nickname,
|
||||
bot_other_names="/".join(global_config.bot.alias_names),
|
||||
prompt_personality=prompt_personality,
|
||||
now_time=now_time,
|
||||
)
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def weighted_sample_no_replacement(items, weights, k) -> list:
|
||||
"""
|
||||
加权且不放回地随机抽取k个元素。
|
||||
|
||||
参数:
|
||||
items: 待抽取的元素列表
|
||||
weights: 每个元素对应的权重(与items等长,且为正数)
|
||||
k: 需要抽取的元素个数
|
||||
返回:
|
||||
selected: 按权重加权且不重复抽取的k个元素组成的列表
|
||||
|
||||
如果 items 中的元素不足 k 个,就只会返回所有可用的元素
|
||||
|
||||
实现思路:
|
||||
每次从当前池中按权重加权随机选出一个元素,选中后将其从池中移除,重复k次。
|
||||
这样保证了:
|
||||
1. count越大被选中概率越高
|
||||
2. 不会重复选中同一个元素
|
||||
"""
|
||||
selected = []
|
||||
pool = list(zip(items, weights))
|
||||
for _ in range(min(k, len(pool))):
|
||||
total = sum(w for _, w in pool)
|
||||
r = random.uniform(0, total)
|
||||
upto = 0
|
||||
for idx, (item, weight) in enumerate(pool):
|
||||
upto += weight
|
||||
if upto >= r:
|
||||
selected.append(item)
|
||||
pool.pop(idx)
|
||||
break
|
||||
return selected
|
||||
|
||||
|
||||
init_prompt()
|
||||
prompt_builder = PromptBuilder()
|
||||
140
src/mais4u/mais4u_chat/s4u_stream_generator.py
Normal file
140
src/mais4u/mais4u_chat/s4u_stream_generator.py
Normal file
@@ -0,0 +1,140 @@
|
||||
import os
|
||||
from typing import AsyncGenerator
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.mais4u.openai_client import AsyncOpenAIClient
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.mais4u.mais4u_chat.s4u_prompt import prompt_builder
|
||||
from src.common.logger import get_logger
|
||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||
import asyncio
|
||||
import re
|
||||
|
||||
|
||||
logger = get_logger("s4u_stream_generator")
|
||||
|
||||
|
||||
class S4UStreamGenerator:
|
||||
def __init__(self):
|
||||
replyer_1_config = global_config.model.replyer_1
|
||||
provider = replyer_1_config.get("provider")
|
||||
if not provider:
|
||||
logger.error("`replyer_1` 在配置文件中缺少 `provider` 字段")
|
||||
raise ValueError("`replyer_1` 在配置文件中缺少 `provider` 字段")
|
||||
|
||||
api_key = os.environ.get(f"{provider.upper()}_KEY")
|
||||
base_url = os.environ.get(f"{provider.upper()}_BASE_URL")
|
||||
|
||||
if not api_key:
|
||||
logger.error(f"环境变量 {provider.upper()}_KEY 未设置")
|
||||
raise ValueError(f"环境变量 {provider.upper()}_KEY 未设置")
|
||||
|
||||
self.client_1 = AsyncOpenAIClient(api_key=api_key, base_url=base_url)
|
||||
self.model_1_name = replyer_1_config.get("name")
|
||||
if not self.model_1_name:
|
||||
logger.error("`replyer_1` 在配置文件中缺少 `model_name` 字段")
|
||||
raise ValueError("`replyer_1` 在配置文件中缺少 `model_name` 字段")
|
||||
self.replyer_1_config = replyer_1_config
|
||||
|
||||
self.model_sum = LLMRequest(model=global_config.model.memory_summary, temperature=0.7, request_type="relation")
|
||||
self.current_model_name = "unknown model"
|
||||
|
||||
# 正则表达式用于按句子切分,同时处理各种标点和边缘情况
|
||||
# 匹配常见的句子结束符,但会忽略引号内和数字中的标点
|
||||
self.sentence_split_pattern = re.compile(
|
||||
r'([^\s\w"\'([{]*["\'([{].*?["\'}\])][^\s\w"\'([{]*|' # 匹配被引号/括号包裹的内容
|
||||
r'[^.。!??!\n\r]+(?:[.。!??!\n\r](?![\'"])|$))' # 匹配直到句子结束符
|
||||
, re.UNICODE | re.DOTALL
|
||||
)
|
||||
|
||||
async def generate_response(
|
||||
self, message: MessageRecv, previous_reply_context: str = ""
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""根据当前模型类型选择对应的生成函数"""
|
||||
# 从global_config中获取模型概率值并选择模型
|
||||
current_client = self.client_1
|
||||
self.current_model_name = self.model_1_name
|
||||
|
||||
person_id = PersonInfoManager.get_person_id(
|
||||
message.chat_stream.user_info.platform, message.chat_stream.user_info.user_id
|
||||
)
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||
|
||||
if message.chat_stream.user_info.user_nickname:
|
||||
sender_name = f"[{message.chat_stream.user_info.user_nickname}](你叫ta{person_name})"
|
||||
else:
|
||||
sender_name = f"用户({message.chat_stream.user_info.user_id})"
|
||||
|
||||
# 构建prompt
|
||||
if previous_reply_context:
|
||||
message_txt = f"""
|
||||
你正在回复用户的消息,但中途被打断了。这是已有的对话上下文:
|
||||
[你已经对上一条消息说的话]: {previous_reply_context}
|
||||
---
|
||||
[这是用户发来的新消息, 你需要结合上下文,对此进行回复]:
|
||||
{message.processed_plain_text}
|
||||
"""
|
||||
else:
|
||||
message_txt = message.processed_plain_text
|
||||
|
||||
|
||||
prompt = await prompt_builder.build_prompt_normal(
|
||||
message = message,
|
||||
message_txt=message_txt,
|
||||
sender_name=sender_name,
|
||||
chat_stream=message.chat_stream,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"{self.current_model_name}思考:{message_txt[:30] + '...' if len(message_txt) > 30 else message_txt}"
|
||||
) # noqa: E501
|
||||
|
||||
extra_kwargs = {}
|
||||
if self.replyer_1_config.get("enable_thinking") is not None:
|
||||
extra_kwargs["enable_thinking"] = self.replyer_1_config.get("enable_thinking")
|
||||
if self.replyer_1_config.get("thinking_budget") is not None:
|
||||
extra_kwargs["thinking_budget"] = self.replyer_1_config.get("thinking_budget")
|
||||
|
||||
async for chunk in self._generate_response_with_model(
|
||||
prompt, current_client, self.current_model_name, **extra_kwargs
|
||||
):
|
||||
yield chunk
|
||||
|
||||
async def _generate_response_with_model(
|
||||
self,
|
||||
prompt: str,
|
||||
client: AsyncOpenAIClient,
|
||||
model_name: str,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
print(prompt)
|
||||
|
||||
buffer = ""
|
||||
delimiters = ",。!?,.!?\n\r" # For final trimming
|
||||
|
||||
async for content in client.get_stream_content(
|
||||
messages=[{"role": "user", "content": prompt}], model=model_name, **kwargs
|
||||
):
|
||||
buffer += content
|
||||
|
||||
# 使用正则表达式匹配句子
|
||||
last_match_end = 0
|
||||
for match in self.sentence_split_pattern.finditer(buffer):
|
||||
sentence = match.group(0).strip()
|
||||
if sentence:
|
||||
# 如果句子看起来完整(即不只是等待更多内容),则发送
|
||||
if match.end(0) < len(buffer) or sentence.endswith(tuple(delimiters)):
|
||||
yield sentence
|
||||
await asyncio.sleep(0) # 允许其他任务运行
|
||||
last_match_end = match.end(0)
|
||||
|
||||
# 从缓冲区移除已发送的部分
|
||||
if last_match_end > 0:
|
||||
buffer = buffer[last_match_end:]
|
||||
|
||||
# 发送缓冲区中剩余的任何内容
|
||||
if buffer.strip():
|
||||
yield buffer.strip()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
312
src/mais4u/openai_client.py
Normal file
312
src/mais4u/openai_client.py
Normal file
@@ -0,0 +1,312 @@
|
||||
import asyncio
|
||||
import json
|
||||
from typing import AsyncGenerator, Dict, List, Optional, Union, Any
|
||||
from dataclasses import dataclass
|
||||
import aiohttp
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatMessage:
|
||||
"""聊天消息数据类"""
|
||||
role: str
|
||||
content: str
|
||||
|
||||
def to_dict(self) -> Dict[str, str]:
|
||||
return {"role": self.role, "content": self.content}
|
||||
|
||||
|
||||
class AsyncOpenAIClient:
|
||||
"""异步OpenAI客户端,支持流式传输"""
|
||||
|
||||
def __init__(self, api_key: str, base_url: Optional[str] = None):
|
||||
"""
|
||||
初始化客户端
|
||||
|
||||
Args:
|
||||
api_key: OpenAI API密钥
|
||||
base_url: 可选的API基础URL,用于自定义端点
|
||||
"""
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
timeout=10.0, # 设置60秒的全局超时
|
||||
)
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
messages: List[Union[ChatMessage, Dict[str, str]]],
|
||||
model: str = "gpt-3.5-turbo",
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None,
|
||||
**kwargs
|
||||
) -> ChatCompletion:
|
||||
"""
|
||||
非流式聊天完成
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
model: 模型名称
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大token数
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
完整的聊天回复
|
||||
"""
|
||||
# 转换消息格式
|
||||
formatted_messages = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, ChatMessage):
|
||||
formatted_messages.append(msg.to_dict())
|
||||
else:
|
||||
formatted_messages.append(msg)
|
||||
|
||||
extra_body = {}
|
||||
if kwargs.get("enable_thinking") is not None:
|
||||
extra_body["enable_thinking"] = kwargs.pop("enable_thinking")
|
||||
if kwargs.get("thinking_budget") is not None:
|
||||
extra_body["thinking_budget"] = kwargs.pop("thinking_budget")
|
||||
|
||||
response = await self.client.chat.completions.create(
|
||||
model=model,
|
||||
messages=formatted_messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=False,
|
||||
extra_body=extra_body if extra_body else None,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
async def chat_completion_stream(
|
||||
self,
|
||||
messages: List[Union[ChatMessage, Dict[str, str]]],
|
||||
model: str = "gpt-3.5-turbo",
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None,
|
||||
**kwargs
|
||||
) -> AsyncGenerator[ChatCompletionChunk, None]:
|
||||
"""
|
||||
流式聊天完成
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
model: 模型名称
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大token数
|
||||
**kwargs: 其他参数
|
||||
|
||||
Yields:
|
||||
ChatCompletionChunk: 流式响应块
|
||||
"""
|
||||
# 转换消息格式
|
||||
formatted_messages = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, ChatMessage):
|
||||
formatted_messages.append(msg.to_dict())
|
||||
else:
|
||||
formatted_messages.append(msg)
|
||||
|
||||
extra_body = {}
|
||||
if kwargs.get("enable_thinking") is not None:
|
||||
extra_body["enable_thinking"] = kwargs.pop("enable_thinking")
|
||||
if kwargs.get("thinking_budget") is not None:
|
||||
extra_body["thinking_budget"] = kwargs.pop("thinking_budget")
|
||||
|
||||
stream = await self.client.chat.completions.create(
|
||||
model=model,
|
||||
messages=formatted_messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=True,
|
||||
extra_body=extra_body if extra_body else None,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
async for chunk in stream:
|
||||
yield chunk
|
||||
|
||||
async def get_stream_content(
|
||||
self,
|
||||
messages: List[Union[ChatMessage, Dict[str, str]]],
|
||||
model: str = "gpt-3.5-turbo",
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None,
|
||||
**kwargs
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
获取流式内容(只返回文本内容)
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
model: 模型名称
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大token数
|
||||
**kwargs: 其他参数
|
||||
|
||||
Yields:
|
||||
str: 文本内容片段
|
||||
"""
|
||||
async for chunk in self.chat_completion_stream(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
**kwargs
|
||||
):
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
yield chunk.choices[0].delta.content
|
||||
|
||||
async def collect_stream_response(
|
||||
self,
|
||||
messages: List[Union[ChatMessage, Dict[str, str]]],
|
||||
model: str = "gpt-3.5-turbo",
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None,
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
收集完整的流式响应
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
model: 模型名称
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大token数
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
str: 完整的响应文本
|
||||
"""
|
||||
full_response = ""
|
||||
async for content in self.get_stream_content(
|
||||
messages=messages,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
**kwargs
|
||||
):
|
||||
full_response += content
|
||||
|
||||
return full_response
|
||||
|
||||
async def close(self):
|
||||
"""关闭客户端"""
|
||||
await self.client.close()
|
||||
|
||||
async def __aenter__(self):
|
||||
"""异步上下文管理器入口"""
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""异步上下文管理器退出"""
|
||||
await self.close()
|
||||
|
||||
|
||||
class ConversationManager:
|
||||
"""对话管理器,用于管理对话历史"""
|
||||
|
||||
def __init__(self, client: AsyncOpenAIClient, system_prompt: Optional[str] = None):
|
||||
"""
|
||||
初始化对话管理器
|
||||
|
||||
Args:
|
||||
client: OpenAI客户端实例
|
||||
system_prompt: 系统提示词
|
||||
"""
|
||||
self.client = client
|
||||
self.messages: List[ChatMessage] = []
|
||||
|
||||
if system_prompt:
|
||||
self.messages.append(ChatMessage(role="system", content=system_prompt))
|
||||
|
||||
def add_user_message(self, content: str):
|
||||
"""添加用户消息"""
|
||||
self.messages.append(ChatMessage(role="user", content=content))
|
||||
|
||||
def add_assistant_message(self, content: str):
|
||||
"""添加助手消息"""
|
||||
self.messages.append(ChatMessage(role="assistant", content=content))
|
||||
|
||||
async def send_message_stream(
|
||||
self,
|
||||
content: str,
|
||||
model: str = "gpt-3.5-turbo",
|
||||
**kwargs
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
发送消息并获取流式响应
|
||||
|
||||
Args:
|
||||
content: 用户消息内容
|
||||
model: 模型名称
|
||||
**kwargs: 其他参数
|
||||
|
||||
Yields:
|
||||
str: 响应内容片段
|
||||
"""
|
||||
self.add_user_message(content)
|
||||
|
||||
response_content = ""
|
||||
async for chunk in self.client.get_stream_content(
|
||||
messages=self.messages,
|
||||
model=model,
|
||||
**kwargs
|
||||
):
|
||||
response_content += chunk
|
||||
yield chunk
|
||||
|
||||
self.add_assistant_message(response_content)
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
content: str,
|
||||
model: str = "gpt-3.5-turbo",
|
||||
**kwargs
|
||||
) -> str:
|
||||
"""
|
||||
发送消息并获取完整响应
|
||||
|
||||
Args:
|
||||
content: 用户消息内容
|
||||
model: 模型名称
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
str: 完整响应
|
||||
"""
|
||||
self.add_user_message(content)
|
||||
|
||||
response = await self.client.chat_completion(
|
||||
messages=self.messages,
|
||||
model=model,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
response_content = response.choices[0].message.content
|
||||
self.add_assistant_message(response_content)
|
||||
|
||||
return response_content
|
||||
|
||||
def clear_history(self, keep_system: bool = True):
|
||||
"""
|
||||
清除对话历史
|
||||
|
||||
Args:
|
||||
keep_system: 是否保留系统消息
|
||||
"""
|
||||
if keep_system and self.messages and self.messages[0].role == "system":
|
||||
self.messages = [self.messages[0]]
|
||||
else:
|
||||
self.messages = []
|
||||
|
||||
def get_message_count(self) -> int:
|
||||
"""获取消息数量"""
|
||||
return len(self.messages)
|
||||
|
||||
def get_conversation_history(self) -> List[Dict[str, str]]:
|
||||
"""获取对话历史"""
|
||||
return [msg.to_dict() for msg in self.messages]
|
||||
Reference in New Issue
Block a user