feat:为s4u添加了优先队列和普通队列
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -9,6 +9,7 @@ tool_call_benchmark.py
|
|||||||
run_maibot_core.bat
|
run_maibot_core.bat
|
||||||
run_napcat_adapter.bat
|
run_napcat_adapter.bat
|
||||||
run_ad.bat
|
run_ad.bat
|
||||||
|
s4u.s4u
|
||||||
llm_tool_benchmark_results.json
|
llm_tool_benchmark_results.json
|
||||||
MaiBot-Napcat-Adapter-main
|
MaiBot-Napcat-Adapter-main
|
||||||
MaiBot-Napcat-Adapter
|
MaiBot-Napcat-Adapter
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import traceback
|
import traceback
|
||||||
|
import os
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
@@ -16,8 +17,14 @@ from src.plugin_system.base.base_command import BaseCommand
|
|||||||
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
||||||
# 定义日志配置
|
# 定义日志配置
|
||||||
|
|
||||||
ENABLE_S4U_CHAT = True
|
# 获取项目根目录(假设本文件在src/chat/message_receive/下,根目录为上上上级目录)
|
||||||
# 仅内部开启
|
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../..'))
|
||||||
|
|
||||||
|
ENABLE_S4U_CHAT = os.path.isfile(os.path.join(PROJECT_ROOT, 's4u.s4u'))
|
||||||
|
|
||||||
|
if ENABLE_S4U_CHAT:
|
||||||
|
print('''\nS4U私聊模式已开启\n!!!!!!!!!!!!!!!!!\n''')
|
||||||
|
# 仅内部开启
|
||||||
|
|
||||||
# 配置主程序日志格式
|
# 配置主程序日志格式
|
||||||
logger = get_logger("chat")
|
logger = get_logger("chat")
|
||||||
@@ -180,19 +187,10 @@ class ChatBot:
|
|||||||
# 如果在私聊中
|
# 如果在私聊中
|
||||||
if group_info is None:
|
if group_info is None:
|
||||||
logger.debug("检测到私聊消息")
|
logger.debug("检测到私聊消息")
|
||||||
|
|
||||||
if ENABLE_S4U_CHAT:
|
if ENABLE_S4U_CHAT:
|
||||||
logger.debug("进入S4U私聊处理流程")
|
logger.debug("进入S4U私聊处理流程")
|
||||||
await self.s4u_message_processor.process_message(message)
|
await self.s4u_message_processor.process_message(message)
|
||||||
return
|
return
|
||||||
|
|
||||||
if global_config.experimental.pfc_chatting:
|
|
||||||
logger.debug("进入PFC私聊处理流程")
|
|
||||||
# 创建聊天流
|
|
||||||
logger.debug(f"为{user_info.user_id}创建/获取聊天流")
|
|
||||||
await self.only_process_chat.process_message(message)
|
|
||||||
await self._create_pfc_chat(message)
|
|
||||||
# 禁止PFC,进入普通的心流消息处理逻辑
|
|
||||||
else:
|
else:
|
||||||
logger.debug("进入普通心流私聊处理")
|
logger.debug("进入普通心流私聊处理")
|
||||||
await self.heartflow_message_receiver.process_message(message)
|
await self.heartflow_message_receiver.process_message(message)
|
||||||
@@ -202,7 +200,7 @@ class ChatBot:
|
|||||||
logger.debug("进入S4U私聊处理流程")
|
logger.debug("进入S4U私聊处理流程")
|
||||||
await self.s4u_message_processor.process_message(message)
|
await self.s4u_message_processor.process_message(message)
|
||||||
return
|
return
|
||||||
|
else:
|
||||||
logger.debug(f"检测到群聊消息,群ID: {group_info.group_id}")
|
logger.debug(f"检测到群聊消息,群ID: {group_info.group_id}")
|
||||||
await self.heartflow_message_receiver.process_message(message)
|
await self.heartflow_message_receiver.process_message(message)
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
import random
|
import random
|
||||||
from typing import Optional, Dict # 导入类型提示
|
from typing import Optional, Dict, Tuple # 导入类型提示
|
||||||
from maim_message import UserInfo, Seg
|
from maim_message import UserInfo, Seg
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||||
@@ -142,6 +142,8 @@ def get_s4u_chat_manager() -> S4UChatManager:
|
|||||||
|
|
||||||
|
|
||||||
class S4UChat:
|
class S4UChat:
|
||||||
|
_MESSAGE_TIMEOUT_SECONDS = 60 # 普通消息存活时间(秒)
|
||||||
|
|
||||||
def __init__(self, chat_stream: ChatStream):
|
def __init__(self, chat_stream: ChatStream):
|
||||||
"""初始化 S4UChat 实例。"""
|
"""初始化 S4UChat 实例。"""
|
||||||
|
|
||||||
@@ -149,86 +151,141 @@ class S4UChat:
|
|||||||
self.stream_id = chat_stream.stream_id
|
self.stream_id = chat_stream.stream_id
|
||||||
self.stream_name = get_chat_manager().get_stream_name(self.stream_id) or self.stream_id
|
self.stream_name = get_chat_manager().get_stream_name(self.stream_id) or self.stream_id
|
||||||
|
|
||||||
self._message_queue = asyncio.Queue()
|
# 两个消息队列
|
||||||
|
self._vip_queue = asyncio.PriorityQueue()
|
||||||
|
self._normal_queue = asyncio.PriorityQueue()
|
||||||
|
|
||||||
|
self._entry_counter = 0 # 保证FIFO的全局计数器
|
||||||
|
self._new_message_event = asyncio.Event() # 用于唤醒处理器
|
||||||
|
|
||||||
self._processing_task = asyncio.create_task(self._message_processor())
|
self._processing_task = asyncio.create_task(self._message_processor())
|
||||||
self._current_generation_task: Optional[asyncio.Task] = None
|
self._current_generation_task: Optional[asyncio.Task] = None
|
||||||
self._current_message_being_replied: Optional[MessageRecv] = None
|
# 当前消息的元数据:(队列类型, 优先级, 计数器, 消息对象)
|
||||||
|
self._current_message_being_replied: Optional[Tuple[str, int, int, MessageRecv]] = None
|
||||||
|
|
||||||
self._is_replying = False
|
self._is_replying = False
|
||||||
|
|
||||||
self.gpt = S4UStreamGenerator()
|
self.gpt = S4UStreamGenerator()
|
||||||
# self.audio_generator = MockAudioGenerator()
|
logger.info(f"[{self.stream_name}] S4UChat with two-queue system initialized.")
|
||||||
|
|
||||||
logger.info(f"[{self.stream_name}] S4UChat")
|
def _is_vip(self, message: MessageRecv) -> bool:
|
||||||
|
"""检查消息是否来自VIP用户。"""
|
||||||
|
# 您需要修改此处或在配置文件中定义VIP用户
|
||||||
|
vip_user_ids = ["1026294844"]
|
||||||
|
vip_user_ids = [""]
|
||||||
|
return message.message_info.user_info.user_id in vip_user_ids
|
||||||
|
|
||||||
|
def _get_message_priority(self, message: MessageRecv) -> int:
|
||||||
|
"""为消息分配优先级。数字越小,优先级越高。"""
|
||||||
|
if f"@{global_config.bot.nickname}" in message.processed_plain_text or any(
|
||||||
|
f"@{alias}" in message.processed_plain_text for alias in global_config.bot.alias_names
|
||||||
|
):
|
||||||
|
return 0
|
||||||
|
return 1
|
||||||
|
|
||||||
|
async def add_message(self, message: MessageRecv) -> None:
|
||||||
|
"""根据VIP状态和中断逻辑将消息放入相应队列。"""
|
||||||
|
is_vip = self._is_vip(message)
|
||||||
|
new_priority = self._get_message_priority(message)
|
||||||
|
|
||||||
# 改为实例方法, 移除 chat 参数
|
|
||||||
async def response(self, message: MessageRecv, is_mentioned: bool, interested_rate: float) -> None:
|
|
||||||
"""将消息放入队列并根据发信人决定是否中断当前处理。"""
|
|
||||||
should_interrupt = False
|
should_interrupt = False
|
||||||
if self._current_generation_task and not self._current_generation_task.done():
|
if self._current_generation_task and not self._current_generation_task.done():
|
||||||
if self._current_message_being_replied:
|
if self._current_message_being_replied:
|
||||||
# 检查新消息发送者和正在回复的消息发送者是否为同一人
|
current_queue, current_priority, _, current_msg = self._current_message_being_replied
|
||||||
new_sender_id = message.message_info.user_info.user_id
|
|
||||||
original_sender_id = self._current_message_being_replied.message_info.user_info.user_id
|
|
||||||
|
|
||||||
if new_sender_id == original_sender_id:
|
# 规则:VIP从不被打断
|
||||||
|
if current_queue == "vip":
|
||||||
|
pass # Do nothing
|
||||||
|
|
||||||
|
# 规则:普通消息可以被打断
|
||||||
|
elif current_queue == "normal":
|
||||||
|
# VIP消息可以打断普通消息
|
||||||
|
if is_vip:
|
||||||
should_interrupt = True
|
should_interrupt = True
|
||||||
logger.info(f"[{self.stream_name}] 来自同一用户的消息,中断当前回复。")
|
logger.info(f"[{self.stream_name}] VIP message received, interrupting current normal task.")
|
||||||
|
# 普通消息的内部打断逻辑
|
||||||
else:
|
else:
|
||||||
if random.random() < 0.2:
|
new_sender_id = message.message_info.user_info.user_id
|
||||||
|
current_sender_id = current_msg.message_info.user_info.user_id
|
||||||
|
# 新消息优先级更高
|
||||||
|
if new_priority < current_priority:
|
||||||
should_interrupt = True
|
should_interrupt = True
|
||||||
logger.info(f"[{self.stream_name}] 来自不同用户的消息,随机中断(20%)。")
|
logger.info(f"[{self.stream_name}] New normal message has higher priority, interrupting.")
|
||||||
else:
|
# 同用户,同级或更高级
|
||||||
logger.info(f"[{self.stream_name}] 来自不同用户的消息,不中断。")
|
elif new_sender_id == current_sender_id and new_priority <= current_priority:
|
||||||
else:
|
|
||||||
# Fallback: if we don't know who we are replying to, interrupt.
|
|
||||||
should_interrupt = True
|
should_interrupt = True
|
||||||
logger.warning(f"[{self.stream_name}] 正在生成回复,但无法获取原始消息发送者信息,将默认中断。")
|
logger.info(f"[{self.stream_name}] Same user sent new message, interrupting.")
|
||||||
|
|
||||||
if should_interrupt:
|
if should_interrupt:
|
||||||
|
if self.gpt.partial_response:
|
||||||
|
logger.warning(f"[{self.stream_name}] Interrupting reply. Already generated: '{self.gpt.partial_response}'")
|
||||||
self._current_generation_task.cancel()
|
self._current_generation_task.cancel()
|
||||||
logger.info(f"[{self.stream_name}] 请求中断当前回复生成任务。")
|
|
||||||
|
|
||||||
await self._message_queue.put(message)
|
# 将消息放入对应的队列
|
||||||
|
item = (new_priority, self._entry_counter, time.time(), message)
|
||||||
|
if is_vip:
|
||||||
|
await self._vip_queue.put(item)
|
||||||
|
logger.info(f"[{self.stream_name}] VIP message added to queue.")
|
||||||
|
else:
|
||||||
|
await self._normal_queue.put(item)
|
||||||
|
|
||||||
|
self._entry_counter += 1
|
||||||
|
self._new_message_event.set() # 唤醒处理器
|
||||||
|
|
||||||
async def _message_processor(self):
|
async def _message_processor(self):
|
||||||
"""从队列中处理消息,支持中断。"""
|
"""调度器:优先处理VIP队列,然后处理普通队列。"""
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
# 等待第一条消息
|
# 等待有新消息的信号,避免空转
|
||||||
message = await self._message_queue.get()
|
await self._new_message_event.wait()
|
||||||
self._current_message_being_replied = message
|
self._new_message_event.clear()
|
||||||
|
|
||||||
# 如果因快速中断导致队列中积压了更多消息,则只处理最新的一条
|
# 优先处理VIP队列
|
||||||
while not self._message_queue.empty():
|
if not self._vip_queue.empty():
|
||||||
drained_msg = self._message_queue.get_nowait()
|
priority, entry_count, _, message = self._vip_queue.get_nowait()
|
||||||
self._message_queue.task_done() # 为取出的旧消息调用 task_done
|
queue_name = "vip"
|
||||||
message = drained_msg # 始终处理最新消息
|
# 其次处理普通队列
|
||||||
self._current_message_being_replied = message
|
elif not self._normal_queue.empty():
|
||||||
logger.info(f"[{self.stream_name}] 丢弃过时消息,处理最新消息: {message.processed_plain_text}")
|
priority, entry_count, timestamp, message = self._normal_queue.get_nowait()
|
||||||
|
# 检查普通消息是否超时
|
||||||
|
if time.time() - timestamp > self._MESSAGE_TIMEOUT_SECONDS:
|
||||||
|
logger.info(f"[{self.stream_name}] Discarding stale normal message: {message.processed_plain_text[:20]}...")
|
||||||
|
self._normal_queue.task_done()
|
||||||
|
continue # 处理下一条
|
||||||
|
queue_name = "normal"
|
||||||
|
else:
|
||||||
|
continue # 没有消息了,回去等事件
|
||||||
|
|
||||||
|
self._current_message_being_replied = (queue_name, priority, entry_count, message)
|
||||||
self._current_generation_task = asyncio.create_task(self._generate_and_send(message))
|
self._current_generation_task = asyncio.create_task(self._generate_and_send(message))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self._current_generation_task
|
await self._current_generation_task
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.info(f"[{self.stream_name}] 回复生成被外部中断。")
|
logger.info(f"[{self.stream_name}] Reply generation was interrupted externally for {queue_name} message. The message will be discarded.")
|
||||||
|
# 被中断的消息应该被丢弃,而不是重新排队,以响应最新的用户输入。
|
||||||
|
# 旧的重新入队逻辑会导致所有中断的消息最终都被回复。
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[{self.stream_name}] _generate_and_send 任务出现错误: {e}", exc_info=True)
|
logger.error(f"[{self.stream_name}] _generate_and_send task error: {e}", exc_info=True)
|
||||||
finally:
|
finally:
|
||||||
self._current_generation_task = None
|
self._current_generation_task = None
|
||||||
self._current_message_being_replied = None
|
self._current_message_being_replied = None
|
||||||
|
# 标记任务完成
|
||||||
|
if queue_name == 'vip':
|
||||||
|
self._vip_queue.task_done()
|
||||||
|
else:
|
||||||
|
self._normal_queue.task_done()
|
||||||
|
|
||||||
|
# 检查是否还有任务,有则立即再次触发事件
|
||||||
|
if not self._vip_queue.empty() or not self._normal_queue.empty():
|
||||||
|
self._new_message_event.set()
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.info(f"[{self.stream_name}] 消息处理器正在关闭。")
|
logger.info(f"[{self.stream_name}] Message processor is shutting down.")
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[{self.stream_name}] 消息处理器主循环发生未知错误: {e}", exc_info=True)
|
logger.error(f"[{self.stream_name}] Message processor main loop error: {e}", exc_info=True)
|
||||||
await asyncio.sleep(1) # 避免在未知错误下陷入CPU空转
|
await asyncio.sleep(1)
|
||||||
finally:
|
|
||||||
# 确保处理过的消息(无论是正常完成还是被丢弃)都被标记完成
|
|
||||||
if "message" in locals():
|
|
||||||
self._message_queue.task_done()
|
|
||||||
|
|
||||||
async def _generate_and_send(self, message: MessageRecv):
|
async def _generate_and_send(self, message: MessageRecv):
|
||||||
"""为单个消息生成文本和音频回复。整个过程可以被中断。"""
|
"""为单个消息生成文本和音频回复。整个过程可以被中断。"""
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
from src.chat.message_receive.message import MessageRecv
|
from src.chat.message_receive.message import MessageRecv
|
||||||
from src.chat.message_receive.storage import MessageStorage
|
from src.chat.message_receive.storage import MessageStorage
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
from src.chat.utils.utils import is_mentioned_bot_in_message
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from .s4u_chat import get_s4u_chat_manager
|
from .s4u_chat import get_s4u_chat_manager
|
||||||
|
|
||||||
@@ -47,13 +46,12 @@ class S4UMessageProcessor:
|
|||||||
|
|
||||||
await self.storage.store_message(message, chat)
|
await self.storage.store_message(message, chat)
|
||||||
|
|
||||||
is_mentioned = is_mentioned_bot_in_message(message)
|
|
||||||
s4u_chat = get_s4u_chat_manager().get_or_create_chat(chat)
|
s4u_chat = get_s4u_chat_manager().get_or_create_chat(chat)
|
||||||
|
|
||||||
if userinfo.user_id in target_user_id_list:
|
if userinfo.user_id in target_user_id_list:
|
||||||
await s4u_chat.response(message, is_mentioned=is_mentioned, interested_rate=1.0)
|
await s4u_chat.add_message(message)
|
||||||
else:
|
else:
|
||||||
await s4u_chat.response(message, is_mentioned=is_mentioned, interested_rate=0.0)
|
await s4u_chat.add_message(message)
|
||||||
|
|
||||||
# 7. 日志记录
|
# 7. 日志记录
|
||||||
logger.info(f"[S4U]{userinfo.user_nickname}:{message.processed_plain_text}")
|
logger.info(f"[S4U]{userinfo.user_nickname}:{message.processed_plain_text}")
|
||||||
|
|||||||
@@ -7,7 +7,11 @@ import time
|
|||||||
from src.chat.utils.utils import get_recent_group_speaker
|
from src.chat.utils.utils import get_recent_group_speaker
|
||||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||||
import random
|
import random
|
||||||
|
from datetime import datetime
|
||||||
|
import asyncio
|
||||||
|
import ast
|
||||||
|
|
||||||
|
from src.person_info.person_info import get_person_info_manager
|
||||||
from src.person_info.relationship_manager import get_relationship_manager
|
from src.person_info.relationship_manager import get_relationship_manager
|
||||||
|
|
||||||
logger = get_logger("prompt")
|
logger = get_logger("prompt")
|
||||||
@@ -20,15 +24,20 @@ def init_prompt():
|
|||||||
Prompt("和{sender_name}私聊", "chat_target_private2")
|
Prompt("和{sender_name}私聊", "chat_target_private2")
|
||||||
|
|
||||||
Prompt("\n你有以下这些**知识**:\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n", "knowledge_prompt")
|
Prompt("\n你有以下这些**知识**:\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n", "knowledge_prompt")
|
||||||
|
Prompt("\n关于你们的关系,你需要知道:\n{relation_info}\n", "relation_prompt")
|
||||||
|
Prompt("你回想起了一些事情:\n{memory_info}\n", "memory_prompt")
|
||||||
|
|
||||||
Prompt(
|
Prompt(
|
||||||
"""
|
"""{identity_block}
|
||||||
你的名字叫{bot_name},昵称是:{bot_other_names},{prompt_personality}。
|
|
||||||
|
{relation_info_block}
|
||||||
|
{memory_block}
|
||||||
|
|
||||||
你现在的主要任务是和 {sender_name} 聊天。同时,也有其他用户会参与你们的聊天,你可以参考他们的回复内容,但是你主要还是关注你和{sender_name}的聊天内容。
|
你现在的主要任务是和 {sender_name} 聊天。同时,也有其他用户会参与你们的聊天,你可以参考他们的回复内容,但是你主要还是关注你和{sender_name}的聊天内容。
|
||||||
|
|
||||||
{background_dialogue_prompt}
|
{background_dialogue_prompt}
|
||||||
--------------------------------
|
--------------------------------
|
||||||
{now_time}
|
{time_block}
|
||||||
这是你和{sender_name}的对话,你们正在交流中:
|
这是你和{sender_name}的对话,你们正在交流中:
|
||||||
{core_dialogue_prompt}
|
{core_dialogue_prompt}
|
||||||
|
|
||||||
@@ -37,7 +46,6 @@ def init_prompt():
|
|||||||
不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出回复内容,现在{sender_name}正在等待你的回复。
|
不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出回复内容,现在{sender_name}正在等待你的回复。
|
||||||
你的回复风格不要浮夸,有逻辑和条理,请你继续回复{sender_name}。
|
你的回复风格不要浮夸,有逻辑和条理,请你继续回复{sender_name}。
|
||||||
你的发言:
|
你的发言:
|
||||||
|
|
||||||
""",
|
""",
|
||||||
"s4u_prompt", # New template for private CHAT chat
|
"s4u_prompt", # New template for private CHAT chat
|
||||||
)
|
)
|
||||||
@@ -48,22 +56,41 @@ class PromptBuilder:
|
|||||||
self.prompt_built = ""
|
self.prompt_built = ""
|
||||||
self.activate_messages = ""
|
self.activate_messages = ""
|
||||||
|
|
||||||
async def build_prompt_normal(
|
async def build_identity_block(self) -> str:
|
||||||
self,
|
person_info_manager = get_person_info_manager()
|
||||||
message,
|
bot_person_id = person_info_manager.get_person_id("system", "bot_id")
|
||||||
chat_stream,
|
bot_name = global_config.bot.nickname
|
||||||
message_txt: str,
|
if global_config.bot.alias_names:
|
||||||
sender_name: str = "某人",
|
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
|
||||||
) -> str:
|
else:
|
||||||
prompt_personality = get_individuality().get_prompt(x_person=2, level=2)
|
bot_nickname = ""
|
||||||
is_group_chat = bool(chat_stream.group_info)
|
short_impression = await person_info_manager.get_value(bot_person_id, "short_impression")
|
||||||
|
try:
|
||||||
|
if isinstance(short_impression, str) and short_impression.strip():
|
||||||
|
short_impression = ast.literal_eval(short_impression)
|
||||||
|
elif not short_impression:
|
||||||
|
logger.warning("short_impression为空,使用默认值")
|
||||||
|
short_impression = ["友好活泼", "人类"]
|
||||||
|
except (ValueError, SyntaxError) as e:
|
||||||
|
logger.error(f"解析short_impression失败: {e}, 原始值: {short_impression}")
|
||||||
|
short_impression = ["友好活泼", "人类"]
|
||||||
|
|
||||||
|
if not isinstance(short_impression, list) or len(short_impression) < 2:
|
||||||
|
logger.warning(f"short_impression格式不正确: {short_impression}, 使用默认值")
|
||||||
|
short_impression = ["友好活泼", "人类"]
|
||||||
|
personality = short_impression[0]
|
||||||
|
identity = short_impression[1]
|
||||||
|
prompt_personality = personality + "," + identity
|
||||||
|
return f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:"
|
||||||
|
|
||||||
|
async def build_relation_info(self, chat_stream) -> str:
|
||||||
|
is_group_chat = bool(chat_stream.group_info)
|
||||||
who_chat_in_group = []
|
who_chat_in_group = []
|
||||||
if is_group_chat:
|
if is_group_chat:
|
||||||
who_chat_in_group = get_recent_group_speaker(
|
who_chat_in_group = get_recent_group_speaker(
|
||||||
chat_stream.stream_id,
|
chat_stream.stream_id,
|
||||||
(chat_stream.user_info.platform, chat_stream.user_info.user_id) if chat_stream.user_info else None,
|
(chat_stream.user_info.platform, chat_stream.user_info.user_id) if chat_stream.user_info else None,
|
||||||
limit=global_config.normal_chat.max_context_size,
|
limit=global_config.chat.max_context_size,
|
||||||
)
|
)
|
||||||
elif chat_stream.user_info:
|
elif chat_stream.user_info:
|
||||||
who_chat_in_group.append(
|
who_chat_in_group.append(
|
||||||
@@ -71,24 +98,29 @@ class PromptBuilder:
|
|||||||
)
|
)
|
||||||
|
|
||||||
relation_prompt = ""
|
relation_prompt = ""
|
||||||
if global_config.relationship.enable_relationship:
|
if global_config.relationship.enable_relationship and who_chat_in_group:
|
||||||
for person in who_chat_in_group:
|
|
||||||
relationship_manager = get_relationship_manager()
|
relationship_manager = get_relationship_manager()
|
||||||
relation_prompt += await relationship_manager.build_relationship_info(person)
|
relation_info_list = await asyncio.gather(
|
||||||
|
*[relationship_manager.build_relationship_info(person) for person in who_chat_in_group]
|
||||||
|
)
|
||||||
|
relation_info = "".join(relation_info_list)
|
||||||
|
if relation_info:
|
||||||
|
relation_prompt = await global_prompt_manager.format_prompt("relation_prompt", relation_info=relation_info)
|
||||||
|
return relation_prompt
|
||||||
|
|
||||||
memory_prompt = ""
|
async def build_memory_block(self, text: str) -> str:
|
||||||
related_memory = await hippocampus_manager.get_memory_from_text(
|
related_memory = await hippocampus_manager.get_memory_from_text(
|
||||||
text=message_txt, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False
|
text=text, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False
|
||||||
)
|
)
|
||||||
|
|
||||||
related_memory_info = ""
|
related_memory_info = ""
|
||||||
if related_memory:
|
if related_memory:
|
||||||
for memory in related_memory:
|
for memory in related_memory:
|
||||||
related_memory_info += memory[1]
|
related_memory_info += memory[1]
|
||||||
memory_prompt = await global_prompt_manager.format_prompt(
|
return await global_prompt_manager.format_prompt("memory_prompt", memory_info=related_memory_info)
|
||||||
"memory_prompt", related_memory_info=related_memory_info
|
return ""
|
||||||
)
|
|
||||||
|
|
||||||
|
def build_chat_history_prompts(self, chat_stream, message) -> (str, str):
|
||||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
||||||
chat_id=chat_stream.stream_id,
|
chat_id=chat_stream.stream_id,
|
||||||
timestamp=time.time(),
|
timestamp=time.time(),
|
||||||
@@ -96,9 +128,7 @@ class PromptBuilder:
|
|||||||
)
|
)
|
||||||
|
|
||||||
talk_type = message.message_info.platform + ":" + message.chat_stream.user_info.user_id
|
talk_type = message.message_info.platform + ":" + message.chat_stream.user_info.user_id
|
||||||
print(f"talk_type: {talk_type}")
|
|
||||||
|
|
||||||
# 分别筛选核心对话和背景对话
|
|
||||||
core_dialogue_list = []
|
core_dialogue_list = []
|
||||||
background_dialogue_list = []
|
background_dialogue_list = []
|
||||||
bot_id = str(global_config.bot.qq_account)
|
bot_id = str(global_config.bot.qq_account)
|
||||||
@@ -106,11 +136,9 @@ class PromptBuilder:
|
|||||||
|
|
||||||
for msg_dict in message_list_before_now:
|
for msg_dict in message_list_before_now:
|
||||||
try:
|
try:
|
||||||
# 直接通过字典访问
|
|
||||||
msg_user_id = str(msg_dict.get("user_id"))
|
msg_user_id = str(msg_dict.get("user_id"))
|
||||||
if msg_user_id == bot_id:
|
if msg_user_id == bot_id:
|
||||||
if msg_dict.get("reply_to") and talk_type == msg_dict.get("reply_to"):
|
if msg_dict.get("reply_to") and talk_type == msg_dict.get("reply_to"):
|
||||||
print(f"reply: {msg_dict.get('reply_to')}")
|
|
||||||
core_dialogue_list.append(msg_dict)
|
core_dialogue_list.append(msg_dict)
|
||||||
else:
|
else:
|
||||||
background_dialogue_list.append(msg_dict)
|
background_dialogue_list.append(msg_dict)
|
||||||
@@ -121,19 +149,19 @@ class PromptBuilder:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"无法处理历史消息记录: {msg_dict}, 错误: {e}")
|
logger.error(f"无法处理历史消息记录: {msg_dict}, 错误: {e}")
|
||||||
|
|
||||||
|
background_dialogue_prompt = ""
|
||||||
if background_dialogue_list:
|
if background_dialogue_list:
|
||||||
latest_25_msgs = background_dialogue_list[-25:]
|
latest_25_msgs = background_dialogue_list[-25:]
|
||||||
background_dialogue_prompt = build_readable_messages(
|
background_dialogue_prompt_str = build_readable_messages(
|
||||||
latest_25_msgs,
|
latest_25_msgs,
|
||||||
merge_messages=True,
|
merge_messages=True,
|
||||||
timestamp_mode="normal_no_YMD",
|
timestamp_mode="normal_no_YMD",
|
||||||
show_pic=False,
|
show_pic=False,
|
||||||
)
|
)
|
||||||
background_dialogue_prompt = f"这是其他用户的发言:\n{background_dialogue_prompt}"
|
background_dialogue_prompt = f"这是其他用户的发言:\n{background_dialogue_prompt_str}"
|
||||||
else:
|
|
||||||
background_dialogue_prompt = ""
|
|
||||||
|
|
||||||
# 分别获取最新50条和最新25条(从message_list_before_now截取)
|
core_msg_str = ""
|
||||||
|
if core_dialogue_list:
|
||||||
core_dialogue_list = core_dialogue_list[-50:]
|
core_dialogue_list = core_dialogue_list[-50:]
|
||||||
|
|
||||||
first_msg = core_dialogue_list[0]
|
first_msg = core_dialogue_list[0]
|
||||||
@@ -152,12 +180,8 @@ class PromptBuilder:
|
|||||||
for msg in core_dialogue_list[1:]:
|
for msg in core_dialogue_list[1:]:
|
||||||
speaker = msg.get("user_id")
|
speaker = msg.get("user_id")
|
||||||
if speaker == last_speaking_user_id:
|
if speaker == last_speaking_user_id:
|
||||||
# 还是同一个人讲话
|
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(msg.get('time')))}: {msg.get('processed_plain_text')}\n"
|
||||||
msg_seg_str += (
|
|
||||||
f"{time.strftime('%H:%M:%S', time.localtime(msg.get('time')))}: {msg.get('processed_plain_text')}\n"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# 换人了
|
|
||||||
msg_seg_str = f"{msg_seg_str}\n"
|
msg_seg_str = f"{msg_seg_str}\n"
|
||||||
all_msg_seg_list.append(msg_seg_str)
|
all_msg_seg_list.append(msg_seg_str)
|
||||||
|
|
||||||
@@ -166,36 +190,46 @@ class PromptBuilder:
|
|||||||
else:
|
else:
|
||||||
msg_seg_str = "对方的发言:\n"
|
msg_seg_str = "对方的发言:\n"
|
||||||
|
|
||||||
msg_seg_str += (
|
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(msg.get('time')))}: {msg.get('processed_plain_text')}\n"
|
||||||
f"{time.strftime('%H:%M:%S', time.localtime(msg.get('time')))}: {msg.get('processed_plain_text')}\n"
|
|
||||||
)
|
|
||||||
last_speaking_user_id = speaker
|
last_speaking_user_id = speaker
|
||||||
|
|
||||||
all_msg_seg_list.append(msg_seg_str)
|
all_msg_seg_list.append(msg_seg_str)
|
||||||
|
|
||||||
core_msg_str = ""
|
|
||||||
for msg in all_msg_seg_list:
|
for msg in all_msg_seg_list:
|
||||||
# print(f"msg: {msg}")
|
|
||||||
core_msg_str += msg
|
core_msg_str += msg
|
||||||
|
|
||||||
now_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
return core_msg_str, background_dialogue_prompt
|
||||||
now_time = f"现在的时间是:{now_time}"
|
|
||||||
|
|
||||||
|
async def build_prompt_normal(
|
||||||
|
self,
|
||||||
|
message,
|
||||||
|
chat_stream,
|
||||||
|
message_txt: str,
|
||||||
|
sender_name: str = "某人",
|
||||||
|
) -> str:
|
||||||
|
|
||||||
|
identity_block, relation_info_block, memory_block = await asyncio.gather(
|
||||||
|
self.build_identity_block(),
|
||||||
|
self.build_relation_info(chat_stream),
|
||||||
|
self.build_memory_block(message_txt)
|
||||||
|
)
|
||||||
|
|
||||||
|
core_dialogue_prompt, background_dialogue_prompt = self.build_chat_history_prompts(chat_stream, message)
|
||||||
|
|
||||||
|
time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||||
|
|
||||||
template_name = "s4u_prompt"
|
template_name = "s4u_prompt"
|
||||||
effective_sender_name = sender_name
|
|
||||||
|
|
||||||
prompt = await global_prompt_manager.format_prompt(
|
prompt = await global_prompt_manager.format_prompt(
|
||||||
template_name,
|
template_name,
|
||||||
relation_prompt=relation_prompt,
|
identity_block=identity_block,
|
||||||
sender_name=effective_sender_name,
|
time_block=time_block,
|
||||||
memory_prompt=memory_prompt,
|
relation_info_block=relation_info_block,
|
||||||
core_dialogue_prompt=core_msg_str,
|
memory_block=memory_block,
|
||||||
|
sender_name=sender_name,
|
||||||
|
core_dialogue_prompt=core_dialogue_prompt,
|
||||||
background_dialogue_prompt=background_dialogue_prompt,
|
background_dialogue_prompt=background_dialogue_prompt,
|
||||||
message_txt=message_txt,
|
message_txt=message_txt,
|
||||||
bot_name=global_config.bot.nickname,
|
|
||||||
bot_other_names="/".join(global_config.bot.alias_names),
|
|
||||||
prompt_personality=prompt_personality,
|
|
||||||
now_time=now_time,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return prompt
|
return prompt
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ class S4UStreamGenerator:
|
|||||||
|
|
||||||
self.model_sum = LLMRequest(model=global_config.model.memory_summary, temperature=0.7, request_type="relation")
|
self.model_sum = LLMRequest(model=global_config.model.memory_summary, temperature=0.7, request_type="relation")
|
||||||
self.current_model_name = "unknown model"
|
self.current_model_name = "unknown model"
|
||||||
|
self.partial_response = ""
|
||||||
|
|
||||||
# 正则表达式用于按句子切分,同时处理各种标点和边缘情况
|
# 正则表达式用于按句子切分,同时处理各种标点和边缘情况
|
||||||
# 匹配常见的句子结束符,但会忽略引号内和数字中的标点
|
# 匹配常见的句子结束符,但会忽略引号内和数字中的标点
|
||||||
@@ -52,6 +53,7 @@ class S4UStreamGenerator:
|
|||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""根据当前模型类型选择对应的生成函数"""
|
"""根据当前模型类型选择对应的生成函数"""
|
||||||
# 从global_config中获取模型概率值并选择模型
|
# 从global_config中获取模型概率值并选择模型
|
||||||
|
self.partial_response = ""
|
||||||
current_client = self.client_1
|
current_client = self.client_1
|
||||||
self.current_model_name = self.model_1_name
|
self.current_model_name = self.model_1_name
|
||||||
|
|
||||||
@@ -134,6 +136,7 @@ class S4UStreamGenerator:
|
|||||||
if to_yield.endswith((",", ",")):
|
if to_yield.endswith((",", ",")):
|
||||||
to_yield = to_yield.rstrip(",,")
|
to_yield = to_yield.rstrip(",,")
|
||||||
|
|
||||||
|
self.partial_response += to_yield
|
||||||
yield to_yield
|
yield to_yield
|
||||||
punctuation_buffer = "" # 清空标点符号缓冲区
|
punctuation_buffer = "" # 清空标点符号缓冲区
|
||||||
await asyncio.sleep(0) # 允许其他任务运行
|
await asyncio.sleep(0) # 允许其他任务运行
|
||||||
@@ -150,4 +153,5 @@ class S4UStreamGenerator:
|
|||||||
if to_yield.endswith((",", ",")):
|
if to_yield.endswith((",", ",")):
|
||||||
to_yield = to_yield.rstrip(",,")
|
to_yield = to_yield.rstrip(",,")
|
||||||
if to_yield:
|
if to_yield:
|
||||||
|
self.partial_response += to_yield
|
||||||
yield to_yield
|
yield to_yield
|
||||||
|
|||||||
Reference in New Issue
Block a user