🤖 自动格式化代码 [skip ci]

This commit is contained in:
github-actions[bot]
2025-07-01 11:33:16 +00:00
parent 3ef3923a8b
commit 324b294b5f
12 changed files with 157 additions and 225 deletions

1
bot.py
View File

@@ -326,7 +326,6 @@ if __name__ == "__main__":
# Wait for all tasks to complete (which they won't, normally) # Wait for all tasks to complete (which they won't, normally)
loop.run_until_complete(main_tasks) loop.run_until_complete(main_tasks)
except KeyboardInterrupt: except KeyboardInterrupt:
# loop.run_until_complete(get_global_api().stop()) # loop.run_until_complete(get_global_api().stop())
logger.warning("收到中断信号,正在优雅关闭...") logger.warning("收到中断信号,正在优雅关闭...")

View File

@@ -3,10 +3,12 @@ from src.common.logger import get_logger
logger = get_logger("MockAudio") logger = get_logger("MockAudio")
class MockAudioPlayer: class MockAudioPlayer:
""" """
一个模拟的音频播放器,它会根据音频数据的"长度"来模拟播放时间。 一个模拟的音频播放器,它会根据音频数据的"长度"来模拟播放时间。
""" """
def __init__(self, audio_data: bytes): def __init__(self, audio_data: bytes):
self._audio_data = audio_data self._audio_data = audio_data
# 模拟音频时长:假设每 1024 字节代表 0.5 秒的音频 # 模拟音频时长:假设每 1024 字节代表 0.5 秒的音频
@@ -22,12 +24,14 @@ class MockAudioPlayer:
logger.info("模拟音频播放完毕。") logger.info("模拟音频播放完毕。")
except asyncio.CancelledError: except asyncio.CancelledError:
logger.info("音频播放被中断。") logger.info("音频播放被中断。")
raise # 重新抛出异常,以便上层逻辑可以捕获它 raise # 重新抛出异常,以便上层逻辑可以捕获它
class MockAudioGenerator: class MockAudioGenerator:
""" """
一个模拟的文本到语音TTS生成器。 一个模拟的文本到语音TTS生成器。
""" """
def __init__(self): def __init__(self):
# 模拟生成速度:每秒生成的字符数 # 模拟生成速度:每秒生成的字符数
self.chars_per_second = 25.0 self.chars_per_second = 25.0
@@ -43,16 +47,16 @@ class MockAudioGenerator:
模拟的音频数据bytes 模拟的音频数据bytes
""" """
if not text: if not text:
return b'' return b""
generation_time = len(text) / self.chars_per_second generation_time = len(text) / self.chars_per_second
logger.info(f"模拟生成音频... 文本长度: {len(text)}, 预计耗时: {generation_time:.2f} 秒...") logger.info(f"模拟生成音频... 文本长度: {len(text)}, 预计耗时: {generation_time:.2f} 秒...")
try: try:
await asyncio.sleep(generation_time) await asyncio.sleep(generation_time)
# 生成虚拟的音频数据,其长度与文本长度成正比 # 生成虚拟的音频数据,其长度与文本长度成正比
mock_audio_data = b'\x01\x02\x03' * (len(text) * 40) mock_audio_data = b"\x01\x02\x03" * (len(text) * 40)
logger.info(f"模拟音频生成完毕,数据大小: {len(mock_audio_data) / 1024:.2f} KB。") logger.info(f"模拟音频生成完毕,数据大小: {len(mock_audio_data) / 1024:.2f} KB。")
return mock_audio_data return mock_audio_data
except asyncio.CancelledError: except asyncio.CancelledError:
logger.info("音频生成被中断。") logger.info("音频生成被中断。")
raise # 重新抛出异常 raise # 重新抛出异常

View File

@@ -180,14 +180,12 @@ class ChatBot:
# 如果在私聊中 # 如果在私聊中
if group_info is None: if group_info is None:
logger.debug("检测到私聊消息") logger.debug("检测到私聊消息")
if ENABLE_S4U_CHAT: if ENABLE_S4U_CHAT:
logger.debug("进入S4U私聊处理流程") logger.debug("进入S4U私聊处理流程")
await self.s4u_message_processor.process_message(message) await self.s4u_message_processor.process_message(message)
return return
if global_config.experimental.pfc_chatting: if global_config.experimental.pfc_chatting:
logger.debug("进入PFC私聊处理流程") logger.debug("进入PFC私聊处理流程")
# 创建聊天流 # 创建聊天流
@@ -200,13 +198,11 @@ class ChatBot:
await self.heartflow_message_receiver.process_message(message) await self.heartflow_message_receiver.process_message(message)
# 群聊默认进入心流消息处理逻辑 # 群聊默认进入心流消息处理逻辑
else: else:
if ENABLE_S4U_CHAT: if ENABLE_S4U_CHAT:
logger.debug("进入S4U私聊处理流程") logger.debug("进入S4U私聊处理流程")
await self.s4u_message_processor.process_message(message) await self.s4u_message_processor.process_message(message)
return return
logger.debug(f"检测到群聊消息群ID: {group_info.group_id}") logger.debug(f"检测到群聊消息群ID: {group_info.group_id}")
await self.heartflow_message_receiver.process_message(message) await self.heartflow_message_receiver.process_message(message)

View File

@@ -323,7 +323,7 @@ class MessageSending(MessageProcessBase):
self.is_head = is_head self.is_head = is_head
self.is_emoji = is_emoji self.is_emoji = is_emoji
self.apply_set_reply_logic = apply_set_reply_logic self.apply_set_reply_logic = apply_set_reply_logic
self.reply_to = reply_to self.reply_to = reply_to
# 用于显示发送内容与显示不一致的情况 # 用于显示发送内容与显示不一致的情况

View File

