🤖 自动格式化代码 [skip ci]
This commit is contained in:
@@ -18,12 +18,12 @@ from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
|||||||
# 定义日志配置
|
# 定义日志配置
|
||||||
|
|
||||||
# 获取项目根目录(假设本文件在src/chat/message_receive/下,根目录为上上上级目录)
|
# 获取项目根目录(假设本文件在src/chat/message_receive/下,根目录为上上上级目录)
|
||||||
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../..'))
|
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
|
||||||
|
|
||||||
ENABLE_S4U_CHAT = os.path.isfile(os.path.join(PROJECT_ROOT, 's4u.s4u'))
|
ENABLE_S4U_CHAT = os.path.isfile(os.path.join(PROJECT_ROOT, "s4u.s4u"))
|
||||||
|
|
||||||
if ENABLE_S4U_CHAT:
|
if ENABLE_S4U_CHAT:
|
||||||
print('''\nS4U私聊模式已开启\n!!!!!!!!!!!!!!!!!\n''')
|
print("""\nS4U私聊模式已开启\n!!!!!!!!!!!!!!!!!\n""")
|
||||||
# 仅内部开启
|
# 仅内部开启
|
||||||
|
|
||||||
# 配置主程序日志格式
|
# 配置主程序日志格式
|
||||||
|
|||||||
@@ -154,9 +154,9 @@ class S4UChat:
|
|||||||
# 两个消息队列
|
# 两个消息队列
|
||||||
self._vip_queue = asyncio.PriorityQueue()
|
self._vip_queue = asyncio.PriorityQueue()
|
||||||
self._normal_queue = asyncio.PriorityQueue()
|
self._normal_queue = asyncio.PriorityQueue()
|
||||||
|
|
||||||
self._entry_counter = 0 # 保证FIFO的全局计数器
|
self._entry_counter = 0 # 保证FIFO的全局计数器
|
||||||
self._new_message_event = asyncio.Event() # 用于唤醒处理器
|
self._new_message_event = asyncio.Event() # 用于唤醒处理器
|
||||||
|
|
||||||
self._processing_task = asyncio.create_task(self._message_processor())
|
self._processing_task = asyncio.create_task(self._message_processor())
|
||||||
self._current_generation_task: Optional[asyncio.Task] = None
|
self._current_generation_task: Optional[asyncio.Task] = None
|
||||||
@@ -186,16 +186,16 @@ class S4UChat:
|
|||||||
"""根据VIP状态和中断逻辑将消息放入相应队列。"""
|
"""根据VIP状态和中断逻辑将消息放入相应队列。"""
|
||||||
is_vip = self._is_vip(message)
|
is_vip = self._is_vip(message)
|
||||||
new_priority = self._get_message_priority(message)
|
new_priority = self._get_message_priority(message)
|
||||||
|
|
||||||
should_interrupt = False
|
should_interrupt = False
|
||||||
if self._current_generation_task and not self._current_generation_task.done():
|
if self._current_generation_task and not self._current_generation_task.done():
|
||||||
if self._current_message_being_replied:
|
if self._current_message_being_replied:
|
||||||
current_queue, current_priority, _, current_msg = self._current_message_being_replied
|
current_queue, current_priority, _, current_msg = self._current_message_being_replied
|
||||||
|
|
||||||
# 规则:VIP从不被打断
|
# 规则:VIP从不被打断
|
||||||
if current_queue == "vip":
|
if current_queue == "vip":
|
||||||
pass # Do nothing
|
pass # Do nothing
|
||||||
|
|
||||||
# 规则:普通消息可以被打断
|
# 规则:普通消息可以被打断
|
||||||
elif current_queue == "normal":
|
elif current_queue == "normal":
|
||||||
# VIP消息可以打断普通消息
|
# VIP消息可以打断普通消息
|
||||||
@@ -214,10 +214,12 @@ class S4UChat:
|
|||||||
elif new_sender_id == current_sender_id and new_priority <= current_priority:
|
elif new_sender_id == current_sender_id and new_priority <= current_priority:
|
||||||
should_interrupt = True
|
should_interrupt = True
|
||||||
logger.info(f"[{self.stream_name}] Same user sent new message, interrupting.")
|
logger.info(f"[{self.stream_name}] Same user sent new message, interrupting.")
|
||||||
|
|
||||||
if should_interrupt:
|
if should_interrupt:
|
||||||
if self.gpt.partial_response:
|
if self.gpt.partial_response:
|
||||||
logger.warning(f"[{self.stream_name}] Interrupting reply. Already generated: '{self.gpt.partial_response}'")
|
logger.warning(
|
||||||
|
f"[{self.stream_name}] Interrupting reply. Already generated: '{self.gpt.partial_response}'"
|
||||||
|
)
|
||||||
self._current_generation_task.cancel()
|
self._current_generation_task.cancel()
|
||||||
|
|
||||||
# 将消息放入对应的队列
|
# 将消息放入对应的队列
|
||||||
@@ -227,9 +229,9 @@ class S4UChat:
|
|||||||
logger.info(f"[{self.stream_name}] VIP message added to queue.")
|
logger.info(f"[{self.stream_name}] VIP message added to queue.")
|
||||||
else:
|
else:
|
||||||
await self._normal_queue.put(item)
|
await self._normal_queue.put(item)
|
||||||
|
|
||||||
self._entry_counter += 1
|
self._entry_counter += 1
|
||||||
self._new_message_event.set() # 唤醒处理器
|
self._new_message_event.set() # 唤醒处理器
|
||||||
|
|
||||||
async def _message_processor(self):
|
async def _message_processor(self):
|
||||||
"""调度器:优先处理VIP队列,然后处理普通队列。"""
|
"""调度器:优先处理VIP队列,然后处理普通队列。"""
|
||||||
@@ -248,12 +250,14 @@ class S4UChat:
|
|||||||
priority, entry_count, timestamp, message = self._normal_queue.get_nowait()
|
priority, entry_count, timestamp, message = self._normal_queue.get_nowait()
|
||||||
# 检查普通消息是否超时
|
# 检查普通消息是否超时
|
||||||
if time.time() - timestamp > self._MESSAGE_TIMEOUT_SECONDS:
|
if time.time() - timestamp > self._MESSAGE_TIMEOUT_SECONDS:
|
||||||
logger.info(f"[{self.stream_name}] Discarding stale normal message: {message.processed_plain_text[:20]}...")
|
logger.info(
|
||||||
|
f"[{self.stream_name}] Discarding stale normal message: {message.processed_plain_text[:20]}..."
|
||||||
|
)
|
||||||
self._normal_queue.task_done()
|
self._normal_queue.task_done()
|
||||||
continue # 处理下一条
|
continue # 处理下一条
|
||||||
queue_name = "normal"
|
queue_name = "normal"
|
||||||
else:
|
else:
|
||||||
continue # 没有消息了,回去等事件
|
continue # 没有消息了,回去等事件
|
||||||
|
|
||||||
self._current_message_being_replied = (queue_name, priority, entry_count, message)
|
self._current_message_being_replied = (queue_name, priority, entry_count, message)
|
||||||
self._current_generation_task = asyncio.create_task(self._generate_and_send(message))
|
self._current_generation_task = asyncio.create_task(self._generate_and_send(message))
|
||||||
@@ -261,7 +265,9 @@ class S4UChat:
|
|||||||
try:
|
try:
|
||||||
await self._current_generation_task
|
await self._current_generation_task
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.info(f"[{self.stream_name}] Reply generation was interrupted externally for {queue_name} message. The message will be discarded.")
|
logger.info(
|
||||||
|
f"[{self.stream_name}] Reply generation was interrupted externally for {queue_name} message. The message will be discarded."
|
||||||
|
)
|
||||||
# 被中断的消息应该被丢弃,而不是重新排队,以响应最新的用户输入。
|
# 被中断的消息应该被丢弃,而不是重新排队,以响应最新的用户输入。
|
||||||
# 旧的重新入队逻辑会导致所有中断的消息最终都被回复。
|
# 旧的重新入队逻辑会导致所有中断的消息最终都被回复。
|
||||||
|
|
||||||
@@ -271,11 +277,11 @@ class S4UChat:
|
|||||||
self._current_generation_task = None
|
self._current_generation_task = None
|
||||||
self._current_message_being_replied = None
|
self._current_message_being_replied = None
|
||||||
# 标记任务完成
|
# 标记任务完成
|
||||||
if queue_name == 'vip':
|
if queue_name == "vip":
|
||||||
self._vip_queue.task_done()
|
self._vip_queue.task_done()
|
||||||
else:
|
else:
|
||||||
self._normal_queue.task_done()
|
self._normal_queue.task_done()
|
||||||
|
|
||||||
# 检查是否还有任务,有则立即再次触发事件
|
# 检查是否还有任务,有则立即再次触发事件
|
||||||
if not self._vip_queue.empty() or not self._normal_queue.empty():
|
if not self._vip_queue.empty() or not self._normal_queue.empty():
|
||||||
self._new_message_event.set()
|
self._new_message_event.set()
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.individuality.individuality import get_individuality
|
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
|
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
|
||||||
import time
|
import time
|
||||||
@@ -105,7 +104,9 @@ class PromptBuilder:
|
|||||||
)
|
)
|
||||||
relation_info = "".join(relation_info_list)
|
relation_info = "".join(relation_info_list)
|
||||||
if relation_info:
|
if relation_info:
|
||||||
relation_prompt = await global_prompt_manager.format_prompt("relation_prompt", relation_info=relation_info)
|
relation_prompt = await global_prompt_manager.format_prompt(
|
||||||
|
"relation_prompt", relation_info=relation_info
|
||||||
|
)
|
||||||
return relation_prompt
|
return relation_prompt
|
||||||
|
|
||||||
async def build_memory_block(self, text: str) -> str:
|
async def build_memory_block(self, text: str) -> str:
|
||||||
@@ -128,7 +129,7 @@ class PromptBuilder:
|
|||||||
)
|
)
|
||||||
|
|
||||||
talk_type = message.message_info.platform + ":" + message.chat_stream.user_info.user_id
|
talk_type = message.message_info.platform + ":" + message.chat_stream.user_info.user_id
|
||||||
|
|
||||||
core_dialogue_list = []
|
core_dialogue_list = []
|
||||||
background_dialogue_list = []
|
background_dialogue_list = []
|
||||||
bot_id = str(global_config.bot.qq_account)
|
bot_id = str(global_config.bot.qq_account)
|
||||||
@@ -148,7 +149,7 @@ class PromptBuilder:
|
|||||||
background_dialogue_list.append(msg_dict)
|
background_dialogue_list.append(msg_dict)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"无法处理历史消息记录: {msg_dict}, 错误: {e}")
|
logger.error(f"无法处理历史消息记录: {msg_dict}, 错误: {e}")
|
||||||
|
|
||||||
background_dialogue_prompt = ""
|
background_dialogue_prompt = ""
|
||||||
if background_dialogue_list:
|
if background_dialogue_list:
|
||||||
latest_25_msgs = background_dialogue_list[-25:]
|
latest_25_msgs = background_dialogue_list[-25:]
|
||||||
@@ -196,9 +197,8 @@ class PromptBuilder:
|
|||||||
all_msg_seg_list.append(msg_seg_str)
|
all_msg_seg_list.append(msg_seg_str)
|
||||||
for msg in all_msg_seg_list:
|
for msg in all_msg_seg_list:
|
||||||
core_msg_str += msg
|
core_msg_str += msg
|
||||||
|
|
||||||
return core_msg_str, background_dialogue_prompt
|
|
||||||
|
|
||||||
|
return core_msg_str, background_dialogue_prompt
|
||||||
|
|
||||||
async def build_prompt_normal(
|
async def build_prompt_normal(
|
||||||
self,
|
self,
|
||||||
@@ -207,19 +207,16 @@ class PromptBuilder:
|
|||||||
message_txt: str,
|
message_txt: str,
|
||||||
sender_name: str = "某人",
|
sender_name: str = "某人",
|
||||||
) -> str:
|
) -> str:
|
||||||
|
|
||||||
identity_block, relation_info_block, memory_block = await asyncio.gather(
|
identity_block, relation_info_block, memory_block = await asyncio.gather(
|
||||||
self.build_identity_block(),
|
self.build_identity_block(), self.build_relation_info(chat_stream), self.build_memory_block(message_txt)
|
||||||
self.build_relation_info(chat_stream),
|
|
||||||
self.build_memory_block(message_txt)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
core_dialogue_prompt, background_dialogue_prompt = self.build_chat_history_prompts(chat_stream, message)
|
core_dialogue_prompt, background_dialogue_prompt = self.build_chat_history_prompts(chat_stream, message)
|
||||||
|
|
||||||
time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||||
|
|
||||||
template_name = "s4u_prompt"
|
template_name = "s4u_prompt"
|
||||||
|
|
||||||
prompt = await global_prompt_manager.format_prompt(
|
prompt = await global_prompt_manager.format_prompt(
|
||||||
template_name,
|
template_name,
|
||||||
identity_block=identity_block,
|
identity_block=identity_block,
|
||||||
|
|||||||
@@ -135,7 +135,7 @@ class S4UStreamGenerator:
|
|||||||
to_yield = punctuation_buffer + sentence
|
to_yield = punctuation_buffer + sentence
|
||||||
if to_yield.endswith((",", ",")):
|
if to_yield.endswith((",", ",")):
|
||||||
to_yield = to_yield.rstrip(",,")
|
to_yield = to_yield.rstrip(",,")
|
||||||
|
|
||||||
self.partial_response += to_yield
|
self.partial_response += to_yield
|
||||||
yield to_yield
|
yield to_yield
|
||||||
punctuation_buffer = "" # 清空标点符号缓冲区
|
punctuation_buffer = "" # 清空标点符号缓冲区
|
||||||
|
|||||||
Reference in New Issue
Block a user