feat:增加了reply_to新message属性,优化prompt,切割
This commit is contained in:
@@ -283,6 +283,7 @@ class MessageSending(MessageProcessBase):
|
|||||||
is_emoji: bool = False,
|
is_emoji: bool = False,
|
||||||
thinking_start_time: float = 0,
|
thinking_start_time: float = 0,
|
||||||
apply_set_reply_logic: bool = False,
|
apply_set_reply_logic: bool = False,
|
||||||
|
reply_to: str = None,
|
||||||
):
|
):
|
||||||
# 调用父类初始化
|
# 调用父类初始化
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -300,6 +301,8 @@ class MessageSending(MessageProcessBase):
|
|||||||
self.is_head = is_head
|
self.is_head = is_head
|
||||||
self.is_emoji = is_emoji
|
self.is_emoji = is_emoji
|
||||||
self.apply_set_reply_logic = apply_set_reply_logic
|
self.apply_set_reply_logic = apply_set_reply_logic
|
||||||
|
|
||||||
|
self.reply_to = reply_to
|
||||||
|
|
||||||
# 用于显示发送内容与显示不一致的情况
|
# 用于显示发送内容与显示不一致的情况
|
||||||
self.display_message = display_message
|
self.display_message = display_message
|
||||||
|
|||||||
@@ -35,8 +35,12 @@ class MessageStorage:
|
|||||||
filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
|
filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
|
||||||
else:
|
else:
|
||||||
filtered_display_message = ""
|
filtered_display_message = ""
|
||||||
|
|
||||||
|
reply_to = message.reply_to
|
||||||
else:
|
else:
|
||||||
filtered_display_message = ""
|
filtered_display_message = ""
|
||||||
|
|
||||||
|
reply_to = ""
|
||||||
|
|
||||||
chat_info_dict = chat_stream.to_dict()
|
chat_info_dict = chat_stream.to_dict()
|
||||||
user_info_dict = message.message_info.user_info.to_dict()
|
user_info_dict = message.message_info.user_info.to_dict()
|
||||||
@@ -54,6 +58,7 @@ class MessageStorage:
|
|||||||
time=float(message.message_info.time),
|
time=float(message.message_info.time),
|
||||||
chat_id=chat_stream.stream_id,
|
chat_id=chat_stream.stream_id,
|
||||||
# Flattened chat_info
|
# Flattened chat_info
|
||||||
|
reply_to=reply_to,
|
||||||
chat_info_stream_id=chat_info_dict.get("stream_id"),
|
chat_info_stream_id=chat_info_dict.get("stream_id"),
|
||||||
chat_info_platform=chat_info_dict.get("platform"),
|
chat_info_platform=chat_info_dict.get("platform"),
|
||||||
chat_info_user_platform=user_info_from_chat.get("platform"),
|
chat_info_user_platform=user_info_from_chat.get("platform"),
|
||||||
|
|||||||
@@ -126,6 +126,8 @@ class Messages(BaseModel):
|
|||||||
time = DoubleField() # 消息时间戳
|
time = DoubleField() # 消息时间戳
|
||||||
|
|
||||||
chat_id = TextField(index=True) # 对应的 ChatStreams stream_id
|
chat_id = TextField(index=True) # 对应的 ChatStreams stream_id
|
||||||
|
|
||||||
|
reply_to = TextField(null=True)
|
||||||
|
|
||||||
# 从 chat_info 扁平化而来的字段
|
# 从 chat_info 扁平化而来的字段
|
||||||
chat_info_stream_id = TextField()
|
chat_info_stream_id = TextField()
|
||||||
|
|||||||
@@ -92,7 +92,8 @@ class MessageSenderContainer:
|
|||||||
# Check for pause signal *after* getting an item.
|
# Check for pause signal *after* getting an item.
|
||||||
await self._paused_event.wait()
|
await self._paused_event.wait()
|
||||||
|
|
||||||
delay = self._calculate_typing_delay(chunk)
|
# delay = self._calculate_typing_delay(chunk)
|
||||||
|
delay = 0.1
|
||||||
await asyncio.sleep(delay)
|
await asyncio.sleep(delay)
|
||||||
|
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
@@ -116,6 +117,7 @@ class MessageSenderContainer:
|
|||||||
reply=self.original_message,
|
reply=self.original_message,
|
||||||
is_emoji=False,
|
is_emoji=False,
|
||||||
apply_set_reply_logic=True,
|
apply_set_reply_logic=True,
|
||||||
|
reply_to=f"{self.original_message.message_info.user_info.platform}:{self.original_message.message_info.user_info.user_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
await bot_message.process()
|
await bot_message.process()
|
||||||
@@ -171,22 +173,13 @@ class S4UChat:
|
|||||||
self._message_queue = asyncio.Queue()
|
self._message_queue = asyncio.Queue()
|
||||||
self._processing_task = asyncio.create_task(self._message_processor())
|
self._processing_task = asyncio.create_task(self._message_processor())
|
||||||
self._current_generation_task: Optional[asyncio.Task] = None
|
self._current_generation_task: Optional[asyncio.Task] = None
|
||||||
|
self._current_message_being_replied: Optional[MessageRecv] = None
|
||||||
|
|
||||||
self._is_replying = False
|
self._is_replying = False
|
||||||
|
|
||||||
# 初始化Normal Chat专用表达器
|
|
||||||
self.expressor = NormalChatExpressor(self.chat_stream)
|
|
||||||
self.replyer = DefaultReplyer(self.chat_stream)
|
|
||||||
|
|
||||||
self.gpt = S4UStreamGenerator()
|
self.gpt = S4UStreamGenerator()
|
||||||
self.audio_generator = MockAudioGenerator()
|
# self.audio_generator = MockAudioGenerator()
|
||||||
self.start_time = time.time()
|
|
||||||
|
|
||||||
# 记录最近的回复内容,每项包含: {time, user_message, response, is_mentioned, is_reference_reply}
|
|
||||||
self.recent_replies = []
|
|
||||||
self.max_replies_history = 20 # 最多保存最近20条回复记录
|
|
||||||
|
|
||||||
self.storage = MessageStorage()
|
|
||||||
|
|
||||||
|
|
||||||
logger.info(f"[{self.stream_name}] S4UChat")
|
logger.info(f"[{self.stream_name}] S4UChat")
|
||||||
@@ -194,11 +187,32 @@ class S4UChat:
|
|||||||
|
|
||||||
# 改为实例方法, 移除 chat 参数
|
# 改为实例方法, 移除 chat 参数
|
||||||
async def response(self, message: MessageRecv, is_mentioned: bool, interested_rate: float) -> None:
|
async def response(self, message: MessageRecv, is_mentioned: bool, interested_rate: float) -> None:
|
||||||
"""将消息放入队列并中断当前处理(如果正在处理)。"""
|
"""将消息放入队列并根据发信人决定是否中断当前处理。"""
|
||||||
|
should_interrupt = False
|
||||||
if self._current_generation_task and not self._current_generation_task.done():
|
if self._current_generation_task and not self._current_generation_task.done():
|
||||||
|
if self._current_message_being_replied:
|
||||||
|
# 检查新消息发送者和正在回复的消息发送者是否为同一人
|
||||||
|
new_sender_id = message.message_info.user_info.user_id
|
||||||
|
original_sender_id = self._current_message_being_replied.message_info.user_info.user_id
|
||||||
|
|
||||||
|
if new_sender_id == original_sender_id:
|
||||||
|
should_interrupt = True
|
||||||
|
logger.info(f"[{self.stream_name}] 来自同一用户的消息,中断当前回复。")
|
||||||
|
else:
|
||||||
|
if random.random() < 0.2:
|
||||||
|
should_interrupt = True
|
||||||
|
logger.info(f"[{self.stream_name}] 来自不同用户的消息,随机中断(20%)。")
|
||||||
|
else:
|
||||||
|
logger.info(f"[{self.stream_name}] 来自不同用户的消息,不中断。")
|
||||||
|
else:
|
||||||
|
# Fallback: if we don't know who we are replying to, interrupt.
|
||||||
|
should_interrupt = True
|
||||||
|
logger.warning(f"[{self.stream_name}] 正在生成回复,但无法获取原始消息发送者信息,将默认中断。")
|
||||||
|
|
||||||
|
if should_interrupt:
|
||||||
self._current_generation_task.cancel()
|
self._current_generation_task.cancel()
|
||||||
logger.info(f"[{self.stream_name}] 请求中断当前回复生成任务。")
|
logger.info(f"[{self.stream_name}] 请求中断当前回复生成任务。")
|
||||||
|
|
||||||
await self._message_queue.put(message)
|
await self._message_queue.put(message)
|
||||||
|
|
||||||
async def _message_processor(self):
|
async def _message_processor(self):
|
||||||
@@ -207,12 +221,14 @@ class S4UChat:
|
|||||||
try:
|
try:
|
||||||
# 等待第一条消息
|
# 等待第一条消息
|
||||||
message = await self._message_queue.get()
|
message = await self._message_queue.get()
|
||||||
|
self._current_message_being_replied = message
|
||||||
|
|
||||||
# 如果因快速中断导致队列中积压了更多消息,则只处理最新的一条
|
# 如果因快速中断导致队列中积压了更多消息,则只处理最新的一条
|
||||||
while not self._message_queue.empty():
|
while not self._message_queue.empty():
|
||||||
drained_msg = self._message_queue.get_nowait()
|
drained_msg = self._message_queue.get_nowait()
|
||||||
self._message_queue.task_done() # 为取出的旧消息调用 task_done
|
self._message_queue.task_done() # 为取出的旧消息调用 task_done
|
||||||
message = drained_msg # 始终处理最新消息
|
message = drained_msg # 始终处理最新消息
|
||||||
|
self._current_message_being_replied = message
|
||||||
logger.info(f"[{self.stream_name}] 丢弃过时消息,处理最新消息: {message.processed_plain_text}")
|
logger.info(f"[{self.stream_name}] 丢弃过时消息,处理最新消息: {message.processed_plain_text}")
|
||||||
|
|
||||||
self._current_generation_task = asyncio.create_task(self._generate_and_send(message))
|
self._current_generation_task = asyncio.create_task(self._generate_and_send(message))
|
||||||
@@ -225,6 +241,7 @@ class S4UChat:
|
|||||||
logger.error(f"[{self.stream_name}] _generate_and_send 任务出现错误: {e}", exc_info=True)
|
logger.error(f"[{self.stream_name}] _generate_and_send 任务出现错误: {e}", exc_info=True)
|
||||||
finally:
|
finally:
|
||||||
self._current_generation_task = None
|
self._current_generation_task = None
|
||||||
|
self._current_message_being_replied = None
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.info(f"[{self.stream_name}] 消息处理器正在关闭。")
|
logger.info(f"[{self.stream_name}] 消息处理器正在关闭。")
|
||||||
@@ -259,10 +276,10 @@ class S4UChat:
|
|||||||
await sender_container.add_message(chunk)
|
await sender_container.add_message(chunk)
|
||||||
|
|
||||||
# b. 为该文本块生成并播放音频
|
# b. 为该文本块生成并播放音频
|
||||||
if chunk.strip():
|
# if chunk.strip():
|
||||||
audio_data = await self.audio_generator.generate(chunk)
|
# audio_data = await self.audio_generator.generate(chunk)
|
||||||
player = MockAudioPlayer(audio_data)
|
# player = MockAudioPlayer(audio_data)
|
||||||
await player.play()
|
# await player.play()
|
||||||
|
|
||||||
# 等待所有文本消息发送完成
|
# 等待所有文本消息发送完成
|
||||||
await sender_container.close()
|
await sender_container.close()
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ class S4UMessageProcessor:
|
|||||||
message_data: 原始消息字符串
|
message_data: 原始消息字符串
|
||||||
"""
|
"""
|
||||||
|
|
||||||
target_user_id = "1026294844"
|
target_user_id_list = ["1026294844", "964959351"]
|
||||||
|
|
||||||
# 1. 消息解析与初始化
|
# 1. 消息解析与初始化
|
||||||
groupinfo = message.message_info.group_info
|
groupinfo = message.message_info.group_info
|
||||||
@@ -61,9 +61,10 @@ class S4UMessageProcessor:
|
|||||||
is_mentioned = is_mentioned_bot_in_message(message)
|
is_mentioned = is_mentioned_bot_in_message(message)
|
||||||
s4u_chat = get_s4u_chat_manager().get_or_create_chat(chat)
|
s4u_chat = get_s4u_chat_manager().get_or_create_chat(chat)
|
||||||
|
|
||||||
if userinfo.user_id == target_user_id:
|
if userinfo.user_id in target_user_id_list:
|
||||||
await s4u_chat.response(message, is_mentioned=is_mentioned, interested_rate=1.0)
|
await s4u_chat.response(message, is_mentioned=is_mentioned, interested_rate=1.0)
|
||||||
|
else:
|
||||||
|
await s4u_chat.response(message, is_mentioned=is_mentioned, interested_rate=0.0)
|
||||||
|
|
||||||
# 7. 日志记录
|
# 7. 日志记录
|
||||||
logger.info(f"[S4U]{userinfo.user_nickname}:{message.processed_plain_text}")
|
logger.info(f"[S4U]{userinfo.user_nickname}:{message.processed_plain_text}")
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ def init_prompt():
|
|||||||
Prompt(
|
Prompt(
|
||||||
"""
|
"""
|
||||||
你的名字叫{bot_name},昵称是:{bot_other_names},{prompt_personality}。
|
你的名字叫{bot_name},昵称是:{bot_other_names},{prompt_personality}。
|
||||||
你现在的主要任务是和 {sender_name} 聊天。同时,也有其他用户会参与你们的聊天,但是你主要还是关注你和{sender_name}的聊天内容。
|
你现在的主要任务是和 {sender_name} 聊天。同时,也有其他用户会参与你们的聊天,你可以参考他们的回复内容,但是你主要还是关注你和{sender_name}的聊天内容。
|
||||||
|
|
||||||
{background_dialogue_prompt}
|
{background_dialogue_prompt}
|
||||||
--------------------------------
|
--------------------------------
|
||||||
@@ -35,10 +35,13 @@ def init_prompt():
|
|||||||
这是你和{sender_name}的对话,你们正在交流中:
|
这是你和{sender_name}的对话,你们正在交流中:
|
||||||
{core_dialogue_prompt}
|
{core_dialogue_prompt}
|
||||||
|
|
||||||
{message_txt}
|
对方最新发送的内容:{message_txt}
|
||||||
回复可以简短一些。可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞,平淡一些。
|
回复可以简短一些。可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞,平淡一些。
|
||||||
不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出回复内容,现在{sender_name}正在等待你的回复。
|
不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出回复内容,现在{sender_name}正在等待你的回复。
|
||||||
你的回复风格不要浮夸,有逻辑和条理,请你继续回复{sender_name}。""",
|
你的回复风格不要浮夸,有逻辑和条理,请你继续回复{sender_name}。
|
||||||
|
你的发言:
|
||||||
|
|
||||||
|
""",
|
||||||
"s4u_prompt", # New template for private CHAT chat
|
"s4u_prompt", # New template for private CHAT chat
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -96,19 +99,29 @@ class PromptBuilder:
|
|||||||
limit=100,
|
limit=100,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
talk_type = message.message_info.platform + ":" + message.chat_stream.user_info.user_id
|
||||||
|
print(f"talk_type: {talk_type}")
|
||||||
|
|
||||||
|
|
||||||
# 分别筛选核心对话和背景对话
|
# 分别筛选核心对话和背景对话
|
||||||
core_dialogue_list = []
|
core_dialogue_list = []
|
||||||
background_dialogue_list = []
|
background_dialogue_list = []
|
||||||
bot_id = str(global_config.bot.qq_account)
|
bot_id = str(global_config.bot.qq_account)
|
||||||
target_user_id = str(message.chat_stream.user_info.user_id)
|
target_user_id = str(message.chat_stream.user_info.user_id)
|
||||||
|
|
||||||
|
|
||||||
for msg_dict in message_list_before_now:
|
for msg_dict in message_list_before_now:
|
||||||
try:
|
try:
|
||||||
# 直接通过字典访问
|
# 直接通过字典访问
|
||||||
msg_user_id = str(msg_dict.get('user_id'))
|
msg_user_id = str(msg_dict.get('user_id'))
|
||||||
|
if msg_user_id == bot_id:
|
||||||
if msg_user_id == bot_id or msg_user_id == target_user_id:
|
if msg_dict.get("reply_to") and talk_type == msg_dict.get("reply_to"):
|
||||||
|
print(f"reply: {msg_dict.get('reply_to')}")
|
||||||
|
core_dialogue_list.append(msg_dict)
|
||||||
|
else:
|
||||||
|
background_dialogue_list.append(msg_dict)
|
||||||
|
elif msg_user_id == target_user_id:
|
||||||
core_dialogue_list.append(msg_dict)
|
core_dialogue_list.append(msg_dict)
|
||||||
else:
|
else:
|
||||||
background_dialogue_list.append(msg_dict)
|
background_dialogue_list.append(msg_dict)
|
||||||
@@ -140,14 +153,14 @@ class PromptBuilder:
|
|||||||
last_speaking_user_id = start_speaking_user_id
|
last_speaking_user_id = start_speaking_user_id
|
||||||
msg_seg_str = "对方的发言:\n"
|
msg_seg_str = "对方的发言:\n"
|
||||||
|
|
||||||
msg_seg_str += f"{first_msg.get('processed_plain_text')}\n"
|
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(first_msg.get('time')))}: {first_msg.get('processed_plain_text')}\n"
|
||||||
|
|
||||||
all_msg_seg_list = []
|
all_msg_seg_list = []
|
||||||
for msg in core_dialogue_list[1:]:
|
for msg in core_dialogue_list[1:]:
|
||||||
speaker = msg.get('user_id')
|
speaker = msg.get('user_id')
|
||||||
if speaker == last_speaking_user_id:
|
if speaker == last_speaking_user_id:
|
||||||
#还是同一个人讲话
|
#还是同一个人讲话
|
||||||
msg_seg_str += f"{msg.get('processed_plain_text')}\n"
|
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(msg.get('time')))}: {msg.get('processed_plain_text')}\n"
|
||||||
else:
|
else:
|
||||||
#换人了
|
#换人了
|
||||||
msg_seg_str = f"{msg_seg_str}\n"
|
msg_seg_str = f"{msg_seg_str}\n"
|
||||||
@@ -158,7 +171,7 @@ class PromptBuilder:
|
|||||||
else:
|
else:
|
||||||
msg_seg_str = "对方的发言:\n"
|
msg_seg_str = "对方的发言:\n"
|
||||||
|
|
||||||
msg_seg_str += f"{msg.get('processed_plain_text')}\n"
|
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(msg.get('time')))}: {msg.get('processed_plain_text')}\n"
|
||||||
last_speaking_user_id = speaker
|
last_speaking_user_id = speaker
|
||||||
|
|
||||||
all_msg_seg_list.append(msg_seg_str)
|
all_msg_seg_list.append(msg_seg_str)
|
||||||
|
|||||||
@@ -112,6 +112,7 @@ class S4UStreamGenerator:
|
|||||||
|
|
||||||
buffer = ""
|
buffer = ""
|
||||||
delimiters = ",。!?,.!?\n\r" # For final trimming
|
delimiters = ",。!?,.!?\n\r" # For final trimming
|
||||||
|
punctuation_buffer = ""
|
||||||
|
|
||||||
async for content in client.get_stream_content(
|
async for content in client.get_stream_content(
|
||||||
messages=[{"role": "user", "content": prompt}], model=model_name, **kwargs
|
messages=[{"role": "user", "content": prompt}], model=model_name, **kwargs
|
||||||
@@ -125,8 +126,19 @@ class S4UStreamGenerator:
|
|||||||
if sentence:
|
if sentence:
|
||||||
# 如果句子看起来完整(即不只是等待更多内容),则发送
|
# 如果句子看起来完整(即不只是等待更多内容),则发送
|
||||||
if match.end(0) < len(buffer) or sentence.endswith(tuple(delimiters)):
|
if match.end(0) < len(buffer) or sentence.endswith(tuple(delimiters)):
|
||||||
yield sentence
|
# 检查是否只是一个标点符号
|
||||||
await asyncio.sleep(0) # 允许其他任务运行
|
if sentence in [",", ",", ".", "。", "!", "!", "?", "?"]:
|
||||||
|
punctuation_buffer += sentence
|
||||||
|
else:
|
||||||
|
# 发送之前累积的标点和当前句子
|
||||||
|
to_yield = punctuation_buffer + sentence
|
||||||
|
if to_yield.endswith((',', ',')):
|
||||||
|
to_yield = to_yield.rstrip(',,')
|
||||||
|
|
||||||
|
yield to_yield
|
||||||
|
punctuation_buffer = "" # 清空标点符号缓冲区
|
||||||
|
await asyncio.sleep(0) # 允许其他任务运行
|
||||||
|
|
||||||
last_match_end = match.end(0)
|
last_match_end = match.end(0)
|
||||||
|
|
||||||
# 从缓冲区移除已发送的部分
|
# 从缓冲区移除已发送的部分
|
||||||
@@ -134,7 +146,10 @@ class S4UStreamGenerator:
|
|||||||
buffer = buffer[last_match_end:]
|
buffer = buffer[last_match_end:]
|
||||||
|
|
||||||
# 发送缓冲区中剩余的任何内容
|
# 发送缓冲区中剩余的任何内容
|
||||||
if buffer.strip():
|
to_yield = (punctuation_buffer + buffer).strip()
|
||||||
yield buffer.strip()
|
if to_yield:
|
||||||
await asyncio.sleep(0)
|
if to_yield.endswith((',', ',')):
|
||||||
|
to_yield = to_yield.rstrip(',,')
|
||||||
|
if to_yield:
|
||||||
|
yield to_yield
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user