@@ -35,11 +35,11 @@ class MessageStorage:
filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL) filtered_display_message = re.sub(pattern, "", display_message, flags=re.DOTALL)
else: else:
filtered_display_message = "" filtered_display_message = ""
reply_to = message.reply_to reply_to = message.reply_to
else: else:
filtered_display_message = "" filtered_display_message = ""
reply_to = "" reply_to = ""
chat_info_dict = chat_stream.to_dict() chat_info_dict = chat_stream.to_dict()

View File

@@ -263,7 +263,6 @@ def _build_readable_messages_internal(
# 处理图片ID # 处理图片ID
if show_pic: if show_pic:
content = process_pic_ids(content) content = process_pic_ids(content)
# 检查必要信息是否存在 # 检查必要信息是否存在
if not all([platform, user_id, timestamp is not None]): if not all([platform, user_id, timestamp is not None]):
@@ -632,10 +631,17 @@ def build_readable_messages(
truncate, truncate,
pic_id_mapping, pic_id_mapping,
pic_counter, pic_counter,
show_pic=show_pic show_pic=show_pic,
) )
formatted_after, _, pic_id_mapping, _ = _build_readable_messages_internal( formatted_after, _, pic_id_mapping, _ = _build_readable_messages_internal(
messages_after_mark, replace_bot_name, merge_messages, timestamp_mode, False, pic_id_mapping, pic_counter, show_pic=show_pic messages_after_mark,
replace_bot_name,
merge_messages,
timestamp_mode,
False,
pic_id_mapping,
pic_counter,
show_pic=show_pic,
) )
read_mark_line = "\n--- 以上消息是你已经看过,请关注以下未读的新消息---\n" read_mark_line = "\n--- 以上消息是你已经看过,请关注以下未读的新消息---\n"

View File

@@ -126,7 +126,7 @@ class Messages(BaseModel):
time = DoubleField() # 消息时间戳 time = DoubleField() # 消息时间戳
chat_id = TextField(index=True) # 对应的 ChatStreams stream_id chat_id = TextField(index=True) # 对应的 ChatStreams stream_id
reply_to = TextField(null=True) reply_to = TextField(null=True)
# 从 chat_info 扁平化而来的字段 # 从 chat_info 扁平化而来的字段

View File

@@ -1,39 +1,15 @@
import asyncio import asyncio
import time import time
import traceback
import random import random
from typing import List, Optional, Dict # 导入类型提示 from typing import Optional, Dict # 导入类型提示
import os
import pickle
from maim_message import UserInfo, Seg from maim_message import UserInfo, Seg
from src.common.logger import get_logger from src.common.logger import get_logger
from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info
from src.manager.mood_manager import mood_manager
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.utils.timer_calculator import Timer
from src.chat.utils.prompt_builder import global_prompt_manager
from .s4u_stream_generator import S4UStreamGenerator from .s4u_stream_generator import S4UStreamGenerator
from src.chat.message_receive.message import MessageSending, MessageRecv, MessageThinking, MessageSet from src.chat.message_receive.message import MessageSending, MessageRecv
from src.chat.message_receive.message_sender import message_manager
from src.chat.normal_chat.willing.willing_manager import get_willing_manager
from src.chat.normal_chat.normal_chat_utils import get_recent_message_stats
from src.config.config import global_config from src.config.config import global_config
from src.chat.focus_chat.planners.action_manager import ActionManager
from src.chat.normal_chat.normal_chat_planner import NormalChatPlanner
from src.chat.normal_chat.normal_chat_action_modifier import NormalChatActionModifier
from src.chat.normal_chat.normal_chat_expressor import NormalChatExpressor
from src.chat.focus_chat.replyer.default_generator import DefaultReplyer
from src.person_info.person_info import PersonInfoManager
from src.person_info.relationship_manager import get_relationship_manager
from src.chat.utils.chat_message_builder import (
get_raw_msg_by_timestamp_with_chat,
get_raw_msg_by_timestamp_with_chat_inclusive,
get_raw_msg_before_timestamp_with_chat,
num_new_messages_since,
)
from src.common.message.api import get_global_api from src.common.message.api import get_global_api
from src.chat.message_receive.storage import MessageStorage from src.chat.message_receive.storage import MessageStorage
from src.audio.mock_audio import MockAudioGenerator, MockAudioPlayer
logger = get_logger("S4U_chat") logger = get_logger("S4U_chat")
@@ -41,6 +17,7 @@ logger = get_logger("S4U_chat")
class MessageSenderContainer: class MessageSenderContainer:
"""一个简单的容器,用于按顺序发送消息并模拟打字效果。""" """一个简单的容器,用于按顺序发送消息并模拟打字效果。"""
def __init__(self, chat_stream: ChatStream, original_message: MessageRecv): def __init__(self, chat_stream: ChatStream, original_message: MessageRecv):
self.chat_stream = chat_stream self.chat_stream = chat_stream
self.original_message = original_message self.original_message = original_message
@@ -71,7 +48,7 @@ class MessageSenderContainer:
chars_per_second = 15.0 chars_per_second = 15.0
min_delay = 0.2 min_delay = 0.2
max_delay = 2.0 max_delay = 2.0
delay = len(text) / chars_per_second delay = len(text) / chars_per_second
return max(min_delay, min(delay, max_delay)) return max(min_delay, min(delay, max_delay))
@@ -98,7 +75,7 @@ class MessageSenderContainer:
current_time = time.time() current_time = time.time()
msg_id = f"{current_time}_{random.randint(1000, 9999)}" msg_id = f"{current_time}_{random.randint(1000, 9999)}"
text_to_send = chunk text_to_send = chunk
if global_config.experimental.debug_show_chat_mode: if global_config.experimental.debug_show_chat_mode:
text_to_send += "" text_to_send += ""
@@ -117,19 +94,19 @@ class MessageSenderContainer:
reply=self.original_message, reply=self.original_message,
is_emoji=False, is_emoji=False,
apply_set_reply_logic=True, apply_set_reply_logic=True,
reply_to=f"{self.original_message.message_info.user_info.platform}:{self.original_message.message_info.user_info.user_id}" reply_to=f"{self.original_message.message_info.user_info.platform}:{self.original_message.message_info.user_info.user_id}",
) )
await bot_message.process() await bot_message.process()
await get_global_api().send_message(bot_message) await get_global_api().send_message(bot_message)
logger.info(f"已将消息 '{text_to_send}' 发往平台 '{bot_message.message_info.platform}'") logger.info(f"已将消息 '{text_to_send}' 发往平台 '{bot_message.message_info.platform}'")
await self.storage.store_message(bot_message, self.chat_stream) await self.storage.store_message(bot_message, self.chat_stream)
except Exception as e: except Exception as e:
logger.error(f"[{self.chat_stream.get_stream_name()}] 消息发送或存储时出现错误: {e}", exc_info=True) logger.error(f"[{self.chat_stream.get_stream_name()}] 消息发送或存储时出现错误: {e}", exc_info=True)
finally: finally:
# CRUCIAL: Always call task_done() for any item that was successfully retrieved. # CRUCIAL: Always call task_done() for any item that was successfully retrieved.
self.queue.task_done() self.queue.task_done()
@@ -138,7 +115,7 @@ class MessageSenderContainer:
"""启动发送任务。""" """启动发送任务。"""
if self._task is None: if self._task is None:
self._task = asyncio.create_task(self._send_worker()) self._task = asyncio.create_task(self._send_worker())
async def join(self): async def join(self):
"""等待所有消息发送完毕。""" """等待所有消息发送完毕。"""
if self._task: if self._task:
@@ -156,8 +133,10 @@ class S4UChatManager:
self.s4u_chats[chat_stream.stream_id] = S4UChat(chat_stream) self.s4u_chats[chat_stream.stream_id] = S4UChat(chat_stream)
return self.s4u_chats[chat_stream.stream_id] return self.s4u_chats[chat_stream.stream_id]
s4u_chat_manager = S4UChatManager() s4u_chat_manager = S4UChatManager()
def get_s4u_chat_manager() -> S4UChatManager: def get_s4u_chat_manager() -> S4UChatManager:
return s4u_chat_manager return s4u_chat_manager
@@ -169,22 +148,19 @@ class S4UChat:
self.chat_stream = chat_stream self.chat_stream = chat_stream
self.stream_id = chat_stream.stream_id self.stream_id = chat_stream.stream_id
self.stream_name = get_chat_manager().get_stream_name(self.stream_id) or self.stream_id self.stream_name = get_chat_manager().get_stream_name(self.stream_id) or self.stream_id
self._message_queue = asyncio.Queue() self._message_queue = asyncio.Queue()
self._processing_task = asyncio.create_task(self._message_processor()) self._processing_task = asyncio.create_task(self._message_processor())
self._current_generation_task: Optional[asyncio.Task] = None self._current_generation_task: Optional[asyncio.Task] = None
self._current_message_being_replied: Optional[MessageRecv] = None self._current_message_being_replied: Optional[MessageRecv] = None
self._is_replying = False self._is_replying = False
self.gpt = S4UStreamGenerator() self.gpt = S4UStreamGenerator()
# self.audio_generator = MockAudioGenerator() # self.audio_generator = MockAudioGenerator()
logger.info(f"[{self.stream_name}] S4UChat") logger.info(f"[{self.stream_name}] S4UChat")
# 改为实例方法, 移除 chat 参数 # 改为实例方法, 移除 chat 参数
async def response(self, message: MessageRecv, is_mentioned: bool, interested_rate: float) -> None: async def response(self, message: MessageRecv, is_mentioned: bool, interested_rate: float) -> None:
"""将消息放入队列并根据发信人决定是否中断当前处理。""" """将消息放入队列并根据发信人决定是否中断当前处理。"""
@@ -226,8 +202,8 @@ class S4UChat:
# 如果因快速中断导致队列中积压了更多消息,则只处理最新的一条 # 如果因快速中断导致队列中积压了更多消息,则只处理最新的一条
while not self._message_queue.empty(): while not self._message_queue.empty():
drained_msg = self._message_queue.get_nowait() drained_msg = self._message_queue.get_nowait()
self._message_queue.task_done() # 为取出的旧消息调用 task_done self._message_queue.task_done() # 为取出的旧消息调用 task_done
message = drained_msg # 始终处理最新消息 message = drained_msg # 始终处理最新消息
self._current_message_being_replied = message self._current_message_being_replied = message
logger.info(f"[{self.stream_name}] 丢弃过时消息,处理最新消息: {message.processed_plain_text}") logger.info(f"[{self.stream_name}] 丢弃过时消息,处理最新消息: {message.processed_plain_text}")
@@ -242,44 +218,40 @@ class S4UChat:
finally: finally:
self._current_generation_task = None self._current_generation_task = None
self._current_message_being_replied = None self._current_message_being_replied = None
except asyncio.CancelledError: except asyncio.CancelledError:
logger.info(f"[{self.stream_name}] 消息处理器正在关闭。") logger.info(f"[{self.stream_name}] 消息处理器正在关闭。")
break break
except Exception as e: except Exception as e:
logger.error(f"[{self.stream_name}] 消息处理器主循环发生未知错误: {e}", exc_info=True) logger.error(f"[{self.stream_name}] 消息处理器主循环发生未知错误: {e}", exc_info=True)
await asyncio.sleep(1) # 避免在未知错误下陷入CPU空转 await asyncio.sleep(1) # 避免在未知错误下陷入CPU空转
finally: finally:
# 确保处理过的消息(无论是正常完成还是被丢弃)都被标记完成 # 确保处理过的消息(无论是正常完成还是被丢弃)都被标记完成
if 'message' in locals(): if "message" in locals():
self._message_queue.task_done() self._message_queue.task_done()
async def _generate_and_send(self, message: MessageRecv): async def _generate_and_send(self, message: MessageRecv):
"""为单个消息生成文本和音频回复。整个过程可以被中断。""" """为单个消息生成文本和音频回复。整个过程可以被中断。"""
self._is_replying = True self._is_replying = True
sender_container = MessageSenderContainer(self.chat_stream, message) sender_container = MessageSenderContainer(self.chat_stream, message)
sender_container.start() sender_container.start()
try: try:
logger.info( logger.info(f"[S4U] 开始为消息生成文本和音频流: '{message.processed_plain_text[:30]}...'")
f"[S4U] 开始为消息生成文本和音频流: "
f"'{message.processed_plain_text[:30]}...'"
)
# 1. 逐句生成文本、发送并播放音频 # 1. 逐句生成文本、发送并播放音频
gen = self.gpt.generate_response(message, "") gen = self.gpt.generate_response(message, "")
async for chunk in gen: async for chunk in gen:
# 如果任务被取消await 会在此处引发 CancelledError # 如果任务被取消await 会在此处引发 CancelledError
# a. 发送文本块 # a. 发送文本块
await sender_container.add_message(chunk) await sender_container.add_message(chunk)
# b. 为该文本块生成并播放音频 # b. 为该文本块生成并播放音频
# if chunk.strip(): # if chunk.strip():
# audio_data = await self.audio_generator.generate(chunk) # audio_data = await self.audio_generator.generate(chunk)
# player = MockAudioPlayer(audio_data) # player = MockAudioPlayer(audio_data)
# await player.play() # await player.play()
# 等待所有文本消息发送完成 # 等待所有文本消息发送完成
await sender_container.close() await sender_container.close()
@@ -300,20 +272,19 @@ class S4UChat:
await sender_container.join() await sender_container.join()
logger.info(f"[{self.stream_name}] _generate_and_send 任务结束,资源已清理。") logger.info(f"[{self.stream_name}] _generate_and_send 任务结束,资源已清理。")
async def shutdown(self): async def shutdown(self):
"""平滑关闭处理任务。""" """平滑关闭处理任务。"""
logger.info(f"正在关闭 S4UChat: {self.stream_name}") logger.info(f"正在关闭 S4UChat: {self.stream_name}")
# 取消正在运行的任务 # 取消正在运行的任务
if self._current_generation_task and not self._current_generation_task.done(): if self._current_generation_task and not self._current_generation_task.done():
self._current_generation_task.cancel() self._current_generation_task.cancel()
if self._processing_task and not self._processing_task.done(): if self._processing_task and not self._processing_task.done():
self._processing_task.cancel() self._processing_task.cancel()
# 等待任务响应取消 # 等待任务响应取消
try: try:
await self._processing_task await self._processing_task
except asyncio.CancelledError: except asyncio.CancelledError:
logger.info(f"处理任务已成功取消: {self.stream_name}") logger.info(f"处理任务已成功取消: {self.stream_name}")

View File

@@ -1,21 +1,10 @@
from src.chat.memory_system.Hippocampus import hippocampus_manager
from src.config.config import global_config
from src.chat.message_receive.message import MessageRecv from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.storage import MessageStorage from src.chat.message_receive.storage import MessageStorage
from src.chat.heart_flow.heartflow import heartflow from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream
from src.chat.utils.utils import is_mentioned_bot_in_message from src.chat.utils.utils import is_mentioned_bot_in_message
from src.chat.utils.timer_calculator import Timer
from src.common.logger import get_logger from src.common.logger import get_logger
from .s4u_chat import get_s4u_chat_manager from .s4u_chat import get_s4u_chat_manager
import math
import re
import traceback
from typing import Optional, Tuple
from maim_message import UserInfo
from src.person_info.relationship_manager import get_relationship_manager
# from ..message_receive.message_buffer import message_buffer # from ..message_receive.message_buffer import message_buffer
@@ -44,7 +33,7 @@ class S4UMessageProcessor:
""" """
target_user_id_list = ["1026294844", "964959351"] target_user_id_list = ["1026294844", "964959351"]
# 1. 消息解析与初始化 # 1. 消息解析与初始化
groupinfo = message.message_info.group_info groupinfo = message.message_info.group_info
userinfo = message.message_info.user_info userinfo = message.message_info.user_info
@@ -60,7 +49,7 @@ class S4UMessageProcessor:
is_mentioned = is_mentioned_bot_in_message(message) is_mentioned = is_mentioned_bot_in_message(message)
s4u_chat = get_s4u_chat_manager().get_or_create_chat(chat) s4u_chat = get_s4u_chat_manager().get_or_create_chat(chat)
if userinfo.user_id in target_user_id_list: if userinfo.user_id in target_user_id_list:
await s4u_chat.response(message, is_mentioned=is_mentioned, interested_rate=1.0) await s4u_chat.response(message, is_mentioned=is_mentioned, interested_rate=1.0)
else: else:
@@ -68,4 +57,3 @@ class S4UMessageProcessor:
# 7. 日志记录 # 7. 日志记录
logger.info(f"[S4U]{userinfo.user_nickname}:{message.processed_plain_text}") logger.info(f"[S4U]{userinfo.user_nickname}:{message.processed_plain_text}")

View File

@@ -1,10 +1,8 @@
from src.config.config import global_config from src.config.config import global_config
from src.common.logger import get_logger from src.common.logger import get_logger
from src.individuality.individuality import get_individuality from src.individuality.individuality import get_individuality
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
from src.chat.message_receive.message import MessageRecv
import time import time
from src.chat.utils.utils import get_recent_group_speaker from src.chat.utils.utils import get_recent_group_speaker
from src.chat.memory_system.Hippocampus import hippocampus_manager from src.chat.memory_system.Hippocampus import hippocampus_manager
@@ -23,7 +21,6 @@ def init_prompt():
Prompt("\n你有以下这些**知识**\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n", "knowledge_prompt") Prompt("\n你有以下这些**知识**\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n", "knowledge_prompt")
Prompt( Prompt(
""" """
你的名字叫{bot_name},昵称是:{bot_other_names}{prompt_personality} 你的名字叫{bot_name},昵称是:{bot_other_names}{prompt_personality}
@@ -79,7 +76,6 @@ class PromptBuilder:
relationship_manager = get_relationship_manager() relationship_manager = get_relationship_manager()
relation_prompt += await relationship_manager.build_relationship_info(person) relation_prompt += await relationship_manager.build_relationship_info(person)
memory_prompt = "" memory_prompt = ""
related_memory = await hippocampus_manager.get_memory_from_text( related_memory = await hippocampus_manager.get_memory_from_text(
text=message_txt, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False text=message_txt, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False
@@ -98,23 +94,20 @@ class PromptBuilder:
timestamp=time.time(), timestamp=time.time(),
limit=100, limit=100,
) )
talk_type = message.message_info.platform + ":" + message.chat_stream.user_info.user_id talk_type = message.message_info.platform + ":" + message.chat_stream.user_info.user_id
print(f"talk_type: {talk_type}") print(f"talk_type: {talk_type}")
# 分别筛选核心对话和背景对话 # 分别筛选核心对话和背景对话
core_dialogue_list = [] core_dialogue_list = []
background_dialogue_list = [] background_dialogue_list = []
bot_id = str(global_config.bot.qq_account) bot_id = str(global_config.bot.qq_account)
target_user_id = str(message.chat_stream.user_info.user_id) target_user_id = str(message.chat_stream.user_info.user_id)
for msg_dict in message_list_before_now: for msg_dict in message_list_before_now:
try: try:
# 直接通过字典访问 # 直接通过字典访问
msg_user_id = str(msg_dict.get('user_id')) msg_user_id = str(msg_dict.get("user_id"))
if msg_user_id == bot_id: if msg_user_id == bot_id:
if msg_dict.get("reply_to") and talk_type == msg_dict.get("reply_to"): if msg_dict.get("reply_to") and talk_type == msg_dict.get("reply_to"):
print(f"reply: {msg_dict.get('reply_to')}") print(f"reply: {msg_dict.get('reply_to')}")
@@ -127,24 +120,24 @@ class PromptBuilder:
background_dialogue_list.append(msg_dict) background_dialogue_list.append(msg_dict)
except Exception as e: except Exception as e:
logger.error(f"无法处理历史消息记录: {msg_dict}, 错误: {e}") logger.error(f"无法处理历史消息记录: {msg_dict}, 错误: {e}")
if background_dialogue_list: if background_dialogue_list:
latest_25_msgs = background_dialogue_list[-25:] latest_25_msgs = background_dialogue_list[-25:]
background_dialogue_prompt = build_readable_messages( background_dialogue_prompt = build_readable_messages(
latest_25_msgs, latest_25_msgs,
merge_messages=True, merge_messages=True,
timestamp_mode = "normal_no_YMD", timestamp_mode="normal_no_YMD",
show_pic = False, show_pic=False,
) )
background_dialogue_prompt = f"这是其他用户的发言:\n{background_dialogue_prompt}" background_dialogue_prompt = f"这是其他用户的发言:\n{background_dialogue_prompt}"
else: else:
background_dialogue_prompt = "" background_dialogue_prompt = ""
# 分别获取最新50条和最新25条从message_list_before_now截取 # 分别获取最新50条和最新25条从message_list_before_now截取
core_dialogue_list = core_dialogue_list[-50:] core_dialogue_list = core_dialogue_list[-50:]
first_msg = core_dialogue_list[0] first_msg = core_dialogue_list[0]
start_speaking_user_id = first_msg.get('user_id') start_speaking_user_id = first_msg.get("user_id")
if start_speaking_user_id == bot_id: if start_speaking_user_id == bot_id:
last_speaking_user_id = bot_id last_speaking_user_id = bot_id
msg_seg_str = "你的发言:\n" msg_seg_str = "你的发言:\n"
@@ -152,30 +145,33 @@ class PromptBuilder:
start_speaking_user_id = target_user_id start_speaking_user_id = target_user_id
last_speaking_user_id = start_speaking_user_id last_speaking_user_id = start_speaking_user_id
msg_seg_str = "对方的发言:\n" msg_seg_str = "对方的发言:\n"
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(first_msg.get('time')))}: {first_msg.get('processed_plain_text')}\n" msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(first_msg.get('time')))}: {first_msg.get('processed_plain_text')}\n"
all_msg_seg_list = [] all_msg_seg_list = []
for msg in core_dialogue_list[1:]: for msg in core_dialogue_list[1:]:
speaker = msg.get('user_id') speaker = msg.get("user_id")
if speaker == last_speaking_user_id: if speaker == last_speaking_user_id:
#还是同一个人讲话 # 还是同一个人讲话
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(msg.get('time')))}: {msg.get('processed_plain_text')}\n" msg_seg_str += (
f"{time.strftime('%H:%M:%S', time.localtime(msg.get('time')))}: {msg.get('processed_plain_text')}\n"
)
else: else:
#换人了 # 换人了
msg_seg_str = f"{msg_seg_str}\n" msg_seg_str = f"{msg_seg_str}\n"
all_msg_seg_list.append(msg_seg_str) all_msg_seg_list.append(msg_seg_str)
if speaker == bot_id: if speaker == bot_id:
msg_seg_str = "你的发言:\n" msg_seg_str = "你的发言:\n"
else: else:
msg_seg_str = "对方的发言:\n" msg_seg_str = "对方的发言:\n"
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(msg.get('time')))}: {msg.get('processed_plain_text')}\n"
last_speaking_user_id = speaker
all_msg_seg_list.append(msg_seg_str)
msg_seg_str += (
f"{time.strftime('%H:%M:%S', time.localtime(msg.get('time')))}: {msg.get('processed_plain_text')}\n"
)
last_speaking_user_id = speaker
all_msg_seg_list.append(msg_seg_str)
core_msg_str = "" core_msg_str = ""
for msg in all_msg_seg_list: for msg in all_msg_seg_list:

View File

@@ -43,8 +43,8 @@ class S4UStreamGenerator:
# 匹配常见的句子结束符,但会忽略引号内和数字中的标点 # 匹配常见的句子结束符,但会忽略引号内和数字中的标点
self.sentence_split_pattern = re.compile( self.sentence_split_pattern = re.compile(
r'([^\s\w"\'([{]*["\'([{].*?["\'}\])][^\s\w"\'([{]*|' # 匹配被引号/括号包裹的内容 r'([^\s\w"\'([{]*["\'([{].*?["\'}\])][^\s\w"\'([{]*|' # 匹配被引号/括号包裹的内容
r'[^.。!?\n\r]+(?:[.。!?\n\r](?![\'"])|$))' # 匹配直到句子结束符 r'[^.。!?\n\r]+(?:[.。!?\n\r](?![\'"])|$))', # 匹配直到句子结束符
, re.UNICODE | re.DOTALL re.UNICODE | re.DOTALL,
) )
async def generate_response( async def generate_response(
@@ -68,7 +68,7 @@ class S4UStreamGenerator:
# 构建prompt # 构建prompt
if previous_reply_context: if previous_reply_context:
message_txt = f""" message_txt = f"""
你正在回复用户的消息,但中途被打断了。这是已有的对话上下文: 你正在回复用户的消息,但中途被打断了。这是已有的对话上下文:
[你已经对上一条消息说的话]: {previous_reply_context} [你已经对上一条消息说的话]: {previous_reply_context}
--- ---
@@ -78,9 +78,8 @@ class S4UStreamGenerator:
else: else:
message_txt = message.processed_plain_text message_txt = message.processed_plain_text
prompt = await prompt_builder.build_prompt_normal( prompt = await prompt_builder.build_prompt_normal(
message = message, message=message,
message_txt=message_txt, message_txt=message_txt,
sender_name=sender_name, sender_name=sender_name,
chat_stream=message.chat_stream, chat_stream=message.chat_stream,
@@ -109,16 +108,16 @@ class S4UStreamGenerator:
**kwargs, **kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
print(prompt) print(prompt)
buffer = "" buffer = ""
delimiters = ",。!?,.!?\n\r" # For final trimming delimiters = ",。!?,.!?\n\r" # For final trimming
punctuation_buffer = "" punctuation_buffer = ""
async for content in client.get_stream_content( async for content in client.get_stream_content(
messages=[{"role": "user", "content": prompt}], model=model_name, **kwargs messages=[{"role": "user", "content": prompt}], model=model_name, **kwargs
): ):
buffer += content buffer += content
# 使用正则表达式匹配句子 # 使用正则表达式匹配句子
last_match_end = 0 last_match_end = 0
for match in self.sentence_split_pattern.finditer(buffer): for match in self.sentence_split_pattern.finditer(buffer):
@@ -132,24 +131,23 @@ class S4UStreamGenerator:
else: else:
# 发送之前累积的标点和当前句子 # 发送之前累积的标点和当前句子
to_yield = punctuation_buffer + sentence to_yield = punctuation_buffer + sentence
if to_yield.endswith((',', '')): if to_yield.endswith((",", "")):
to_yield = to_yield.rstrip(',') to_yield = to_yield.rstrip(",")
yield to_yield yield to_yield
punctuation_buffer = "" # 清空标点符号缓冲区 punctuation_buffer = "" # 清空标点符号缓冲区
await asyncio.sleep(0) # 允许其他任务运行 await asyncio.sleep(0) # 允许其他任务运行
last_match_end = match.end(0) last_match_end = match.end(0)
# 从缓冲区移除已发送的部分 # 从缓冲区移除已发送的部分
if last_match_end > 0: if last_match_end > 0:
buffer = buffer[last_match_end:] buffer = buffer[last_match_end:]
# 发送缓冲区中剩余的任何内容 # 发送缓冲区中剩余的任何内容
to_yield = (punctuation_buffer + buffer).strip() to_yield = (punctuation_buffer + buffer).strip()
if to_yield: if to_yield:
if to_yield.endswith(('', ',')): if to_yield.endswith(("", ",")):
to_yield = to_yield.rstrip(',') to_yield = to_yield.rstrip(",")
if to_yield: if to_yield:
yield to_yield yield to_yield

View File

@@ -1,8 +1,5 @@
import asyncio from typing import AsyncGenerator, Dict, List, Optional, Union
import json
from typing import AsyncGenerator, Dict, List, Optional, Union, Any
from dataclasses import dataclass from dataclasses import dataclass
import aiohttp
from openai import AsyncOpenAI from openai import AsyncOpenAI
from openai.types.chat import ChatCompletion, ChatCompletionChunk from openai.types.chat import ChatCompletion, ChatCompletionChunk
@@ -10,20 +7,21 @@ from openai.types.chat import ChatCompletion, ChatCompletionChunk
@dataclass @dataclass
class ChatMessage: class ChatMessage:
"""聊天消息数据类""" """聊天消息数据类"""
role: str role: str
content: str content: str
def to_dict(self) -> Dict[str, str]: def to_dict(self) -> Dict[str, str]:
return {"role": self.role, "content": self.content} return {"role": self.role, "content": self.content}
class AsyncOpenAIClient: class AsyncOpenAIClient:
"""异步OpenAI客户端支持流式传输""" """异步OpenAI客户端支持流式传输"""
def __init__(self, api_key: str, base_url: Optional[str] = None): def __init__(self, api_key: str, base_url: Optional[str] = None):
""" """
初始化客户端 初始化客户端
Args: Args:
api_key: OpenAI API密钥 api_key: OpenAI API密钥
base_url: 可选的API基础URL用于自定义端点 base_url: 可选的API基础URL用于自定义端点
@@ -33,25 +31,25 @@ class AsyncOpenAIClient:
base_url=base_url, base_url=base_url,
timeout=10.0, # 设置60秒的全局超时 timeout=10.0, # 设置60秒的全局超时
) )
async def chat_completion( async def chat_completion(
self, self,
messages: List[Union[ChatMessage, Dict[str, str]]], messages: List[Union[ChatMessage, Dict[str, str]]],
model: str = "gpt-3.5-turbo", model: str = "gpt-3.5-turbo",
temperature: float = 0.7, temperature: float = 0.7,
max_tokens: Optional[int] = None, max_tokens: Optional[int] = None,
**kwargs **kwargs,
) -> ChatCompletion: ) -> ChatCompletion:
""" """
非流式聊天完成 非流式聊天完成
Args: Args:
messages: 消息列表 messages: 消息列表
model: 模型名称 model: 模型名称
temperature: 温度参数 temperature: 温度参数
max_tokens: 最大token数 max_tokens: 最大token数
**kwargs: 其他参数 **kwargs: 其他参数
Returns: Returns:
完整的聊天回复 完整的聊天回复
""" """
@@ -62,7 +60,7 @@ class AsyncOpenAIClient:
formatted_messages.append(msg.to_dict()) formatted_messages.append(msg.to_dict())
else: else:
formatted_messages.append(msg) formatted_messages.append(msg)
extra_body = {} extra_body = {}
if kwargs.get("enable_thinking") is not None: if kwargs.get("enable_thinking") is not None:
extra_body["enable_thinking"] = kwargs.pop("enable_thinking") extra_body["enable_thinking"] = kwargs.pop("enable_thinking")
@@ -76,29 +74,29 @@ class AsyncOpenAIClient:
max_tokens=max_tokens, max_tokens=max_tokens,
stream=False, stream=False,
extra_body=extra_body if extra_body else None, extra_body=extra_body if extra_body else None,
**kwargs **kwargs,
) )
return response return response
async def chat_completion_stream( async def chat_completion_stream(
self, self,
messages: List[Union[ChatMessage, Dict[str, str]]], messages: List[Union[ChatMessage, Dict[str, str]]],
model: str = "gpt-3.5-turbo", model: str = "gpt-3.5-turbo",
temperature: float = 0.7, temperature: float = 0.7,
max_tokens: Optional[int] = None, max_tokens: Optional[int] = None,
**kwargs **kwargs,
) -> AsyncGenerator[ChatCompletionChunk, None]: ) -> AsyncGenerator[ChatCompletionChunk, None]:
""" """
流式聊天完成 流式聊天完成
Args: Args:
messages: 消息列表 messages: 消息列表
model: 模型名称 model: 模型名称
temperature: 温度参数 temperature: 温度参数
max_tokens: 最大token数 max_tokens: 最大token数
**kwargs: 其他参数 **kwargs: 其他参数
Yields: Yields:
ChatCompletionChunk: 流式响应块 ChatCompletionChunk: 流式响应块
""" """
@@ -109,7 +107,7 @@ class AsyncOpenAIClient:
formatted_messages.append(msg.to_dict()) formatted_messages.append(msg.to_dict())
else: else:
formatted_messages.append(msg) formatted_messages.append(msg)
extra_body = {} extra_body = {}
if kwargs.get("enable_thinking") is not None: if kwargs.get("enable_thinking") is not None:
extra_body["enable_thinking"] = kwargs.pop("enable_thinking") extra_body["enable_thinking"] = kwargs.pop("enable_thinking")
@@ -123,84 +121,76 @@ class AsyncOpenAIClient:
max_tokens=max_tokens, max_tokens=max_tokens,
stream=True, stream=True,
extra_body=extra_body if extra_body else None, extra_body=extra_body if extra_body else None,
**kwargs **kwargs,
) )
async for chunk in stream: async for chunk in stream:
yield chunk yield chunk
async def get_stream_content( async def get_stream_content(
self, self,
messages: List[Union[ChatMessage, Dict[str, str]]], messages: List[Union[ChatMessage, Dict[str, str]]],
model: str = "gpt-3.5-turbo", model: str = "gpt-3.5-turbo",
temperature: float = 0.7, temperature: float = 0.7,
max_tokens: Optional[int] = None, max_tokens: Optional[int] = None,
**kwargs **kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
""" """
获取流式内容(只返回文本内容) 获取流式内容(只返回文本内容)
Args: Args:
messages: 消息列表 messages: 消息列表
model: 模型名称 model: 模型名称
temperature: 温度参数 temperature: 温度参数
max_tokens: 最大token数 max_tokens: 最大token数
**kwargs: 其他参数 **kwargs: 其他参数
Yields: Yields:
str: 文本内容片段 str: 文本内容片段
""" """
async for chunk in self.chat_completion_stream( async for chunk in self.chat_completion_stream(
messages=messages, messages=messages, model=model, temperature=temperature, max_tokens=max_tokens, **kwargs
model=model,
temperature=temperature,
max_tokens=max_tokens,
**kwargs
): ):
if chunk.choices and chunk.choices[0].delta.content: if chunk.choices and chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content yield chunk.choices[0].delta.content
async def collect_stream_response( async def collect_stream_response(
self, self,
messages: List[Union[ChatMessage, Dict[str, str]]], messages: List[Union[ChatMessage, Dict[str, str]]],
model: str = "gpt-3.5-turbo", model: str = "gpt-3.5-turbo",
temperature: float = 0.7, temperature: float = 0.7,
max_tokens: Optional[int] = None, max_tokens: Optional[int] = None,
**kwargs **kwargs,
) -> str: ) -> str:
""" """
收集完整的流式响应 收集完整的流式响应
Args: Args:
messages: 消息列表 messages: 消息列表
model: 模型名称 model: 模型名称
temperature: 温度参数 temperature: 温度参数
max_tokens: 最大token数 max_tokens: 最大token数
**kwargs: 其他参数 **kwargs: 其他参数
Returns: Returns:
str: 完整的响应文本 str: 完整的响应文本
""" """
full_response = "" full_response = ""
async for content in self.get_stream_content( async for content in self.get_stream_content(
messages=messages, messages=messages, model=model, temperature=temperature, max_tokens=max_tokens, **kwargs
model=model,
temperature=temperature,
max_tokens=max_tokens,
**kwargs
): ):
full_response += content full_response += content
return full_response return full_response
async def close(self): async def close(self):
"""关闭客户端""" """关闭客户端"""
await self.client.close() await self.client.close()
async def __aenter__(self): async def __aenter__(self):
"""异步上下文管理器入口""" """异步上下文管理器入口"""
return self return self
async def __aexit__(self, exc_type, exc_val, exc_tb): async def __aexit__(self, exc_type, exc_val, exc_tb):
"""异步上下文管理器退出""" """异步上下文管理器退出"""
await self.close() await self.close()
@@ -208,93 +198,77 @@ class AsyncOpenAIClient:
class ConversationManager: class ConversationManager:
"""对话管理器,用于管理对话历史""" """对话管理器,用于管理对话历史"""
def __init__(self, client: AsyncOpenAIClient, system_prompt: Optional[str] = None): def __init__(self, client: AsyncOpenAIClient, system_prompt: Optional[str] = None):
""" """
初始化对话管理器 初始化对话管理器
Args: Args:
client: OpenAI客户端实例 client: OpenAI客户端实例
system_prompt: 系统提示词 system_prompt: 系统提示词
""" """
self.client = client self.client = client
self.messages: List[ChatMessage] = [] self.messages: List[ChatMessage] = []
if system_prompt: if system_prompt:
self.messages.append(ChatMessage(role="system", content=system_prompt)) self.messages.append(ChatMessage(role="system", content=system_prompt))
def add_user_message(self, content: str): def add_user_message(self, content: str):
"""添加用户消息""" """添加用户消息"""
self.messages.append(ChatMessage(role="user", content=content)) self.messages.append(ChatMessage(role="user", content=content))
def add_assistant_message(self, content: str): def add_assistant_message(self, content: str):
"""添加助手消息""" """添加助手消息"""
self.messages.append(ChatMessage(role="assistant", content=content)) self.messages.append(ChatMessage(role="assistant", content=content))
async def send_message_stream( async def send_message_stream(
self, self, content: str, model: str = "gpt-3.5-turbo", **kwargs
content: str,
model: str = "gpt-3.5-turbo",
**kwargs
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
""" """
发送消息并获取流式响应 发送消息并获取流式响应
Args: Args:
content: 用户消息内容 content: 用户消息内容
model: 模型名称 model: 模型名称
**kwargs: 其他参数 **kwargs: 其他参数
Yields: Yields:
str: 响应内容片段 str: 响应内容片段
""" """
self.add_user_message(content) self.add_user_message(content)
response_content = "" response_content = ""
async for chunk in self.client.get_stream_content( async for chunk in self.client.get_stream_content(messages=self.messages, model=model, **kwargs):
messages=self.messages,
model=model,
**kwargs
):
response_content += chunk response_content += chunk
yield chunk yield chunk
self.add_assistant_message(response_content) self.add_assistant_message(response_content)
async def send_message( async def send_message(self, content: str, model: str = "gpt-3.5-turbo", **kwargs) -> str:
self,
content: str,
model: str = "gpt-3.5-turbo",
**kwargs
) -> str:
""" """
发送消息并获取完整响应 发送消息并获取完整响应
Args: Args:
content: 用户消息内容 content: 用户消息内容
model: 模型名称 model: 模型名称
**kwargs: 其他参数 **kwargs: 其他参数
Returns: Returns:
str: 完整响应 str: 完整响应
""" """
self.add_user_message(content) self.add_user_message(content)
response = await self.client.chat_completion( response = await self.client.chat_completion(messages=self.messages, model=model, **kwargs)
messages=self.messages,
model=model,
**kwargs
)
response_content = response.choices[0].message.content response_content = response.choices[0].message.content
self.add_assistant_message(response_content) self.add_assistant_message(response_content)
return response_content return response_content
def clear_history(self, keep_system: bool = True): def clear_history(self, keep_system: bool = True):
""" """
清除对话历史 清除对话历史
Args: Args:
keep_system: 是否保留系统消息 keep_system: 是否保留系统消息
""" """
@@ -302,11 +276,11 @@ class ConversationManager:
self.messages = [self.messages[0]] self.messages = [self.messages[0]]
else: else:
self.messages = [] self.messages = []
def get_message_count(self) -> int: def get_message_count(self) -> int:
"""获取消息数量""" """获取消息数量"""
return len(self.messages) return len(self.messages)
def get_conversation_history(self) -> List[Dict[str, str]]: def get_conversation_history(self) -> List[Dict[str, str]]:
"""获取对话历史""" """获取对话历史"""
return [msg.to_dict() for msg in self.messages] return [msg.to_dict() for msg in self.messages]