移除mai4u:s4u_watching_manager.py, screen_manager.py, super_chat_manager.py, yes_or_no.py, openai_client.py, and s4u_config.py. These changes streamline the codebase by eliminating unused components and improving maintainability.
This commit is contained in:
@@ -9,13 +9,12 @@ from maim_message import UserInfo
|
||||
from src.chat.antipromptinjector import initialize_anti_injector
|
||||
from src.chat.message_manager import message_manager
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.chat.utils.prompt import create_prompt_async, global_prompt_manager
|
||||
from src.chat.utils.utils import is_mentioned_bot_in_message
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
||||
from src.mood.mood_manager import mood_manager # 导入情绪管理器
|
||||
from src.plugin_system.base import BaseCommand, EventType
|
||||
from src.plugin_system.core import component_registry, event_manager, global_announcement_manager
|
||||
@@ -73,9 +72,6 @@ class ChatBot:
|
||||
self.bot = None # bot 实例引用
|
||||
self._started = False
|
||||
self.mood_manager = mood_manager # 获取情绪管理器单例
|
||||
# 亲和力流消息处理器 - 直接使用全局afc_manager
|
||||
|
||||
self.s4u_message_processor = S4UMessageProcessor()
|
||||
|
||||
# 初始化反注入系统
|
||||
self._initialize_anti_injector()
|
||||
@@ -364,27 +360,6 @@ class ChatBot:
|
||||
except Exception as e:
|
||||
logger.error(f"处理适配器响应时出错: {e}")
|
||||
|
||||
async def do_s4u(self, message_data: dict[str, Any]):
|
||||
message = MessageRecvS4U(message_data)
|
||||
group_info = message.message_info.group_info
|
||||
user_info = message.message_info.user_info
|
||||
|
||||
get_chat_manager().register_message(message)
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform=message.message_info.platform, # type: ignore
|
||||
user_info=user_info, # type: ignore
|
||||
group_info=group_info,
|
||||
)
|
||||
|
||||
message.update_chat_stream(chat)
|
||||
|
||||
# 处理消息内容
|
||||
await message.process()
|
||||
|
||||
await self.s4u_message_processor.process_message(message)
|
||||
|
||||
return
|
||||
|
||||
async def message_process(self, message_data: dict[str, Any]) -> None:
|
||||
"""处理转化后的统一格式消息"""
|
||||
try:
|
||||
@@ -419,10 +394,6 @@ class ChatBot:
|
||||
|
||||
platform = message_info.get("platform")
|
||||
|
||||
if platform == "amaidesu_default":
|
||||
await self.do_s4u(message_data)
|
||||
return
|
||||
|
||||
if message_info.get("group_info") is not None:
|
||||
message_info["group_info"]["group_id"] = str(
|
||||
message_info["group_info"]["group_id"]
|
||||
|
||||
@@ -309,206 +309,6 @@ class MessageRecv(Message):
|
||||
return f"[处理失败的{segment.type}消息]"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageRecvS4U(MessageRecv):
|
||||
def __init__(self, message_dict: dict[str, Any]):
|
||||
super().__init__(message_dict)
|
||||
self.is_gift = False
|
||||
self.is_fake_gift = False
|
||||
self.is_superchat = False
|
||||
self.gift_info = None
|
||||
self.gift_name = None
|
||||
self.gift_count: int | None = None
|
||||
self.superchat_info = None
|
||||
self.superchat_price = None
|
||||
self.superchat_message_text = None
|
||||
self.is_screen = False
|
||||
self.is_internal = False
|
||||
self.voice_done = None
|
||||
|
||||
self.chat_info = None
|
||||
|
||||
async def process(self) -> None:
|
||||
self.processed_plain_text = await self._process_message_segments(self.message_segment)
|
||||
|
||||
async def _process_single_segment(self, segment: Seg) -> str:
|
||||
"""处理单个消息段
|
||||
|
||||
Args:
|
||||
segment: 消息段
|
||||
|
||||
Returns:
|
||||
str: 处理后的文本
|
||||
"""
|
||||
try:
|
||||
if segment.type == "text":
|
||||
self.is_voice = False
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
return segment.data # type: ignore
|
||||
elif segment.type == "image":
|
||||
self.is_voice = False
|
||||
# 如果是base64图片数据
|
||||
if isinstance(segment.data, str):
|
||||
self.has_picid = True
|
||||
self.is_picid = True
|
||||
self.is_emoji = False
|
||||
image_manager = get_image_manager()
|
||||
# print(f"segment.data: {segment.data}")
|
||||
_, processed_text = await image_manager.process_image(segment.data)
|
||||
return processed_text
|
||||
return "[发了一张图片,网卡了加载不出来]"
|
||||
elif segment.type == "emoji":
|
||||
self.has_emoji = True
|
||||
self.is_emoji = True
|
||||
self.is_picid = False
|
||||
if isinstance(segment.data, str):
|
||||
return await get_image_manager().get_emoji_description(segment.data)
|
||||
return "[发了一个表情包,网卡了加载不出来]"
|
||||
elif segment.type == "voice":
|
||||
self.has_picid = False
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
self.is_voice = True
|
||||
|
||||
# 检查消息是否由机器人自己发送
|
||||
# 检查消息是否由机器人自己发送
|
||||
if self.message_info and self.message_info.user_info and str(self.message_info.user_info.user_id) == str(global_config.bot.qq_account):
|
||||
logger.info(f"检测到机器人自身发送的语音消息 (User ID: {self.message_info.user_info.user_id}),尝试从缓存获取文本。")
|
||||
if isinstance(segment.data, str):
|
||||
cached_text = consume_self_voice_text(segment.data)
|
||||
if cached_text:
|
||||
logger.info(f"成功从缓存中获取语音文本: '{cached_text[:70]}...'")
|
||||
return f"[语音:{cached_text}]"
|
||||
else:
|
||||
logger.warning("机器人自身语音消息缓存未命中,将回退到标准语音识别。")
|
||||
|
||||
# 标准语音识别流程 (也作为缓存未命中的后备方案)
|
||||
if isinstance(segment.data, str):
|
||||
return await get_voice_text(segment.data)
|
||||
return "[发了一段语音,网卡了加载不出来]"
|
||||
elif segment.type == "mention_bot":
|
||||
self.is_voice = False
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
self.is_mentioned = float(segment.data) # type: ignore
|
||||
return ""
|
||||
elif segment.type == "priority_info":
|
||||
self.is_voice = False
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
if isinstance(segment.data, dict):
|
||||
# 处理优先级信息
|
||||
self.priority_mode = "priority"
|
||||
self.priority_info = segment.data
|
||||
"""
|
||||
{
|
||||
'message_type': 'vip', # vip or normal
|
||||
'message_priority': 1.0, # 优先级,大为优先,float
|
||||
}
|
||||
"""
|
||||
return ""
|
||||
elif segment.type == "gift":
|
||||
self.is_voice = False
|
||||
self.is_gift = True
|
||||
# 解析gift_info,格式为"名称:数量"
|
||||
name, count = segment.data.split(":", 1) # type: ignore
|
||||
self.gift_info = segment.data
|
||||
self.gift_name = name.strip()
|
||||
self.gift_count = int(count.strip())
|
||||
return ""
|
||||
elif segment.type == "voice_done":
|
||||
msg_id = segment.data
|
||||
logger.info(f"voice_done: {msg_id}")
|
||||
self.voice_done = msg_id
|
||||
return ""
|
||||
elif segment.type == "superchat":
|
||||
self.is_superchat = True
|
||||
self.superchat_info = segment.data
|
||||
price, message_text = segment.data.split(":", 1) # type: ignore
|
||||
self.superchat_price = price.strip()
|
||||
self.superchat_message_text = message_text.strip()
|
||||
|
||||
self.processed_plain_text = str(self.superchat_message_text)
|
||||
self.processed_plain_text += (
|
||||
f"(注意:这是一条超级弹幕信息,价值{self.superchat_price}元,请你认真回复)"
|
||||
)
|
||||
|
||||
return self.processed_plain_text
|
||||
elif segment.type == "screen":
|
||||
self.is_screen = True
|
||||
self.screen_info = segment.data
|
||||
return "屏幕信息"
|
||||
elif segment.type == "file":
|
||||
if isinstance(segment.data, dict):
|
||||
file_name = segment.data.get('name', '未知文件')
|
||||
file_size = segment.data.get('size', '未知大小')
|
||||
return f"[文件:{file_name} ({file_size}字节)]"
|
||||
return "[收到一个文件]"
|
||||
elif segment.type == "video":
|
||||
self.is_voice = False
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
|
||||
logger.info(f"接收到视频消息,数据类型: {type(segment.data)}")
|
||||
|
||||
# 检查视频分析功能是否可用
|
||||
if not is_video_analysis_available():
|
||||
logger.warning("⚠️ Rust视频处理模块不可用,跳过视频分析")
|
||||
return "[视频]"
|
||||
|
||||
if global_config.video_analysis.enable:
|
||||
logger.info("已启用视频识别,开始识别")
|
||||
if isinstance(segment.data, dict):
|
||||
try:
|
||||
# 从Adapter接收的视频数据
|
||||
video_base64 = segment.data.get("base64")
|
||||
filename = segment.data.get("filename", "video.mp4")
|
||||
|
||||
logger.info(f"视频文件名: {filename}")
|
||||
logger.info(f"Base64数据长度: {len(video_base64) if video_base64 else 0}")
|
||||
|
||||
if video_base64:
|
||||
# 解码base64视频数据
|
||||
video_bytes = base64.b64decode(video_base64)
|
||||
logger.info(f"解码后视频大小: {len(video_bytes)} 字节")
|
||||
|
||||
# 使用video analyzer分析视频
|
||||
video_analyzer = get_video_analyzer()
|
||||
result = await video_analyzer.analyze_video_from_bytes(
|
||||
video_bytes, filename
|
||||
)
|
||||
|
||||
logger.info(f"视频分析结果: {result}")
|
||||
|
||||
# 返回视频分析结果
|
||||
summary = result.get("summary", "")
|
||||
if summary:
|
||||
return f"[视频内容] {summary}"
|
||||
else:
|
||||
return "[已收到视频,但分析失败]"
|
||||
else:
|
||||
logger.warning("视频消息中没有base64数据")
|
||||
return "[收到视频消息,但数据异常]"
|
||||
except Exception as e:
|
||||
logger.error(f"视频处理失败: {e!s}")
|
||||
import traceback
|
||||
|
||||
logger.error(f"错误详情: {traceback.format_exc()}")
|
||||
return "[收到视频,但处理时出现错误]"
|
||||
else:
|
||||
logger.warning(f"视频消息数据不是字典格式: {type(segment.data)}")
|
||||
return "[发了一个视频,但格式不支持]"
|
||||
else:
|
||||
return ""
|
||||
else:
|
||||
logger.warning(f"未知的消息段类型: {segment.type}")
|
||||
return f"[{segment.type} 消息]"
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}")
|
||||
return f"[处理失败的{segment.type}消息]"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageProcessBase(Message):
|
||||
"""消息处理基类,用于处理中和发送中的消息"""
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
[inner]
|
||||
version = "1.0.0"
|
||||
|
||||
#----以下是S4U聊天系统配置文件----
|
||||
# S4U (Smart 4 U) 聊天系统是MaiBot的核心对话模块
|
||||
# 支持优先级队列、消息中断、VIP用户等高级功能
|
||||
#
|
||||
# 如果你想要修改配置文件,请在修改后将version的值进行变更
|
||||
# 如果新增项目,请参考src/mais4u/s4u_config.py中的S4UConfig类
|
||||
#
|
||||
# 版本格式:主版本号.次版本号.修订号
|
||||
#----S4U配置说明结束----
|
||||
|
||||
[s4u]
|
||||
# 消息管理配置
|
||||
message_timeout_seconds = 120 # 普通消息存活时间(秒),超过此时间的消息将被丢弃
|
||||
recent_message_keep_count = 6 # 保留最近N条消息,超出范围的普通消息将被移除
|
||||
|
||||
# 优先级系统配置
|
||||
at_bot_priority_bonus = 100.0 # @机器人时的优先级加成分数
|
||||
vip_queue_priority = true # 是否启用VIP队列优先级系统
|
||||
enable_message_interruption = true # 是否允许高优先级消息中断当前回复
|
||||
|
||||
# 打字效果配置
|
||||
typing_delay = 0.1 # 打字延迟时间(秒),模拟真实打字速度
|
||||
enable_dynamic_typing_delay = false # 是否启用基于文本长度的动态打字延迟
|
||||
|
||||
# 动态打字延迟参数(仅在enable_dynamic_typing_delay=true时生效)
|
||||
chars_per_second = 15.0 # 每秒字符数,用于计算动态打字延迟
|
||||
min_typing_delay = 0.2 # 最小打字延迟(秒)
|
||||
max_typing_delay = 2.0 # 最大打字延迟(秒)
|
||||
|
||||
# 系统功能开关
|
||||
enable_old_message_cleanup = true # 是否自动清理过旧的普通消息
|
||||
enable_loading_indicator = true # 是否显示加载提示
|
||||
|
||||
@@ -1,132 +0,0 @@
|
||||
[inner]
|
||||
version = "1.1.0"
|
||||
|
||||
#----以下是S4U聊天系统配置文件----
|
||||
# S4U (Smart 4 U) 聊天系统是MaiBot的核心对话模块
|
||||
# 支持优先级队列、消息中断、VIP用户等高级功能
|
||||
#
|
||||
# 如果你想要修改配置文件,请在修改后将version的值进行变更
|
||||
# 如果新增项目,请参考src/mais4u/s4u_config.py中的S4UConfig类
|
||||
#
|
||||
# 版本格式:主版本号.次版本号.修订号
|
||||
#----S4U配置说明结束----
|
||||
|
||||
[s4u]
|
||||
# 消息管理配置
|
||||
message_timeout_seconds = 80 # 普通消息存活时间(秒),超过此时间的消息将被丢弃
|
||||
recent_message_keep_count = 8 # 保留最近N条消息,超出范围的普通消息将被移除
|
||||
|
||||
# 优先级系统配置
|
||||
at_bot_priority_bonus = 100.0 # @机器人时的优先级加成分数
|
||||
vip_queue_priority = true # 是否启用VIP队列优先级系统
|
||||
enable_message_interruption = true # 是否允许高优先级消息中断当前回复
|
||||
|
||||
# 打字效果配置
|
||||
typing_delay = 0.1 # 打字延迟时间(秒),模拟真实打字速度
|
||||
enable_dynamic_typing_delay = false # 是否启用基于文本长度的动态打字延迟
|
||||
|
||||
# 动态打字延迟参数(仅在enable_dynamic_typing_delay=true时生效)
|
||||
chars_per_second = 15.0 # 每秒字符数,用于计算动态打字延迟
|
||||
min_typing_delay = 0.2 # 最小打字延迟(秒)
|
||||
max_typing_delay = 2.0 # 最大打字延迟(秒)
|
||||
|
||||
# 系统功能开关
|
||||
enable_old_message_cleanup = true # 是否自动清理过旧的普通消息
|
||||
enable_loading_indicator = true # 是否显示加载提示
|
||||
|
||||
enable_streaming_output = false # 是否启用流式输出,false时全部生成后一次性发送
|
||||
|
||||
max_context_message_length = 30
|
||||
max_core_message_length = 20
|
||||
|
||||
# 模型配置
|
||||
[models]
|
||||
# 主要对话模型配置
|
||||
[models.chat]
|
||||
name = "qwen3-8b"
|
||||
provider = "BAILIAN"
|
||||
pri_in = 0.5
|
||||
pri_out = 2
|
||||
temp = 0.7
|
||||
enable_thinking = false
|
||||
|
||||
# 规划模型配置
|
||||
[models.motion]
|
||||
name = "qwen3-8b"
|
||||
provider = "BAILIAN"
|
||||
pri_in = 0.5
|
||||
pri_out = 2
|
||||
temp = 0.7
|
||||
enable_thinking = false
|
||||
|
||||
# 情感分析模型配置
|
||||
[models.emotion]
|
||||
name = "qwen3-8b"
|
||||
provider = "BAILIAN"
|
||||
pri_in = 0.5
|
||||
pri_out = 2
|
||||
temp = 0.7
|
||||
|
||||
# 记忆模型配置
|
||||
[models.memory]
|
||||
name = "qwen3-8b"
|
||||
provider = "BAILIAN"
|
||||
pri_in = 0.5
|
||||
pri_out = 2
|
||||
temp = 0.7
|
||||
|
||||
# 工具使用模型配置
|
||||
[models.tool_use]
|
||||
name = "qwen3-8b"
|
||||
provider = "BAILIAN"
|
||||
pri_in = 0.5
|
||||
pri_out = 2
|
||||
temp = 0.7
|
||||
|
||||
# 嵌入模型配置
|
||||
[models.embedding]
|
||||
name = "text-embedding-v1"
|
||||
provider = "OPENAI"
|
||||
dimension = 1024
|
||||
|
||||
# 视觉语言模型配置
|
||||
[models.vlm]
|
||||
name = "qwen-vl-plus"
|
||||
provider = "BAILIAN"
|
||||
pri_in = 0.5
|
||||
pri_out = 2
|
||||
temp = 0.7
|
||||
|
||||
# 知识库模型配置
|
||||
[models.knowledge]
|
||||
name = "qwen3-8b"
|
||||
provider = "BAILIAN"
|
||||
pri_in = 0.5
|
||||
pri_out = 2
|
||||
temp = 0.7
|
||||
|
||||
# 实体提取模型配置
|
||||
[models.entity_extract]
|
||||
name = "qwen3-8b"
|
||||
provider = "BAILIAN"
|
||||
pri_in = 0.5
|
||||
pri_out = 2
|
||||
temp = 0.7
|
||||
|
||||
# 问答模型配置
|
||||
[models.qa]
|
||||
name = "qwen3-8b"
|
||||
provider = "BAILIAN"
|
||||
pri_in = 0.5
|
||||
pri_out = 2
|
||||
temp = 0.7
|
||||
|
||||
# 兼容性配置(已废弃,请使用models.motion)
|
||||
[model_motion] # 在麦麦的一些组件中使用的小模型,消耗量较大,建议使用速度较快的小模型
|
||||
# 强烈建议使用免费的小模型
|
||||
name = "qwen3-8b"
|
||||
provider = "BAILIAN"
|
||||
pri_in = 0.5
|
||||
pri_out = 2
|
||||
temp = 0.7
|
||||
enable_thinking = false # 是否启用思考
|
||||
@@ -1,67 +0,0 @@
|
||||
[inner]
|
||||
version = "1.1.0"
|
||||
|
||||
#----以下是S4U聊天系统配置文件----
|
||||
# S4U (Smart 4 U) 聊天系统是MaiBot的核心对话模块
|
||||
# 支持优先级队列、消息中断、VIP用户等高级功能
|
||||
#
|
||||
# 如果你想要修改配置文件,请在修改后将version的值进行变更
|
||||
# 如果新增项目,请参考src/mais4u/s4u_config.py中的S4UConfig类
|
||||
#
|
||||
# 版本格式:主版本号.次版本号.修订号
|
||||
#----S4U配置说明结束----
|
||||
|
||||
[s4u]
|
||||
# 消息管理配置
|
||||
message_timeout_seconds = 120 # 普通消息存活时间(秒),超过此时间的消息将被丢弃
|
||||
recent_message_keep_count = 6 # 保留最近N条消息,超出范围的普通消息将被移除
|
||||
|
||||
# 优先级系统配置
|
||||
at_bot_priority_bonus = 100.0 # @机器人时的优先级加成分数
|
||||
vip_queue_priority = true # 是否启用VIP队列优先级系统
|
||||
enable_message_interruption = true # 是否允许高优先级消息中断当前回复
|
||||
|
||||
# 打字效果配置
|
||||
typing_delay = 0.1 # 打字延迟时间(秒),模拟真实打字速度
|
||||
enable_dynamic_typing_delay = false # 是否启用基于文本长度的动态打字延迟
|
||||
|
||||
# 动态打字延迟参数(仅在enable_dynamic_typing_delay=true时生效)
|
||||
chars_per_second = 15.0 # 每秒字符数,用于计算动态打字延迟
|
||||
min_typing_delay = 0.2 # 最小打字延迟(秒)
|
||||
max_typing_delay = 2.0 # 最大打字延迟(秒)
|
||||
|
||||
# 系统功能开关
|
||||
enable_old_message_cleanup = true # 是否自动清理过旧的普通消息
|
||||
|
||||
enable_streaming_output = true # 是否启用流式输出,false时全部生成后一次性发送
|
||||
|
||||
max_context_message_length = 20
|
||||
max_core_message_length = 30
|
||||
|
||||
# 模型配置
|
||||
[models]
|
||||
# 主要对话模型配置
|
||||
[models.chat]
|
||||
name = "qwen3-8b"
|
||||
provider = "BAILIAN"
|
||||
pri_in = 0.5
|
||||
pri_out = 2
|
||||
temp = 0.7
|
||||
enable_thinking = false
|
||||
|
||||
# 规划模型配置
|
||||
[models.motion]
|
||||
name = "qwen3-32b"
|
||||
provider = "BAILIAN"
|
||||
pri_in = 0.5
|
||||
pri_out = 2
|
||||
temp = 0.7
|
||||
enable_thinking = false
|
||||
|
||||
# 情感分析模型配置
|
||||
[models.emotion]
|
||||
name = "qwen3-8b"
|
||||
provider = "BAILIAN"
|
||||
pri_in = 0.5
|
||||
pri_out = 2
|
||||
temp = 0.7
|
||||
@@ -1 +0,0 @@
|
||||
ENABLE_S4U = False
|
||||
@@ -1,178 +0,0 @@
|
||||
import time
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.message_receive.message import MessageRecvS4U
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.mais4u.mais4u_chat.internal_manager import internal_manager
|
||||
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def init_prompt():
|
||||
Prompt(
|
||||
"""
|
||||
你之前的内心想法是:{mind}
|
||||
|
||||
{memory_block}
|
||||
{relation_info_block}
|
||||
|
||||
{chat_target}
|
||||
{time_block}
|
||||
{chat_info}
|
||||
{identity}
|
||||
|
||||
你刚刚在{chat_target_2},你你刚刚的心情是:{mood_state}
|
||||
---------------------
|
||||
在这样的情况下,你对上面的内容,你对 {sender} 发送的 消息 “{target}” 进行了回复
|
||||
你刚刚选择回复的内容是:{reponse}
|
||||
现在,根据你之前的想法和回复的内容,推测你现在的想法,思考你现在的想法是什么,为什么做出上面的回复内容
|
||||
请不要浮夸和夸张修辞,不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出想法:""",
|
||||
"after_response_think_prompt",
|
||||
)
|
||||
|
||||
|
||||
class MaiThinking:
|
||||
def __init__(self, chat_id):
|
||||
self.chat_id = chat_id
|
||||
# 这些将在异步初始化中设置
|
||||
self.chat_stream = None # type: ignore
|
||||
self.platform = None
|
||||
self.is_group = False
|
||||
self._initialized = False
|
||||
|
||||
self.s4u_message_processor = S4UMessageProcessor()
|
||||
|
||||
self.mind = ""
|
||||
|
||||
self.memory_block = ""
|
||||
self.relation_info_block = ""
|
||||
self.time_block = ""
|
||||
self.chat_target = ""
|
||||
self.chat_target_2 = ""
|
||||
self.chat_info = ""
|
||||
self.mood_state = ""
|
||||
self.identity = ""
|
||||
self.sender = ""
|
||||
self.target = ""
|
||||
|
||||
self.thinking_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type="thinking")
|
||||
|
||||
async def _initialize(self):
|
||||
"""异步初始化方法"""
|
||||
if not self._initialized:
|
||||
self.chat_stream = await get_chat_manager().get_stream(self.chat_id)
|
||||
if self.chat_stream:
|
||||
self.platform = self.chat_stream.platform
|
||||
self.is_group = bool(self.chat_stream.group_info)
|
||||
self._initialized = True
|
||||
|
||||
async def do_think_before_response(self):
|
||||
pass
|
||||
|
||||
async def do_think_after_response(self, reponse: str):
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"after_response_think_prompt",
|
||||
mind=self.mind,
|
||||
reponse=reponse,
|
||||
memory_block=self.memory_block,
|
||||
relation_info_block=self.relation_info_block,
|
||||
time_block=self.time_block,
|
||||
chat_target=self.chat_target,
|
||||
chat_target_2=self.chat_target_2,
|
||||
chat_info=self.chat_info,
|
||||
mood_state=self.mood_state,
|
||||
identity=self.identity,
|
||||
sender=self.sender,
|
||||
target=self.target,
|
||||
)
|
||||
|
||||
result, _ = await self.thinking_model.generate_response_async(prompt)
|
||||
self.mind = result
|
||||
|
||||
logger.info(f"[{self.chat_id}] 思考前想法:{self.mind}")
|
||||
# logger.info(f"[{self.chat_id}] 思考前prompt:{prompt}")
|
||||
logger.info(f"[{self.chat_id}] 思考后想法:{self.mind}")
|
||||
|
||||
msg_recv = await self.build_internal_message_recv(self.mind)
|
||||
await self.s4u_message_processor.process_message(msg_recv)
|
||||
internal_manager.set_internal_state(self.mind)
|
||||
|
||||
async def do_think_when_receive_message(self):
|
||||
pass
|
||||
|
||||
async def build_internal_message_recv(self, message_text: str):
|
||||
# 初始化
|
||||
await self._initialize()
|
||||
|
||||
msg_id = f"internal_{time.time()}"
|
||||
|
||||
message_dict = {
|
||||
"message_info": {
|
||||
"message_id": msg_id,
|
||||
"time": time.time(),
|
||||
"user_info": {
|
||||
"user_id": "internal", # 内部用户ID
|
||||
"user_nickname": "内心", # 内部昵称
|
||||
"platform": self.platform, # 平台标记为 internal
|
||||
# 其他 user_info 字段按需补充
|
||||
},
|
||||
"platform": self.platform, # 平台
|
||||
# 其他 message_info 字段按需补充
|
||||
},
|
||||
"message_segment": {
|
||||
"type": "text", # 消息类型
|
||||
"data": message_text, # 消息内容
|
||||
# 其他 segment 字段按需补充
|
||||
},
|
||||
"raw_message": message_text, # 原始消息内容
|
||||
"processed_plain_text": message_text, # 处理后的纯文本
|
||||
# 下面这些字段可选,根据 MessageRecv 需要
|
||||
"is_emoji": False,
|
||||
"has_emoji": False,
|
||||
"is_picid": False,
|
||||
"has_picid": False,
|
||||
"is_voice": False,
|
||||
"is_mentioned": False,
|
||||
"is_command": False,
|
||||
"is_internal": True,
|
||||
"priority_mode": "interest",
|
||||
"priority_info": {"message_priority": 10.0}, # 内部消息可设高优先级
|
||||
"interest_value": 1.0,
|
||||
}
|
||||
|
||||
if self.is_group:
|
||||
message_dict["message_info"]["group_info"] = {
|
||||
"platform": self.platform,
|
||||
"group_id": self.chat_stream.group_info.group_id,
|
||||
"group_name": self.chat_stream.group_info.group_name,
|
||||
}
|
||||
|
||||
msg_recv = MessageRecvS4U(message_dict)
|
||||
msg_recv.chat_info = self.chat_info
|
||||
msg_recv.chat_stream = self.chat_stream
|
||||
msg_recv.is_internal = True
|
||||
|
||||
return msg_recv
|
||||
|
||||
|
||||
class MaiThinkingManager:
|
||||
def __init__(self):
|
||||
self.mai_think_list = []
|
||||
|
||||
def get_mai_think(self, chat_id):
|
||||
for mai_think in self.mai_think_list:
|
||||
if mai_think.chat_id == chat_id:
|
||||
return mai_think
|
||||
mai_think = MaiThinking(chat_id)
|
||||
self.mai_think_list.append(mai_think)
|
||||
return mai_think
|
||||
|
||||
|
||||
mai_thinking_manager = MaiThinkingManager()
|
||||
|
||||
|
||||
init_prompt()
|
||||
@@ -1,306 +0,0 @@
|
||||
import time
|
||||
|
||||
import orjson
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.mais4u.s4u_config import s4u_config
|
||||
from src.manager.async_task_manager import AsyncTask, async_task_manager
|
||||
from src.plugin_system.apis import send_api
|
||||
|
||||
logger = get_logger("action")
|
||||
|
||||
HEAD_CODE = {
|
||||
"看向上方": "(0,0.5,0)",
|
||||
"看向下方": "(0,-0.5,0)",
|
||||
"看向左边": "(-1,0,0)",
|
||||
"看向右边": "(1,0,0)",
|
||||
"随意朝向": "random",
|
||||
"看向摄像机": "camera",
|
||||
"注视对方": "(0,0,0)",
|
||||
"看向正前方": "(0,0,0)",
|
||||
}
|
||||
|
||||
BODY_CODE = {
|
||||
"双手背后向前弯腰": "010_0070",
|
||||
"歪头双手合十": "010_0100",
|
||||
"标准文静站立": "010_0101",
|
||||
"双手交叠腹部站立": "010_0150",
|
||||
"帅气的姿势": "010_0190",
|
||||
"另一个帅气的姿势": "010_0191",
|
||||
"手掌朝前可爱": "010_0210",
|
||||
"平静,双手后放": "平静,双手后放",
|
||||
"思考": "思考",
|
||||
"优雅,左手放在腰上": "优雅,左手放在腰上",
|
||||
"一般": "一般",
|
||||
"可爱,双手前放": "可爱,双手前放",
|
||||
}
|
||||
|
||||
|
||||
def init_prompt():
|
||||
Prompt(
|
||||
"""
|
||||
{chat_talking_prompt}
|
||||
以上是群里正在进行的聊天记录
|
||||
|
||||
{indentify_block}
|
||||
你现在的动作状态是:
|
||||
- 身体动作:{body_action}
|
||||
|
||||
现在,因为你发送了消息,或者群里其他人发送了消息,引起了你的注意,你对其进行了阅读和思考,请你更新你的动作状态。
|
||||
身体动作可选:
|
||||
{all_actions}
|
||||
|
||||
请只按照以下json格式输出,描述你新的动作状态,确保每个字段都存在:
|
||||
{{
|
||||
"body_action": "..."
|
||||
}}
|
||||
""",
|
||||
"change_action_prompt",
|
||||
)
|
||||
Prompt(
|
||||
"""
|
||||
{chat_talking_prompt}
|
||||
以上是群里最近的聊天记录
|
||||
|
||||
{indentify_block}
|
||||
你之前的动作状态是
|
||||
- 身体动作:{body_action}
|
||||
|
||||
身体动作可选:
|
||||
{all_actions}
|
||||
|
||||
距离你上次关注群里消息已经过去了一段时间,你冷静了下来,你的动作会趋于平缓或静止,请你输出你现在新的动作状态,用中文。
|
||||
请只按照以下json格式输出,描述你新的动作状态,确保每个字段都存在:
|
||||
{{
|
||||
"body_action": "..."
|
||||
}}
|
||||
""",
|
||||
"regress_action_prompt",
|
||||
)
|
||||
|
||||
|
||||
class ChatAction:
|
||||
def __init__(self, chat_id: str):
|
||||
self.chat_id: str = chat_id
|
||||
self.body_action: str = "一般"
|
||||
self.head_action: str = "注视摄像机"
|
||||
|
||||
self.regression_count: int = 0
|
||||
# 新增:body_action冷却池,key为动作名,value为剩余冷却次数
|
||||
self.body_action_cooldown: dict[str, int] = {}
|
||||
|
||||
print(s4u_config.models.motion)
|
||||
print(model_config.model_task_config.emotion)
|
||||
|
||||
self.action_model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="motion")
|
||||
|
||||
self.last_change_time: float = 0
|
||||
|
||||
async def send_action_update(self):
|
||||
"""发送动作更新到前端"""
|
||||
|
||||
body_code = BODY_CODE.get(self.body_action, "")
|
||||
await send_api.custom_to_stream(
|
||||
message_type="body_action",
|
||||
content=body_code,
|
||||
stream_id=self.chat_id,
|
||||
storage_message=False,
|
||||
show_log=True,
|
||||
)
|
||||
|
||||
async def update_action_by_message(self, message: MessageRecv):
|
||||
self.regression_count = 0
|
||||
|
||||
message_time: float = message.message_info.time # type: ignore
|
||||
message_list_before_now = await get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_change_time,
|
||||
timestamp_end=message_time,
|
||||
limit=15,
|
||||
limit_mode="last",
|
||||
)
|
||||
chat_talking_prompt = await build_readable_messages(
|
||||
message_list_before_now,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=0.0,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
if global_config.bot.alias_names:
|
||||
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
|
||||
else:
|
||||
bot_nickname = ""
|
||||
|
||||
prompt_personality = global_config.personality.personality_core
|
||||
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:"
|
||||
|
||||
try:
|
||||
# 冷却池处理:过滤掉冷却中的动作
|
||||
self._update_body_action_cooldown()
|
||||
available_actions = [k for k in BODY_CODE.keys() if k not in self.body_action_cooldown]
|
||||
all_actions = "\n".join(available_actions)
|
||||
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"change_action_prompt",
|
||||
chat_talking_prompt=chat_talking_prompt,
|
||||
indentify_block=indentify_block,
|
||||
body_action=self.body_action,
|
||||
all_actions=all_actions,
|
||||
)
|
||||
|
||||
logger.info(f"prompt: {prompt}")
|
||||
response, (reasoning_content, _, _) = await self.action_model.generate_response_async(
|
||||
prompt=prompt, temperature=0.7
|
||||
)
|
||||
logger.info(f"response: {response}")
|
||||
logger.info(f"reasoning_content: {reasoning_content}")
|
||||
|
||||
if action_data := orjson.loads(repair_json(response)):
|
||||
# 记录原动作,切换后进入冷却
|
||||
prev_body_action = self.body_action
|
||||
new_body_action = action_data.get("body_action", self.body_action)
|
||||
if new_body_action != prev_body_action and prev_body_action:
|
||||
self.body_action_cooldown[prev_body_action] = 3
|
||||
self.body_action = new_body_action
|
||||
self.head_action = action_data.get("head_action", self.head_action)
|
||||
# 发送动作更新
|
||||
await self.send_action_update()
|
||||
|
||||
self.last_change_time = message_time
|
||||
except Exception as e:
|
||||
logger.error(f"update_action_by_message error: {e}")
|
||||
|
||||
async def regress_action(self):
|
||||
message_time = time.time()
|
||||
message_list_before_now = await get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_change_time,
|
||||
timestamp_end=message_time,
|
||||
limit=10,
|
||||
limit_mode="last",
|
||||
)
|
||||
chat_talking_prompt = await build_readable_messages(
|
||||
message_list_before_now,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=0.0,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
if global_config.bot.alias_names:
|
||||
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
|
||||
else:
|
||||
bot_nickname = ""
|
||||
|
||||
prompt_personality = global_config.personality.personality_core
|
||||
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:"
|
||||
try:
|
||||
# 冷却池处理:过滤掉冷却中的动作
|
||||
self._update_body_action_cooldown()
|
||||
available_actions = [k for k in BODY_CODE.keys() if k not in self.body_action_cooldown]
|
||||
all_actions = "\n".join(available_actions)
|
||||
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"regress_action_prompt",
|
||||
chat_talking_prompt=chat_talking_prompt,
|
||||
indentify_block=indentify_block,
|
||||
body_action=self.body_action,
|
||||
all_actions=all_actions,
|
||||
)
|
||||
|
||||
logger.info(f"prompt: {prompt}")
|
||||
response, (reasoning_content, _, _) = await self.action_model.generate_response_async(
|
||||
prompt=prompt, temperature=0.7
|
||||
)
|
||||
logger.info(f"response: {response}")
|
||||
logger.info(f"reasoning_content: {reasoning_content}")
|
||||
|
||||
if action_data := orjson.loads(repair_json(response)):
|
||||
prev_body_action = self.body_action
|
||||
new_body_action = action_data.get("body_action", self.body_action)
|
||||
if new_body_action != prev_body_action and prev_body_action:
|
||||
self.body_action_cooldown[prev_body_action] = 6
|
||||
self.body_action = new_body_action
|
||||
# 发送动作更新
|
||||
await self.send_action_update()
|
||||
|
||||
self.regression_count += 1
|
||||
self.last_change_time = message_time
|
||||
except Exception as e:
|
||||
logger.error(f"regress_action error: {e}")
|
||||
|
||||
# 新增:冷却池维护方法
|
||||
def _update_body_action_cooldown(self):
|
||||
remove_keys = []
|
||||
for k in self.body_action_cooldown:
|
||||
self.body_action_cooldown[k] -= 1
|
||||
if self.body_action_cooldown[k] <= 0:
|
||||
remove_keys.append(k)
|
||||
for k in remove_keys:
|
||||
del self.body_action_cooldown[k]
|
||||
|
||||
|
||||
class ActionRegressionTask(AsyncTask):
|
||||
def __init__(self, action_manager: "ActionManager"):
|
||||
super().__init__(task_name="ActionRegressionTask", run_interval=3)
|
||||
self.action_manager = action_manager
|
||||
|
||||
async def run(self):
|
||||
logger.debug("Running action regression task...")
|
||||
now = time.time()
|
||||
for action_state in self.action_manager.action_state_list:
|
||||
if action_state.last_change_time == 0:
|
||||
continue
|
||||
|
||||
if now - action_state.last_change_time > 10:
|
||||
if action_state.regression_count >= 3:
|
||||
continue
|
||||
|
||||
logger.info(f"chat {action_state.chat_id} 开始动作回归, 这是第 {action_state.regression_count + 1} 次")
|
||||
await action_state.regress_action()
|
||||
|
||||
|
||||
class ActionManager:
|
||||
def __init__(self):
|
||||
self.action_state_list: list[ChatAction] = []
|
||||
"""当前动作状态"""
|
||||
self.task_started: bool = False
|
||||
|
||||
async def start(self):
|
||||
"""启动动作回归后台任务"""
|
||||
if self.task_started:
|
||||
return
|
||||
|
||||
logger.info("启动动作回归任务...")
|
||||
task = ActionRegressionTask(self)
|
||||
await async_task_manager.add_task(task)
|
||||
self.task_started = True
|
||||
logger.info("动作回归任务已启动")
|
||||
|
||||
def get_action_state_by_chat_id(self, chat_id: str) -> ChatAction:
|
||||
for action_state in self.action_state_list:
|
||||
if action_state.chat_id == chat_id:
|
||||
return action_state
|
||||
|
||||
new_action_state = ChatAction(chat_id)
|
||||
self.action_state_list.append(new_action_state)
|
||||
return new_action_state
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
||||
action_manager = ActionManager()
|
||||
"""全局动作管理器"""
|
||||
@@ -1,692 +0,0 @@
|
||||
import asyncio
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
|
||||
import aiohttp_cors
|
||||
import orjson
|
||||
from aiohttp import WSMsgType, web
|
||||
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("context_web")
|
||||
|
||||
|
||||
class ContextMessage:
|
||||
"""上下文消息类"""
|
||||
|
||||
def __init__(self, message: MessageRecv):
|
||||
self.user_name = message.message_info.user_info.user_nickname
|
||||
self.user_id = message.message_info.user_info.user_id
|
||||
self.content = message.processed_plain_text
|
||||
self.timestamp = datetime.now()
|
||||
self.group_name = message.message_info.group_info.group_name if message.message_info.group_info else "私聊"
|
||||
|
||||
# 识别消息类型
|
||||
self.is_gift = getattr(message, "is_gift", False)
|
||||
self.is_superchat = getattr(message, "is_superchat", False)
|
||||
|
||||
# 添加礼物和SC相关信息
|
||||
if self.is_gift:
|
||||
self.gift_name = getattr(message, "gift_name", "")
|
||||
self.gift_count = getattr(message, "gift_count", "1")
|
||||
self.content = f"送出了 {self.gift_name} x{self.gift_count}"
|
||||
elif self.is_superchat:
|
||||
self.superchat_price = getattr(message, "superchat_price", "0")
|
||||
self.superchat_message = getattr(message, "superchat_message_text", "")
|
||||
if self.superchat_message:
|
||||
self.content = f"[¥{self.superchat_price}] {self.superchat_message}"
|
||||
else:
|
||||
self.content = f"[¥{self.superchat_price}] {self.content}"
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"user_name": self.user_name,
|
||||
"user_id": self.user_id,
|
||||
"content": self.content,
|
||||
"timestamp": self.timestamp.strftime("%m-%d %H:%M:%S"),
|
||||
"group_name": self.group_name,
|
||||
"is_gift": self.is_gift,
|
||||
"is_superchat": self.is_superchat,
|
||||
}
|
||||
|
||||
|
||||
class ContextWebManager:
|
||||
"""上下文网页管理器"""
|
||||
|
||||
def __init__(self, max_messages: int = 10, port: int = 8765):
|
||||
self.max_messages = max_messages
|
||||
self.port = port
|
||||
self.contexts: dict[str, deque] = {} # chat_id -> deque of ContextMessage
|
||||
self.websockets: list[web.WebSocketResponse] = []
|
||||
self.app = None
|
||||
self.runner = None
|
||||
self.site = None
|
||||
self._server_starting = False # 添加启动标志防止并发
|
||||
|
||||
async def start_server(self):
|
||||
"""启动web服务器"""
|
||||
if self.site is not None:
|
||||
logger.debug("Web服务器已经启动,跳过重复启动")
|
||||
return
|
||||
|
||||
if self._server_starting:
|
||||
logger.debug("Web服务器正在启动中,等待启动完成...")
|
||||
# 等待启动完成
|
||||
while self._server_starting and self.site is None:
|
||||
await asyncio.sleep(0.1)
|
||||
return
|
||||
|
||||
self._server_starting = True
|
||||
|
||||
try:
|
||||
self.app = web.Application()
|
||||
|
||||
# 设置CORS
|
||||
cors = aiohttp_cors.setup(
|
||||
self.app,
|
||||
defaults={
|
||||
"*": aiohttp_cors.ResourceOptions(
|
||||
allow_credentials=True, expose_headers="*", allow_headers="*", allow_methods="*"
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
# 添加路由
|
||||
self.app.router.add_get("/", self.index_handler)
|
||||
self.app.router.add_get("/ws", self.websocket_handler)
|
||||
self.app.router.add_get("/api/contexts", self.get_contexts_handler)
|
||||
self.app.router.add_get("/debug", self.debug_handler)
|
||||
|
||||
# 为所有路由添加CORS
|
||||
for route in list(self.app.router.routes()):
|
||||
cors.add(route)
|
||||
|
||||
self.runner = web.AppRunner(self.app)
|
||||
await self.runner.setup()
|
||||
|
||||
self.site = web.TCPSite(self.runner, "localhost", self.port)
|
||||
await self.site.start()
|
||||
|
||||
logger.info(f"🌐 上下文网页服务器启动成功在 http://localhost:{self.port}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 启动Web服务器失败: {e}")
|
||||
# 清理部分启动的资源
|
||||
if self.runner:
|
||||
await self.runner.cleanup()
|
||||
self.app = None
|
||||
self.runner = None
|
||||
self.site = None
|
||||
raise
|
||||
finally:
|
||||
self._server_starting = False
|
||||
|
||||
async def stop_server(self):
|
||||
"""停止web服务器"""
|
||||
if self.site:
|
||||
await self.site.stop()
|
||||
if self.runner:
|
||||
await self.runner.cleanup()
|
||||
self.app = None
|
||||
self.runner = None
|
||||
self.site = None
|
||||
self._server_starting = False
|
||||
|
||||
async def index_handler(self, request):
|
||||
"""主页处理器"""
|
||||
html_content = (
|
||||
"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>聊天上下文</title>
|
||||
<style>
|
||||
html, body {
|
||||
background: transparent !important;
|
||||
background-color: transparent !important;
|
||||
margin: 0;
|
||||
padding: 20px;
|
||||
font-family: 'Microsoft YaHei', Arial, sans-serif;
|
||||
color: #ffffff;
|
||||
text-shadow: 2px 2px 4px rgba(0,0,0,0.8);
|
||||
}
|
||||
.container {
|
||||
max-width: 800px;
|
||||
margin: 0 auto;
|
||||
background: transparent !important;
|
||||
}
|
||||
.message {
|
||||
background: rgba(0, 0, 0, 0.3);
|
||||
margin: 10px 0;
|
||||
padding: 15px;
|
||||
border-radius: 10px;
|
||||
border-left: 4px solid #00ff88;
|
||||
backdrop-filter: blur(5px);
|
||||
animation: slideIn 0.3s ease-out;
|
||||
transform: translateY(0);
|
||||
transition: transform 0.5s ease, opacity 0.5s ease;
|
||||
}
|
||||
.message:hover {
|
||||
background: rgba(0, 0, 0, 0.5);
|
||||
transform: translateX(5px);
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
.message.gift {
|
||||
border-left: 4px solid #ff8800;
|
||||
background: rgba(255, 136, 0, 0.2);
|
||||
}
|
||||
.message.gift:hover {
|
||||
background: rgba(255, 136, 0, 0.3);
|
||||
}
|
||||
.message.gift .username {
|
||||
color: #ff8800;
|
||||
}
|
||||
.message.superchat {
|
||||
border-left: 4px solid #ff6b6b;
|
||||
background: linear-gradient(135deg, rgba(255, 107, 107, 0.2), rgba(107, 255, 107, 0.2), rgba(107, 107, 255, 0.2));
|
||||
background-size: 200% 200%;
|
||||
animation: rainbow 3s ease infinite;
|
||||
}
|
||||
.message.superchat:hover {
|
||||
background: linear-gradient(135deg, rgba(255, 107, 107, 0.4), rgba(107, 255, 107, 0.4), rgba(107, 107, 255, 0.4));
|
||||
background-size: 200% 200%;
|
||||
}
|
||||
.message.superchat .username {
|
||||
background: linear-gradient(45deg, #ff6b6b, #4ecdc4, #45b7d1, #96ceb4, #feca57);
|
||||
background-size: 300% 300%;
|
||||
animation: rainbow-text 2s ease infinite;
|
||||
-webkit-background-clip: text;
|
||||
-webkit-text-fill-color: transparent;
|
||||
background-clip: text;
|
||||
}
|
||||
@keyframes rainbow {
|
||||
0% { background-position: 0% 50%; }
|
||||
50% { background-position: 100% 50%; }
|
||||
100% { background-position: 0% 50%; }
|
||||
}
|
||||
@keyframes rainbow-text {
|
||||
0% { background-position: 0% 50%; }
|
||||
50% { background-position: 100% 50%; }
|
||||
100% { background-position: 0% 50%; }
|
||||
}
|
||||
.message-line {
|
||||
line-height: 1.4;
|
||||
word-wrap: break-word;
|
||||
font-size: 24px;
|
||||
}
|
||||
.username {
|
||||
color: #00ff88;
|
||||
}
|
||||
.content {
|
||||
color: #ffffff;
|
||||
}
|
||||
|
||||
.new-message {
|
||||
animation: slideInNew 0.6s ease-out;
|
||||
}
|
||||
|
||||
.debug-btn {
|
||||
position: fixed;
|
||||
bottom: 20px;
|
||||
right: 20px;
|
||||
background: rgba(0, 0, 0, 0.7);
|
||||
color: #00ff88;
|
||||
font-size: 12px;
|
||||
padding: 8px 12px;
|
||||
border-radius: 20px;
|
||||
backdrop-filter: blur(10px);
|
||||
z-index: 1000;
|
||||
text-decoration: none;
|
||||
border: 1px solid #00ff88;
|
||||
}
|
||||
.debug-btn:hover {
|
||||
background: rgba(0, 255, 136, 0.2);
|
||||
}
|
||||
@keyframes slideIn {
|
||||
from {
|
||||
opacity: 0;
|
||||
transform: translateY(-20px);
|
||||
}
|
||||
to {
|
||||
opacity: 1;
|
||||
transform: translateY(0);
|
||||
}
|
||||
}
|
||||
@keyframes slideInNew {
|
||||
from {
|
||||
opacity: 0;
|
||||
transform: translateY(50px) scale(0.95);
|
||||
}
|
||||
to {
|
||||
opacity: 1;
|
||||
transform: translateY(0) scale(1);
|
||||
}
|
||||
}
|
||||
.no-messages {
|
||||
text-align: center;
|
||||
color: #666;
|
||||
font-style: italic;
|
||||
margin-top: 50px;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<a href="/debug" class="debug-btn">🔧 调试</a>
|
||||
<div id="messages">
|
||||
<div class="no-messages">暂无消息</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
let ws;
|
||||
let reconnectInterval;
|
||||
let currentMessages = []; // 存储当前显示的消息
|
||||
|
||||
function connectWebSocket() {
|
||||
console.log('正在连接WebSocket...');
|
||||
ws = new WebSocket('ws://localhost:"""
|
||||
+ str(self.port)
|
||||
+ """/ws');
|
||||
|
||||
ws.onopen = function() {
|
||||
console.log('WebSocket连接已建立');
|
||||
if (reconnectInterval) {
|
||||
clearInterval(reconnectInterval);
|
||||
reconnectInterval = null;
|
||||
}
|
||||
};
|
||||
|
||||
ws.onmessage = function(event) {
|
||||
console.log('收到WebSocket消息:', event.data);
|
||||
try {
|
||||
const data = orjson.parse(event.data);
|
||||
updateMessages(data.contexts);
|
||||
} catch (e) {
|
||||
console.error('解析消息失败:', e, event.data);
|
||||
}
|
||||
};
|
||||
|
||||
ws.onclose = function(event) {
|
||||
console.log('WebSocket连接关闭:', event.code, event.reason);
|
||||
|
||||
if (!reconnectInterval) {
|
||||
reconnectInterval = setInterval(connectWebSocket, 3000);
|
||||
}
|
||||
};
|
||||
|
||||
ws.onerror = function(error) {
|
||||
console.error('WebSocket错误:', error);
|
||||
};
|
||||
}
|
||||
|
||||
function updateMessages(contexts) {
|
||||
const messagesDiv = document.getElementById('messages');
|
||||
|
||||
if (!contexts || contexts.length === 0) {
|
||||
messagesDiv.innerHTML = '<div class="no-messages">暂无消息</div>';
|
||||
currentMessages = [];
|
||||
return;
|
||||
}
|
||||
|
||||
// 如果是第一次加载或者消息完全不同,进行完全重新渲染
|
||||
if (currentMessages.length === 0) {
|
||||
console.log('首次加载消息,数量:', contexts.length);
|
||||
messagesDiv.innerHTML = '';
|
||||
|
||||
contexts.forEach(function(msg) {
|
||||
const messageDiv = createMessageElement(msg);
|
||||
messagesDiv.appendChild(messageDiv);
|
||||
});
|
||||
|
||||
currentMessages = [...contexts];
|
||||
window.scrollTo(0, document.body.scrollHeight);
|
||||
return;
|
||||
}
|
||||
|
||||
// 检测新消息 - 使用更可靠的方法
|
||||
const newMessages = findNewMessages(contexts, currentMessages);
|
||||
|
||||
if (newMessages.length > 0) {
|
||||
console.log('添加新消息,数量:', newMessages.length);
|
||||
|
||||
// 先检查是否需要移除老消息(保持DOM清洁)
|
||||
const maxDisplayMessages = 15; // 比服务器端稍多一些,确保流畅性
|
||||
const currentMessageElements = messagesDiv.querySelectorAll('.message');
|
||||
const willExceedLimit = currentMessageElements.length + newMessages.length > maxDisplayMessages;
|
||||
|
||||
if (willExceedLimit) {
|
||||
const removeCount = (currentMessageElements.length + newMessages.length) - maxDisplayMessages;
|
||||
console.log('需要移除老消息数量:', removeCount);
|
||||
|
||||
for (let i = 0; i < removeCount && i < currentMessageElements.length; i++) {
|
||||
const oldMessage = currentMessageElements[i];
|
||||
oldMessage.style.transition = 'opacity 0.3s ease, transform 0.3s ease';
|
||||
oldMessage.style.opacity = '0';
|
||||
oldMessage.style.transform = 'translateY(-20px)';
|
||||
|
||||
setTimeout(() => {
|
||||
if (oldMessage.parentNode) {
|
||||
oldMessage.parentNode.removeChild(oldMessage);
|
||||
}
|
||||
}, 300);
|
||||
}
|
||||
}
|
||||
|
||||
// 添加新消息
|
||||
newMessages.forEach(function(msg) {
|
||||
const messageDiv = createMessageElement(msg, true); // true表示是新消息
|
||||
messagesDiv.appendChild(messageDiv);
|
||||
|
||||
// 移除动画类,避免重复动画
|
||||
setTimeout(() => {
|
||||
messageDiv.classList.remove('new-message');
|
||||
}, 600);
|
||||
});
|
||||
|
||||
// 更新当前消息列表
|
||||
currentMessages = [...contexts];
|
||||
|
||||
// 平滑滚动到底部
|
||||
setTimeout(() => {
|
||||
window.scrollTo({
|
||||
top: document.body.scrollHeight,
|
||||
behavior: 'smooth'
|
||||
});
|
||||
}, 100);
|
||||
}
|
||||
}
|
||||
|
||||
function findNewMessages(contexts, currentMessages) {
|
||||
// 如果当前消息为空,所有消息都是新的
|
||||
if (currentMessages.length === 0) {
|
||||
return contexts;
|
||||
}
|
||||
|
||||
// 找到最后一条当前消息在新消息列表中的位置
|
||||
const lastCurrentMsg = currentMessages[currentMessages.length - 1];
|
||||
let lastIndex = -1;
|
||||
|
||||
// 从后往前找,因为新消息通常在末尾
|
||||
for (let i = contexts.length - 1; i >= 0; i--) {
|
||||
const msg = contexts[i];
|
||||
if (msg.user_id === lastCurrentMsg.user_id &&
|
||||
msg.content === lastCurrentMsg.content &&
|
||||
msg.timestamp === lastCurrentMsg.timestamp) {
|
||||
lastIndex = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// 如果找到了,返回之后的消息;否则返回所有消息(可能是完全刷新)
|
||||
if (lastIndex >= 0) {
|
||||
return contexts.slice(lastIndex + 1);
|
||||
} else {
|
||||
console.log('未找到匹配的最后消息,可能需要完全刷新');
|
||||
return contexts.slice(Math.max(0, contexts.length - (currentMessages.length + 1)));
|
||||
}
|
||||
}
|
||||
|
||||
function createMessageElement(msg, isNew = false) {
|
||||
const messageDiv = document.createElement('div');
|
||||
let className = 'message';
|
||||
|
||||
// 根据消息类型添加对应的CSS类
|
||||
if (msg.is_gift) {
|
||||
className += ' gift';
|
||||
} else if (msg.is_superchat) {
|
||||
className += ' superchat';
|
||||
}
|
||||
|
||||
if (isNew) {
|
||||
className += ' new-message';
|
||||
}
|
||||
|
||||
messageDiv.className = className;
|
||||
messageDiv.innerHTML = `
|
||||
<div class="message-line">
|
||||
<span class="username">${escapeHtml(msg.user_name)}:</span><span class="content">${escapeHtml(msg.content)}</span>
|
||||
</div>
|
||||
`;
|
||||
return messageDiv;
|
||||
}
|
||||
|
||||
function escapeHtml(text) {
|
||||
const div = document.createElement('div');
|
||||
div.textContent = text;
|
||||
return div.innerHTML;
|
||||
}
|
||||
|
||||
// 初始加载数据
|
||||
fetch('/api/contexts')
|
||||
.then(response => response.json())
|
||||
.then(data => {
|
||||
console.log('初始数据加载成功:', data);
|
||||
updateMessages(data.contexts);
|
||||
})
|
||||
.catch(err => console.error('加载初始数据失败:', err));
|
||||
|
||||
// 连接WebSocket
|
||||
connectWebSocket();
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
)
|
||||
return web.Response(text=html_content, content_type="text/html")
|
||||
|
||||
async def websocket_handler(self, request):
|
||||
"""WebSocket处理器"""
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
|
||||
self.websockets.append(ws)
|
||||
logger.debug(f"WebSocket连接建立,当前连接数: {len(self.websockets)}")
|
||||
|
||||
# 发送初始数据
|
||||
await self.send_contexts_to_websocket(ws)
|
||||
|
||||
async for msg in ws:
|
||||
if msg.type == WSMsgType.ERROR:
|
||||
logger.error(f"WebSocket错误: {ws.exception()}")
|
||||
break
|
||||
|
||||
# 清理断开的连接
|
||||
if ws in self.websockets:
|
||||
self.websockets.remove(ws)
|
||||
logger.debug(f"WebSocket连接断开,当前连接数: {len(self.websockets)}")
|
||||
|
||||
return ws
|
||||
|
||||
async def get_contexts_handler(self, request):
|
||||
"""获取上下文API"""
|
||||
all_context_msgs = []
|
||||
for contexts in self.contexts.values():
|
||||
all_context_msgs.extend(list(contexts))
|
||||
|
||||
# 按时间排序,最新的在最后
|
||||
all_context_msgs.sort(key=lambda x: x.timestamp)
|
||||
|
||||
# 转换为字典格式
|
||||
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages :]]
|
||||
|
||||
logger.debug(f"返回上下文数据,共 {len(contexts_data)} 条消息")
|
||||
return web.json_response({"contexts": contexts_data})
|
||||
|
||||
async def debug_handler(self, request):
|
||||
"""调试信息处理器"""
|
||||
debug_info = {
|
||||
"server_status": "running",
|
||||
"websocket_connections": len(self.websockets),
|
||||
"total_chats": len(self.contexts),
|
||||
"total_messages": sum(len(contexts) for contexts in self.contexts.values()),
|
||||
}
|
||||
|
||||
# 构建聊天详情HTML
|
||||
chats_html = ""
|
||||
for chat_id, contexts in self.contexts.items():
|
||||
messages_html = ""
|
||||
for msg in contexts:
|
||||
timestamp = msg.timestamp.strftime("%H:%M:%S")
|
||||
content = msg.content[:50] + "..." if len(msg.content) > 50 else msg.content
|
||||
messages_html += f'<div class="message">[{timestamp}] {msg.user_name}: {content}</div>'
|
||||
|
||||
chats_html += f"""
|
||||
<div class="chat">
|
||||
<h3>聊天 {chat_id} ({len(contexts)} 条消息)</h3>
|
||||
{messages_html}
|
||||
</div>
|
||||
"""
|
||||
|
||||
html_content = f"""
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>调试信息</title>
|
||||
<style>
|
||||
body {{ font-family: monospace; margin: 20px; }}
|
||||
.section {{ margin: 20px 0; padding: 10px; border: 1px solid #ccc; }}
|
||||
.chat {{ margin: 10px 0; padding: 10px; background: #f5f5f5; }}
|
||||
.message {{ margin: 5px 0; padding: 5px; background: white; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>上下文网页管理器调试信息</h1>
|
||||
|
||||
<div class="section">
|
||||
<h2>服务器状态</h2>
|
||||
<p>状态: {debug_info["server_status"]}</p>
|
||||
<p>WebSocket连接数: {debug_info["websocket_connections"]}</p>
|
||||
<p>聊天总数: {debug_info["total_chats"]}</p>
|
||||
<p>消息总数: {debug_info["total_messages"]}</p>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>聊天详情</h2>
|
||||
{chats_html}
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>操作</h2>
|
||||
<button onclick="location.reload()">刷新页面</button>
|
||||
<button onclick="window.location.href='/'">返回主页</button>
|
||||
<button onclick="window.location.href='/api/contexts'">查看API数据</button>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
console.log('调试信息:', {orjson.dumps(debug_info, option=orjson.OPT_INDENT_2).decode("utf-8")});
|
||||
setTimeout(() => location.reload(), 5000); // 5秒自动刷新
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
return web.Response(text=html_content, content_type="text/html")
|
||||
|
||||
async def add_message(self, chat_id: str, message: MessageRecv):
|
||||
"""添加新消息到上下文"""
|
||||
if chat_id not in self.contexts:
|
||||
self.contexts[chat_id] = deque(maxlen=self.max_messages)
|
||||
logger.debug(f"为聊天 {chat_id} 创建新的上下文队列")
|
||||
|
||||
context_msg = ContextMessage(message)
|
||||
self.contexts[chat_id].append(context_msg)
|
||||
|
||||
# 统计当前总消息数
|
||||
total_messages = sum(len(contexts) for contexts in self.contexts.values())
|
||||
|
||||
logger.info(
|
||||
f"✅ 添加消息到上下文 [总数: {total_messages}]: [{context_msg.group_name}] {context_msg.user_name}: {context_msg.content}"
|
||||
)
|
||||
|
||||
# 调试:打印当前所有消息
|
||||
logger.info("📝 当前上下文中的所有消息:")
|
||||
for cid, contexts in self.contexts.items():
|
||||
logger.info(f" 聊天 {cid}: {len(contexts)} 条消息")
|
||||
for i, msg in enumerate(contexts):
|
||||
logger.info(
|
||||
f" {i + 1}. [{msg.timestamp.strftime('%H:%M:%S')}] {msg.user_name}: {msg.content[:30]}..."
|
||||
)
|
||||
|
||||
# 广播更新给所有WebSocket连接
|
||||
await self.broadcast_contexts()
|
||||
|
||||
async def send_contexts_to_websocket(self, ws: web.WebSocketResponse):
|
||||
"""向单个WebSocket发送上下文数据"""
|
||||
all_context_msgs = []
|
||||
for contexts in self.contexts.values():
|
||||
all_context_msgs.extend(list(contexts))
|
||||
|
||||
# 按时间排序,最新的在最后
|
||||
all_context_msgs.sort(key=lambda x: x.timestamp)
|
||||
|
||||
# 转换为字典格式
|
||||
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages :]]
|
||||
|
||||
data = {"contexts": contexts_data}
|
||||
await ws.send_str(orjson.dumps(data).decode("utf-8"))
|
||||
|
||||
async def broadcast_contexts(self):
|
||||
"""向所有WebSocket连接广播上下文更新"""
|
||||
if not self.websockets:
|
||||
logger.debug("没有WebSocket连接,跳过广播")
|
||||
return
|
||||
|
||||
all_context_msgs = []
|
||||
for contexts in self.contexts.values():
|
||||
all_context_msgs.extend(list(contexts))
|
||||
|
||||
# 按时间排序,最新的在最后
|
||||
all_context_msgs.sort(key=lambda x: x.timestamp)
|
||||
|
||||
# 转换为字典格式
|
||||
contexts_data = [msg.to_dict() for msg in all_context_msgs[-self.max_messages :]]
|
||||
|
||||
data = {"contexts": contexts_data}
|
||||
message = orjson.dumps(data).decode("utf-8")
|
||||
|
||||
logger.info(f"广播 {len(contexts_data)} 条消息到 {len(self.websockets)} 个WebSocket连接")
|
||||
|
||||
# 创建WebSocket列表的副本,避免在遍历时修改
|
||||
websockets_copy = self.websockets.copy()
|
||||
removed_count = 0
|
||||
|
||||
for ws in websockets_copy:
|
||||
if ws.closed:
|
||||
if ws in self.websockets:
|
||||
self.websockets.remove(ws)
|
||||
removed_count += 1
|
||||
else:
|
||||
try:
|
||||
await ws.send_str(message)
|
||||
logger.debug("消息发送成功")
|
||||
except Exception as e:
|
||||
logger.error(f"发送WebSocket消息失败: {e}")
|
||||
if ws in self.websockets:
|
||||
self.websockets.remove(ws)
|
||||
removed_count += 1
|
||||
|
||||
if removed_count > 0:
|
||||
logger.debug(f"清理了 {removed_count} 个断开的WebSocket连接")
|
||||
|
||||
|
||||
# 全局实例
|
||||
_context_web_manager: ContextWebManager | None = None
|
||||
|
||||
|
||||
def get_context_web_manager() -> ContextWebManager:
|
||||
"""获取上下文网页管理器实例"""
|
||||
global _context_web_manager
|
||||
if _context_web_manager is None:
|
||||
_context_web_manager = ContextWebManager()
|
||||
return _context_web_manager
|
||||
|
||||
|
||||
async def init_context_web_manager():
|
||||
"""初始化上下文网页管理器"""
|
||||
manager = get_context_web_manager()
|
||||
await manager.start_server()
|
||||
return manager
|
||||
@@ -1,147 +0,0 @@
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
from src.chat.message_receive.message import MessageRecvS4U
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("gift_manager")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PendingGift:
|
||||
"""等待中的礼物消息"""
|
||||
|
||||
message: MessageRecvS4U
|
||||
total_count: int
|
||||
timer_task: asyncio.Task
|
||||
callback: Callable[[MessageRecvS4U], None]
|
||||
|
||||
|
||||
class GiftManager:
|
||||
"""礼物管理器,提供防抖功能"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化礼物管理器"""
|
||||
self.pending_gifts: dict[tuple[str, str], PendingGift] = {}
|
||||
self.debounce_timeout = 5.0 # 3秒防抖时间
|
||||
|
||||
async def handle_gift(
|
||||
self, message: MessageRecvS4U, callback: Callable[[MessageRecvS4U], None] | None = None
|
||||
) -> bool:
|
||||
"""处理礼物消息,返回是否应该立即处理
|
||||
|
||||
Args:
|
||||
message: 礼物消息
|
||||
callback: 防抖完成后的回调函数
|
||||
|
||||
Returns:
|
||||
bool: False表示消息被暂存等待防抖,True表示应该立即处理
|
||||
"""
|
||||
if not message.is_gift:
|
||||
return True
|
||||
|
||||
# 构建礼物的唯一键:(发送人ID, 礼物名称)
|
||||
gift_key = (message.message_info.user_info.user_id, message.gift_name)
|
||||
|
||||
# 如果已经有相同的礼物在等待中,则合并
|
||||
if gift_key in self.pending_gifts:
|
||||
await self._merge_gift(gift_key, message)
|
||||
return False
|
||||
|
||||
# 创建新的等待礼物
|
||||
await self._create_pending_gift(gift_key, message, callback)
|
||||
return False
|
||||
|
||||
async def _merge_gift(self, gift_key: tuple[str, str], new_message: MessageRecvS4U) -> None:
|
||||
"""合并礼物消息"""
|
||||
pending_gift = self.pending_gifts[gift_key]
|
||||
|
||||
# 取消之前的定时器
|
||||
if not pending_gift.timer_task.cancelled():
|
||||
pending_gift.timer_task.cancel()
|
||||
|
||||
# 累加礼物数量
|
||||
try:
|
||||
new_count = int(new_message.gift_count)
|
||||
pending_gift.total_count += new_count
|
||||
|
||||
# 更新消息为最新的(保留最新的消息,但累加数量)
|
||||
pending_gift.message = new_message
|
||||
pending_gift.message.gift_count = str(pending_gift.total_count)
|
||||
pending_gift.message.gift_info = f"{pending_gift.message.gift_name}:{pending_gift.total_count}"
|
||||
|
||||
except ValueError:
|
||||
logger.warning(f"无法解析礼物数量: {new_message.gift_count}")
|
||||
# 如果无法解析数量,保持原有数量不变
|
||||
|
||||
# 重新创建定时器
|
||||
pending_gift.timer_task = asyncio.create_task(self._gift_timeout(gift_key))
|
||||
|
||||
logger.debug(f"合并礼物: {gift_key}, 总数量: {pending_gift.total_count}")
|
||||
|
||||
async def _create_pending_gift(
|
||||
self, gift_key: tuple[str, str], message: MessageRecvS4U, callback: Callable[[MessageRecvS4U], None] | None
|
||||
) -> None:
|
||||
"""创建新的等待礼物"""
|
||||
try:
|
||||
initial_count = int(message.gift_count)
|
||||
except ValueError:
|
||||
initial_count = 1
|
||||
logger.warning(f"无法解析礼物数量: {message.gift_count},默认设为1")
|
||||
|
||||
# 创建定时器任务
|
||||
timer_task = asyncio.create_task(self._gift_timeout(gift_key))
|
||||
|
||||
# 创建等待礼物对象
|
||||
pending_gift = PendingGift(message=message, total_count=initial_count, timer_task=timer_task, callback=callback)
|
||||
|
||||
self.pending_gifts[gift_key] = pending_gift
|
||||
|
||||
logger.debug(f"创建等待礼物: {gift_key}, 初始数量: {initial_count}")
|
||||
|
||||
async def _gift_timeout(self, gift_key: tuple[str, str]) -> None:
|
||||
"""礼物防抖超时处理"""
|
||||
try:
|
||||
# 等待防抖时间
|
||||
await asyncio.sleep(self.debounce_timeout)
|
||||
|
||||
# 获取等待中的礼物
|
||||
if gift_key not in self.pending_gifts:
|
||||
return
|
||||
|
||||
pending_gift = self.pending_gifts.pop(gift_key)
|
||||
|
||||
logger.info(f"礼物防抖完成: {gift_key}, 最终数量: {pending_gift.total_count}")
|
||||
|
||||
message = pending_gift.message
|
||||
message.processed_plain_text = f"用户{message.message_info.user_info.user_nickname}送出了礼物{message.gift_name} x{pending_gift.total_count}"
|
||||
|
||||
# 执行回调
|
||||
if pending_gift.callback:
|
||||
try:
|
||||
pending_gift.callback(message)
|
||||
except Exception as e:
|
||||
logger.error(f"礼物回调执行失败: {e}", exc_info=True)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# 定时器被取消,不需要处理
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"礼物防抖处理异常: {e}", exc_info=True)
|
||||
|
||||
def get_pending_count(self) -> int:
|
||||
"""获取当前等待中的礼物数量"""
|
||||
return len(self.pending_gifts)
|
||||
|
||||
async def flush_all(self) -> None:
|
||||
"""立即处理所有等待中的礼物"""
|
||||
for gift_key in list(self.pending_gifts.keys()):
|
||||
pending_gift = self.pending_gifts.get(gift_key)
|
||||
if pending_gift and not pending_gift.timer_task.cancelled():
|
||||
pending_gift.timer_task.cancel()
|
||||
await self._gift_timeout(gift_key)
|
||||
|
||||
|
||||
# 创建全局礼物管理器实例
|
||||
gift_manager = GiftManager()
|
||||
@@ -1,15 +0,0 @@
|
||||
class InternalManager:
|
||||
def __init__(self):
|
||||
self.now_internal_state = ""
|
||||
|
||||
def set_internal_state(self, internal_state: str):
|
||||
self.now_internal_state = internal_state
|
||||
|
||||
def get_internal_state(self):
|
||||
return self.now_internal_state
|
||||
|
||||
def get_internal_state_str(self):
|
||||
return f"你今天的直播内容是直播QQ水群,你正在一边回复弹幕,一边在QQ群聊天,你在QQ群聊天中产生的想法是:{self.now_internal_state}"
|
||||
|
||||
|
||||
internal_manager = InternalManager()
|
||||
@@ -1,611 +0,0 @@
|
||||
import asyncio
|
||||
import random
|
||||
import time
|
||||
import traceback
|
||||
|
||||
import orjson
|
||||
from maim_message import Seg, UserInfo
|
||||
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U, MessageSending
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.common.logger import get_logger
|
||||
from src.common.message.api import get_global_api
|
||||
from src.config.config import global_config
|
||||
from src.mais4u.constant_s4u import ENABLE_S4U
|
||||
from src.mais4u.s4u_config import s4u_config
|
||||
from src.person_info.person_info import PersonInfoManager
|
||||
from src.person_info.relationship_builder_manager import relationship_builder_manager
|
||||
|
||||
from .s4u_mood_manager import mood_manager
|
||||
from .s4u_stream_generator import S4UStreamGenerator
|
||||
from .s4u_watching_manager import watching_manager
|
||||
from .super_chat_manager import get_super_chat_manager
|
||||
from .yes_or_no import yes_or_no_head
|
||||
|
||||
logger = get_logger("S4U_chat")
|
||||
|
||||
|
||||
class MessageSenderContainer:
|
||||
"""一个简单的容器,用于按顺序发送消息并模拟打字效果。"""
|
||||
|
||||
def __init__(self, chat_stream: ChatStream, original_message: MessageRecv):
|
||||
self.chat_stream = chat_stream
|
||||
self.original_message = original_message
|
||||
self.queue = asyncio.Queue()
|
||||
self.storage = MessageStorage()
|
||||
self._task: asyncio.Task | None = None
|
||||
self._paused_event = asyncio.Event()
|
||||
self._paused_event.set() # 默认设置为非暂停状态
|
||||
|
||||
self.msg_id = ""
|
||||
|
||||
self.last_msg_id = ""
|
||||
|
||||
self.voice_done = ""
|
||||
|
||||
async def add_message(self, chunk: str):
|
||||
"""向队列中添加一个消息块。"""
|
||||
await self.queue.put(chunk)
|
||||
|
||||
async def close(self):
|
||||
"""表示没有更多消息了,关闭队列。"""
|
||||
await self.queue.put(None) # Sentinel
|
||||
|
||||
def pause(self):
|
||||
"""暂停发送。"""
|
||||
self._paused_event.clear()
|
||||
|
||||
def resume(self):
|
||||
"""恢复发送。"""
|
||||
self._paused_event.set()
|
||||
|
||||
@staticmethod
|
||||
def _calculate_typing_delay(text: str) -> float:
|
||||
"""根据文本长度计算模拟打字延迟。"""
|
||||
chars_per_second = s4u_config.chars_per_second
|
||||
min_delay = s4u_config.min_typing_delay
|
||||
max_delay = s4u_config.max_typing_delay
|
||||
|
||||
delay = len(text) / chars_per_second
|
||||
return max(min_delay, min(delay, max_delay))
|
||||
|
||||
async def _send_worker(self):
|
||||
"""从队列中取出消息并发送。"""
|
||||
while True:
|
||||
try:
|
||||
# This structure ensures that task_done() is called for every item retrieved,
|
||||
# even if the worker is cancelled while processing the item.
|
||||
chunk = await self.queue.get()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
try:
|
||||
if chunk is None:
|
||||
break
|
||||
|
||||
# Check for pause signal *after* getting an item.
|
||||
await self._paused_event.wait()
|
||||
|
||||
# 根据配置选择延迟模式
|
||||
if s4u_config.enable_dynamic_typing_delay:
|
||||
delay = self._calculate_typing_delay(chunk)
|
||||
else:
|
||||
delay = s4u_config.typing_delay
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
message_segment = Seg(type="tts_text", data=f"{self.msg_id}:{chunk}")
|
||||
bot_message = MessageSending(
|
||||
message_id=self.msg_id,
|
||||
chat_stream=self.chat_stream,
|
||||
bot_user_info=UserInfo(
|
||||
user_id=global_config.bot.qq_account,
|
||||
user_nickname=global_config.bot.nickname,
|
||||
platform=self.original_message.message_info.platform,
|
||||
),
|
||||
sender_info=self.original_message.message_info.user_info,
|
||||
message_segment=message_segment,
|
||||
reply=self.original_message,
|
||||
is_emoji=False,
|
||||
apply_set_reply_logic=True,
|
||||
reply_to=f"{self.original_message.message_info.user_info.platform}:{self.original_message.message_info.user_info.user_id}",
|
||||
)
|
||||
|
||||
await bot_message.process()
|
||||
|
||||
await get_global_api().send_message(bot_message)
|
||||
logger.info(f"已将消息 '{self.msg_id}:{chunk}' 发往平台 '{bot_message.message_info.platform}'")
|
||||
|
||||
message_segment = Seg(type="text", data=chunk)
|
||||
bot_message = MessageSending(
|
||||
message_id=self.msg_id,
|
||||
chat_stream=self.chat_stream,
|
||||
bot_user_info=UserInfo(
|
||||
user_id=global_config.bot.qq_account,
|
||||
user_nickname=global_config.bot.nickname,
|
||||
platform=self.original_message.message_info.platform,
|
||||
),
|
||||
sender_info=self.original_message.message_info.user_info,
|
||||
message_segment=message_segment,
|
||||
reply=self.original_message,
|
||||
is_emoji=False,
|
||||
apply_set_reply_logic=True,
|
||||
reply_to=f"{self.original_message.message_info.user_info.platform}:{self.original_message.message_info.user_info.user_id}",
|
||||
)
|
||||
await bot_message.process()
|
||||
|
||||
await self.storage.store_message(bot_message, self.chat_stream)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[消息流: {self.chat_stream.stream_id}] 消息发送或存储时出现错误: {e}", exc_info=True)
|
||||
|
||||
finally:
|
||||
# CRUCIAL: Always call task_done() for any item that was successfully retrieved.
|
||||
self.queue.task_done()
|
||||
|
||||
def start(self):
|
||||
"""启动发送任务。"""
|
||||
if self._task is None:
|
||||
self._task = asyncio.create_task(self._send_worker())
|
||||
|
||||
async def join(self):
|
||||
"""等待所有消息发送完毕。"""
|
||||
if self._task:
|
||||
await self._task
|
||||
|
||||
@property
|
||||
def task(self):
|
||||
return self._task
|
||||
|
||||
|
||||
class S4UChatManager:
|
||||
def __init__(self):
|
||||
self.s4u_chats: dict[str, "S4UChat"] = {}
|
||||
|
||||
async def get_or_create_chat(self, chat_stream: ChatStream) -> "S4UChat":
|
||||
if chat_stream.stream_id not in self.s4u_chats:
|
||||
stream_name = await get_chat_manager().get_stream_name(chat_stream.stream_id) or chat_stream.stream_id
|
||||
logger.info(f"Creating new S4UChat for stream: {stream_name}")
|
||||
self.s4u_chats[chat_stream.stream_id] = S4UChat(chat_stream)
|
||||
return self.s4u_chats[chat_stream.stream_id]
|
||||
|
||||
|
||||
if not ENABLE_S4U:
|
||||
s4u_chat_manager = None
|
||||
else:
|
||||
s4u_chat_manager = S4UChatManager()
|
||||
|
||||
|
||||
def get_s4u_chat_manager() -> S4UChatManager:
|
||||
return s4u_chat_manager
|
||||
|
||||
|
||||
class S4UChat:
|
||||
def __init__(self, chat_stream: ChatStream):
|
||||
"""初始化 S4UChat 实例。"""
|
||||
|
||||
self.last_msg_id = self.msg_id
|
||||
self.chat_stream = chat_stream
|
||||
self.stream_id = chat_stream.stream_id
|
||||
self.stream_name = self.stream_id # 初始化时使用stream_id,稍后异步更新
|
||||
self.relationship_builder = relationship_builder_manager.get_or_create_builder(self.stream_id)
|
||||
|
||||
# 两个消息队列
|
||||
self._vip_queue = asyncio.PriorityQueue()
|
||||
self._normal_queue = asyncio.PriorityQueue()
|
||||
|
||||
self._entry_counter = 0 # 保证FIFO的全局计数器
|
||||
self._new_message_event = asyncio.Event() # 用于唤醒处理器
|
||||
|
||||
self._processing_task = asyncio.create_task(self._message_processor())
|
||||
self._current_generation_task: asyncio.Task | None = None
|
||||
# 当前消息的元数据:(队列类型, 优先级分数, 计数器, 消息对象)
|
||||
self._current_message_being_replied: tuple[str, float, int, MessageRecv] | None = None
|
||||
|
||||
self._is_replying = False
|
||||
self.gpt = S4UStreamGenerator()
|
||||
self.gpt.chat_stream = self.chat_stream
|
||||
self.interest_dict: dict[str, float] = {} # 用户兴趣分
|
||||
|
||||
self.internal_message: list[MessageRecvS4U] = []
|
||||
|
||||
self.msg_id = ""
|
||||
self.voice_done = ""
|
||||
|
||||
logger.info(f"[{self.stream_name}] S4UChat with two-queue system initialized.")
|
||||
self._stream_name_initialized = False
|
||||
|
||||
async def _initialize_stream_name(self):
|
||||
"""异步初始化stream_name"""
|
||||
if not self._stream_name_initialized:
|
||||
self.stream_name = await get_chat_manager().get_stream_name(self.stream_id) or self.stream_id
|
||||
self._stream_name_initialized = True
|
||||
|
||||
@staticmethod
|
||||
def _get_priority_info(message: MessageRecv) -> dict:
|
||||
"""安全地从消息中提取和解析 priority_info"""
|
||||
priority_info_raw = message.priority_info
|
||||
priority_info = {}
|
||||
if isinstance(priority_info_raw, str):
|
||||
try:
|
||||
priority_info = orjson.loads(priority_info_raw)
|
||||
except orjson.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse priority_info JSON: {priority_info_raw}")
|
||||
elif isinstance(priority_info_raw, dict):
|
||||
priority_info = priority_info_raw
|
||||
return priority_info
|
||||
|
||||
@staticmethod
|
||||
def _is_vip(priority_info: dict) -> bool:
|
||||
"""检查消息是否来自VIP用户。"""
|
||||
return priority_info.get("message_type") == "vip"
|
||||
|
||||
def _get_interest_score(self, user_id: str) -> float:
|
||||
"""获取用户的兴趣分,默认为1.0"""
|
||||
return self.interest_dict.get(user_id, 1.0)
|
||||
|
||||
def go_processing(self):
|
||||
if self.voice_done == self.last_msg_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _calculate_base_priority_score(self, message: MessageRecv, priority_info: dict) -> float:
|
||||
"""
|
||||
为消息计算基础优先级分数。分数越高,优先级越高。
|
||||
"""
|
||||
score = 0.0
|
||||
|
||||
# 加上消息自带的优先级
|
||||
score += priority_info.get("message_priority", 0.0)
|
||||
|
||||
# 加上用户的固有兴趣分
|
||||
score += self._get_interest_score(message.message_info.user_info.user_id)
|
||||
return score
|
||||
|
||||
def decay_interest_score(self):
|
||||
for person_id, score in self.interest_dict.items():
|
||||
if score > 0:
|
||||
self.interest_dict[person_id] = score * 0.95
|
||||
else:
|
||||
self.interest_dict[person_id] = 0
|
||||
|
||||
async def add_message(self, message: MessageRecvS4U | MessageRecv) -> None:
|
||||
# 初始化stream_name
|
||||
await self._initialize_stream_name()
|
||||
|
||||
self.decay_interest_score()
|
||||
|
||||
"""根据VIP状态和中断逻辑将消息放入相应队列。"""
|
||||
user_id = message.message_info.user_info.user_id
|
||||
platform = message.message_info.platform
|
||||
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||
|
||||
try:
|
||||
is_gift = message.is_gift
|
||||
is_superchat = message.is_superchat
|
||||
# print(is_gift)
|
||||
# print(is_superchat)
|
||||
if is_gift:
|
||||
await self.relationship_builder.build_relation(immediate_build=person_id)
|
||||
# 安全地增加兴趣分,如果person_id不存在则先初始化为1.0
|
||||
current_score = self.interest_dict.get(person_id, 1.0)
|
||||
self.interest_dict[person_id] = current_score + 0.1 * message.gift_count
|
||||
elif is_superchat:
|
||||
await self.relationship_builder.build_relation(immediate_build=person_id)
|
||||
# 安全地增加兴趣分,如果person_id不存在则先初始化为1.0
|
||||
current_score = self.interest_dict.get(person_id, 1.0)
|
||||
self.interest_dict[person_id] = current_score + 0.1 * float(message.superchat_price)
|
||||
|
||||
# 添加SuperChat到管理器
|
||||
super_chat_manager = get_super_chat_manager()
|
||||
await super_chat_manager.add_superchat(message)
|
||||
else:
|
||||
await self.relationship_builder.build_relation(20)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
logger.info(f"[{self.stream_name}] 消息处理完毕,消息内容:{message.processed_plain_text}")
|
||||
|
||||
priority_info = self._get_priority_info(message)
|
||||
is_vip = self._is_vip(priority_info)
|
||||
new_priority_score = self._calculate_base_priority_score(message, priority_info)
|
||||
|
||||
should_interrupt = False
|
||||
if (
|
||||
s4u_config.enable_message_interruption
|
||||
and self._current_generation_task
|
||||
and not self._current_generation_task.done()
|
||||
):
|
||||
if self._current_message_being_replied:
|
||||
current_queue, current_priority, _, current_msg = self._current_message_being_replied
|
||||
|
||||
# 规则:VIP从不被打断
|
||||
if current_queue == "vip":
|
||||
pass # Do nothing
|
||||
|
||||
# 规则:普通消息可以被打断
|
||||
elif current_queue == "normal":
|
||||
# VIP消息可以打断普通消息
|
||||
if is_vip:
|
||||
should_interrupt = True
|
||||
logger.info(f"[{self.stream_name}] VIP message received, interrupting current normal task.")
|
||||
# 普通消息的内部打断逻辑
|
||||
else:
|
||||
new_sender_id = message.message_info.user_info.user_id
|
||||
current_sender_id = current_msg.message_info.user_info.user_id
|
||||
# 新消息优先级更高
|
||||
if new_priority_score > current_priority:
|
||||
should_interrupt = True
|
||||
logger.info(f"[{self.stream_name}] New normal message has higher priority, interrupting.")
|
||||
# 同用户,新消息的优先级不能更低
|
||||
elif new_sender_id == current_sender_id and new_priority_score >= current_priority:
|
||||
should_interrupt = True
|
||||
logger.info(f"[{self.stream_name}] Same user sent new message, interrupting.")
|
||||
|
||||
if should_interrupt:
|
||||
if self.gpt.partial_response:
|
||||
logger.warning(
|
||||
f"[{self.stream_name}] Interrupting reply. Already generated: '{self.gpt.partial_response}'"
|
||||
)
|
||||
self._current_generation_task.cancel()
|
||||
|
||||
# asyncio.PriorityQueue 是最小堆,所以我们存入分数的相反数
|
||||
# 这样,原始分数越高的消息,在队列中的优先级数字越小,越靠前
|
||||
item = (-new_priority_score, self._entry_counter, time.time(), message)
|
||||
|
||||
if is_vip and s4u_config.vip_queue_priority:
|
||||
await self._vip_queue.put(item)
|
||||
logger.info(f"[{self.stream_name}] VIP message added to queue.")
|
||||
else:
|
||||
await self._normal_queue.put(item)
|
||||
|
||||
self._entry_counter += 1
|
||||
self._new_message_event.set() # 唤醒处理器
|
||||
|
||||
def _cleanup_old_normal_messages(self):
|
||||
"""清理普通队列中不在最近N条消息范围内的消息"""
|
||||
if not s4u_config.enable_old_message_cleanup or self._normal_queue.empty():
|
||||
return
|
||||
|
||||
# 计算阈值:保留最近 recent_message_keep_count 条消息
|
||||
cutoff_counter = max(0, self._entry_counter - s4u_config.recent_message_keep_count)
|
||||
|
||||
# 临时存储需要保留的消息
|
||||
temp_messages = []
|
||||
removed_count = 0
|
||||
|
||||
# 取出所有普通队列中的消息
|
||||
while not self._normal_queue.empty():
|
||||
try:
|
||||
item = self._normal_queue.get_nowait()
|
||||
neg_priority, entry_count, timestamp, message = item
|
||||
|
||||
# 如果消息在最近N条消息范围内,保留它
|
||||
logger.info(
|
||||
f"检查消息:{message.processed_plain_text},entry_count:{entry_count} cutoff_counter:{cutoff_counter}"
|
||||
)
|
||||
|
||||
if entry_count >= cutoff_counter:
|
||||
temp_messages.append(item)
|
||||
else:
|
||||
removed_count += 1
|
||||
self._normal_queue.task_done() # 标记被移除的任务为完成
|
||||
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
|
||||
# 将保留的消息重新放入队列
|
||||
for item in temp_messages:
|
||||
self._normal_queue.put_nowait(item)
|
||||
|
||||
if removed_count > 0:
|
||||
logger.info(
|
||||
f"消息{message.processed_plain_text}超过{s4u_config.recent_message_keep_count}条,现在counter:{self._entry_counter}被移除"
|
||||
)
|
||||
logger.info(
|
||||
f"[{self.stream_name}] Cleaned up {removed_count} old normal messages outside recent {s4u_config.recent_message_keep_count} range."
|
||||
)
|
||||
|
||||
async def _message_processor(self):
|
||||
"""调度器:优先处理VIP队列,然后处理普通队列。"""
|
||||
while True:
|
||||
try:
|
||||
# 等待有新消息的信号,避免空转
|
||||
await self._new_message_event.wait()
|
||||
self._new_message_event.clear()
|
||||
|
||||
# 清理普通队列中的过旧消息
|
||||
self._cleanup_old_normal_messages()
|
||||
|
||||
# 优先处理VIP队列
|
||||
if not self._vip_queue.empty():
|
||||
neg_priority, entry_count, _, message = self._vip_queue.get_nowait()
|
||||
priority = -neg_priority
|
||||
queue_name = "vip"
|
||||
# 其次处理普通队列
|
||||
elif not self._normal_queue.empty():
|
||||
neg_priority, entry_count, timestamp, message = self._normal_queue.get_nowait()
|
||||
priority = -neg_priority
|
||||
# 检查普通消息是否超时
|
||||
if time.time() - timestamp > s4u_config.message_timeout_seconds:
|
||||
logger.info(
|
||||
f"[{self.stream_name}] Discarding stale normal message: {message.processed_plain_text[:20]}..."
|
||||
)
|
||||
self._normal_queue.task_done()
|
||||
continue # 处理下一条
|
||||
queue_name = "normal"
|
||||
else:
|
||||
if self.internal_message:
|
||||
message = self.internal_message[-1]
|
||||
self.internal_message = []
|
||||
|
||||
priority = 0
|
||||
neg_priority = 0
|
||||
entry_count = 0
|
||||
queue_name = "internal"
|
||||
|
||||
logger.info(
|
||||
f"[{self.stream_name}] normal/vip 队列都空,触发 internal_message 回复: {getattr(message, 'processed_plain_text', str(message))[:20]}..."
|
||||
)
|
||||
else:
|
||||
continue # 没有消息了,回去等事件
|
||||
|
||||
self._current_message_being_replied = (queue_name, priority, entry_count, message)
|
||||
self._current_generation_task = asyncio.create_task(self._generate_and_send(message))
|
||||
|
||||
try:
|
||||
await self._current_generation_task
|
||||
except asyncio.CancelledError:
|
||||
logger.info(
|
||||
f"[{self.stream_name}] Reply generation was interrupted externally for {queue_name} message. The message will be discarded."
|
||||
)
|
||||
# 被中断的消息应该被丢弃,而不是重新排队,以响应最新的用户输入。
|
||||
# 旧的重新入队逻辑会导致所有中断的消息最终都被回复。
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.stream_name}] _generate_and_send task error: {e}", exc_info=True)
|
||||
finally:
|
||||
self._current_generation_task = None
|
||||
self._current_message_being_replied = None
|
||||
# 标记任务完成
|
||||
if queue_name == "vip":
|
||||
self._vip_queue.task_done()
|
||||
elif queue_name == "internal":
|
||||
# 如果使用 internal_message 生成回复,则不从 normal 队列中移除
|
||||
pass
|
||||
else:
|
||||
self._normal_queue.task_done()
|
||||
|
||||
# 检查是否还有任务,有则立即再次触发事件
|
||||
if not self._vip_queue.empty() or not self._normal_queue.empty():
|
||||
self._new_message_event.set()
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"[{self.stream_name}] Message processor is shutting down.")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.stream_name}] Message processor main loop error: {e}", exc_info=True)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
def get_processing_message_id(self):
|
||||
self.msg_id = f"{time.time()}_{random.randint(1000, 9999)}"
|
||||
|
||||
async def _generate_and_send(self, message: MessageRecv):
|
||||
"""为单个消息生成文本回复。整个过程可以被中断。"""
|
||||
self._is_replying = True
|
||||
total_chars_sent = 0 # 跟踪发送的总字符数
|
||||
|
||||
self.get_processing_message_id()
|
||||
|
||||
# 视线管理:开始生成回复时切换视线状态
|
||||
chat_watching = watching_manager.get_watching_by_chat_id(self.stream_id)
|
||||
|
||||
if message.is_internal:
|
||||
await chat_watching.on_internal_message_start()
|
||||
else:
|
||||
await chat_watching.on_reply_start()
|
||||
|
||||
sender_container = MessageSenderContainer(self.chat_stream, message)
|
||||
sender_container.start()
|
||||
|
||||
async def generate_and_send_inner():
|
||||
nonlocal total_chars_sent
|
||||
logger.info(f"[S4U] 开始为消息生成文本和音频流: '{message.processed_plain_text[:30]}...'")
|
||||
|
||||
if s4u_config.enable_streaming_output:
|
||||
logger.info("[S4U] 开始流式输出")
|
||||
# 流式输出,边生成边发送
|
||||
gen = self.gpt.generate_response(message, "")
|
||||
async for chunk in gen:
|
||||
sender_container.msg_id = self.msg_id
|
||||
await sender_container.add_message(chunk)
|
||||
total_chars_sent += len(chunk)
|
||||
else:
|
||||
logger.info("[S4U] 开始一次性输出")
|
||||
# 一次性输出,先收集所有chunk
|
||||
all_chunks = []
|
||||
gen = self.gpt.generate_response(message, "")
|
||||
async for chunk in gen:
|
||||
all_chunks.append(chunk)
|
||||
total_chars_sent += len(chunk)
|
||||
# 一次性发送
|
||||
sender_container.msg_id = self.msg_id
|
||||
await sender_container.add_message("".join(all_chunks))
|
||||
|
||||
try:
|
||||
try:
|
||||
await asyncio.wait_for(generate_and_send_inner(), timeout=10)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"[{self.stream_name}] 回复生成超时,发送默认回复。")
|
||||
sender_container.msg_id = self.msg_id
|
||||
await sender_container.add_message("麦麦不知道哦")
|
||||
total_chars_sent = len("麦麦不知道哦")
|
||||
|
||||
mood = mood_manager.get_mood_by_chat_id(self.stream_id)
|
||||
await yes_or_no_head(
|
||||
text=total_chars_sent,
|
||||
emotion=mood.mood_state,
|
||||
chat_history=message.processed_plain_text,
|
||||
chat_id=self.stream_id,
|
||||
)
|
||||
|
||||
# 等待所有文本消息发送完成
|
||||
await sender_container.close()
|
||||
await sender_container.join()
|
||||
|
||||
await chat_watching.on_thinking_finished()
|
||||
|
||||
start_time = time.time()
|
||||
logged = False
|
||||
while not self.go_processing():
|
||||
if time.time() - start_time > 60:
|
||||
logger.warning(f"[{self.stream_name}] 等待消息发送超时(60秒),强制跳出循环。")
|
||||
break
|
||||
if not logged:
|
||||
logger.info(f"[{self.stream_name}] 等待消息发送完成...")
|
||||
logged = True
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
logger.info(f"[{self.stream_name}] 所有文本块处理完毕。")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"[{self.stream_name}] 回复流程(文本)被中断。")
|
||||
raise # 将取消异常向上传播
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
logger.error(f"[{self.stream_name}] 回复生成过程中出现错误: {e}", exc_info=True)
|
||||
# 回复生成实时展示:清空内容(出错时)
|
||||
finally:
|
||||
self._is_replying = False
|
||||
|
||||
# 视线管理:回复结束时切换视线状态
|
||||
chat_watching = watching_manager.get_watching_by_chat_id(self.stream_id)
|
||||
await chat_watching.on_reply_finished()
|
||||
|
||||
# 确保发送器被妥善关闭(即使已关闭,再次调用也是安全的)
|
||||
sender_container.resume()
|
||||
if not sender_container.task.done():
|
||||
await sender_container.close()
|
||||
await sender_container.join()
|
||||
logger.info(f"[{self.stream_name}] _generate_and_send 任务结束,资源已清理。")
|
||||
|
||||
async def shutdown(self):
|
||||
"""平滑关闭处理任务。"""
|
||||
logger.info(f"正在关闭 S4UChat: {self.stream_name}")
|
||||
|
||||
# 取消正在运行的任务
|
||||
if self._current_generation_task and not self._current_generation_task.done():
|
||||
self._current_generation_task.cancel()
|
||||
|
||||
if self._processing_task and not self._processing_task.done():
|
||||
self._processing_task.cancel()
|
||||
|
||||
# 等待任务响应取消
|
||||
try:
|
||||
await self._processing_task
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"处理任务已成功取消: {self.stream_name}")
|
||||
|
||||
@property
|
||||
def new_message_event(self):
|
||||
return self._new_message_event
|
||||
@@ -1,458 +0,0 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
import orjson
|
||||
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.mais4u.constant_s4u import ENABLE_S4U
|
||||
from src.manager.async_task_manager import AsyncTask, async_task_manager
|
||||
from src.plugin_system.apis import send_api
|
||||
|
||||
"""
|
||||
情绪管理系统使用说明:
|
||||
|
||||
1. 情绪数值系统:
|
||||
- 情绪包含四个维度:joy(喜), anger(怒), sorrow(哀), fear(惧)
|
||||
- 每个维度的取值范围为1-10
|
||||
- 当情绪发生变化时,会自动发送到ws端处理
|
||||
|
||||
2. 情绪更新机制:
|
||||
- 接收到新消息时会更新情绪状态
|
||||
- 定期进行情绪回归(冷静下来)
|
||||
- 每次情绪变化都会发送到ws端,格式为:
|
||||
type: "emotion"
|
||||
data: {"joy": 5, "anger": 1, "sorrow": 1, "fear": 1}
|
||||
|
||||
3. ws端处理:
|
||||
- 本地只负责情绪计算和发送情绪数值
|
||||
- 表情渲染和动作由ws端根据情绪数值处理
|
||||
"""
|
||||
|
||||
logger = get_logger("mood")
|
||||
|
||||
|
||||
def init_prompt():
|
||||
Prompt(
|
||||
"""
|
||||
{chat_talking_prompt}
|
||||
以上是直播间里正在进行的对话
|
||||
|
||||
{indentify_block}
|
||||
你刚刚的情绪状态是:{mood_state}
|
||||
|
||||
现在,发送了消息,引起了你的注意,你对其进行了阅读和思考,请你输出一句话描述你新的情绪状态,不要输出任何其他内容
|
||||
请只输出情绪状态,不要输出其他内容:
|
||||
""",
|
||||
"change_mood_prompt_vtb",
|
||||
)
|
||||
Prompt(
|
||||
"""
|
||||
{chat_talking_prompt}
|
||||
以上是直播间里最近的对话
|
||||
|
||||
{indentify_block}
|
||||
你之前的情绪状态是:{mood_state}
|
||||
|
||||
距离你上次关注直播间消息已经过去了一段时间,你冷静了下来,请你输出一句话描述你现在的情绪状态
|
||||
请只输出情绪状态,不要输出其他内容:
|
||||
""",
|
||||
"regress_mood_prompt_vtb",
|
||||
)
|
||||
Prompt(
|
||||
"""
|
||||
{chat_talking_prompt}
|
||||
以上是直播间里正在进行的对话
|
||||
|
||||
{indentify_block}
|
||||
你刚刚的情绪状态是:{mood_state}
|
||||
具体来说,从1-10分,你的情绪状态是:
|
||||
喜(Joy): {joy}
|
||||
怒(Anger): {anger}
|
||||
哀(Sorrow): {sorrow}
|
||||
惧(Fear): {fear}
|
||||
|
||||
现在,发送了消息,引起了你的注意,你对其进行了阅读和思考。请基于对话内容,评估你新的情绪状态。
|
||||
请以JSON格式输出你新的情绪状态,包含"喜怒哀惧"四个维度,每个维度的取值范围为1-10。
|
||||
键值请使用英文: "joy", "anger", "sorrow", "fear".
|
||||
例如: {{"joy": 5, "anger": 1, "sorrow": 1, "fear": 1}}
|
||||
不要输出任何其他内容,只输出JSON。
|
||||
""",
|
||||
"change_mood_numerical_prompt",
|
||||
)
|
||||
Prompt(
|
||||
"""
|
||||
{chat_talking_prompt}
|
||||
以上是直播间里最近的对话
|
||||
|
||||
{indentify_block}
|
||||
你之前的情绪状态是:{mood_state}
|
||||
具体来说,从1-10分,你的情绪状态是:
|
||||
喜(Joy): {joy}
|
||||
怒(Anger): {anger}
|
||||
哀(Sorrow): {sorrow}
|
||||
惧(Fear): {fear}
|
||||
|
||||
距离你上次关注直播间消息已经过去了一段时间,你冷静了下来。请基于此,评估你现在的情绪状态。
|
||||
请以JSON格式输出你新的情绪状态,包含"喜怒哀惧"四个维度,每个维度的取值范围为1-10。
|
||||
键值请使用英文: "joy", "anger", "sorrow", "fear".
|
||||
例如: {{"joy": 5, "anger": 1, "sorrow": 1, "fear": 1}}
|
||||
不要输出任何其他内容,只输出JSON。
|
||||
""",
|
||||
"regress_mood_numerical_prompt",
|
||||
)
|
||||
|
||||
|
||||
class ChatMood:
|
||||
def __init__(self, chat_id: str):
|
||||
self.chat_id: str = chat_id
|
||||
self.mood_state: str = "感觉很平静"
|
||||
self.mood_values: dict[str, int] = {"joy": 5, "anger": 1, "sorrow": 1, "fear": 1}
|
||||
|
||||
self.regression_count: int = 0
|
||||
|
||||
self.mood_model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="mood_text")
|
||||
self.mood_model_numerical = LLMRequest(
|
||||
model_set=model_config.model_task_config.emotion, request_type="mood_numerical"
|
||||
)
|
||||
|
||||
self.last_change_time: float = 0
|
||||
|
||||
# 发送初始情绪状态到ws端
|
||||
asyncio.create_task(self.send_emotion_update(self.mood_values))
|
||||
|
||||
@staticmethod
|
||||
def _parse_numerical_mood(response: str) -> dict[str, int] | None:
|
||||
try:
|
||||
# The LLM might output markdown with json inside
|
||||
if "```json" in response:
|
||||
response = response.split("```json")[1].split("```")[0]
|
||||
elif "```" in response:
|
||||
response = response.split("```")[1].split("```")[0]
|
||||
|
||||
data = orjson.loads(response)
|
||||
|
||||
# Validate
|
||||
required_keys = {"joy", "anger", "sorrow", "fear"}
|
||||
if not required_keys.issubset(data.keys()):
|
||||
logger.warning(f"Numerical mood response missing keys: {response}")
|
||||
return None
|
||||
|
||||
for key in required_keys:
|
||||
value = data[key]
|
||||
if not isinstance(value, int) or not (1 <= value <= 10):
|
||||
logger.warning(f"Numerical mood response invalid value for {key}: {value} in {response}")
|
||||
return None
|
||||
|
||||
return {key: data[key] for key in required_keys}
|
||||
|
||||
except orjson.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse numerical mood JSON: {response}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing numerical mood: {e}, response: {response}")
|
||||
return None
|
||||
|
||||
async def update_mood_by_message(self, message: MessageRecv):
|
||||
self.regression_count = 0
|
||||
|
||||
message_time: float = message.message_info.time # type: ignore
|
||||
message_list_before_now = await get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_change_time,
|
||||
timestamp_end=message_time,
|
||||
limit=10,
|
||||
limit_mode="last",
|
||||
)
|
||||
chat_talking_prompt = await build_readable_messages(
|
||||
message_list_before_now,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=0.0,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
if global_config.bot.alias_names:
|
||||
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
|
||||
else:
|
||||
bot_nickname = ""
|
||||
|
||||
prompt_personality = global_config.personality.personality_core
|
||||
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:"
|
||||
|
||||
async def _update_text_mood():
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"change_mood_prompt_vtb",
|
||||
chat_talking_prompt=chat_talking_prompt,
|
||||
indentify_block=indentify_block,
|
||||
mood_state=self.mood_state,
|
||||
)
|
||||
logger.debug(f"text mood prompt: {prompt}")
|
||||
response, (reasoning_content, _, _) = await self.mood_model.generate_response_async(
|
||||
prompt=prompt, temperature=0.7
|
||||
)
|
||||
logger.info(f"text mood response: {response}")
|
||||
logger.debug(f"text mood reasoning_content: {reasoning_content}")
|
||||
return response
|
||||
|
||||
async def _update_numerical_mood():
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"change_mood_numerical_prompt",
|
||||
chat_talking_prompt=chat_talking_prompt,
|
||||
indentify_block=indentify_block,
|
||||
mood_state=self.mood_state,
|
||||
joy=self.mood_values["joy"],
|
||||
anger=self.mood_values["anger"],
|
||||
sorrow=self.mood_values["sorrow"],
|
||||
fear=self.mood_values["fear"],
|
||||
)
|
||||
logger.debug(f"numerical mood prompt: {prompt}")
|
||||
response, (reasoning_content, _, _) = await self.mood_model_numerical.generate_response_async(
|
||||
prompt=prompt, temperature=0.4
|
||||
)
|
||||
logger.info(f"numerical mood response: {response}")
|
||||
logger.debug(f"numerical mood reasoning_content: {reasoning_content}")
|
||||
return self._parse_numerical_mood(response)
|
||||
|
||||
results = await asyncio.gather(_update_text_mood(), _update_numerical_mood())
|
||||
text_mood_response, numerical_mood_response = results
|
||||
|
||||
if text_mood_response:
|
||||
self.mood_state = text_mood_response
|
||||
|
||||
if numerical_mood_response:
|
||||
_old_mood_values = self.mood_values.copy()
|
||||
self.mood_values = numerical_mood_response
|
||||
|
||||
# 发送情绪更新到ws端
|
||||
await self.send_emotion_update(self.mood_values)
|
||||
|
||||
logger.info(f"[{self.chat_id}] 情绪变化: {_old_mood_values} -> {self.mood_values}")
|
||||
|
||||
self.last_change_time = message_time
|
||||
|
||||
async def regress_mood(self):
|
||||
message_time = time.time()
|
||||
message_list_before_now = await get_raw_msg_by_timestamp_with_chat_inclusive(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_change_time,
|
||||
timestamp_end=message_time,
|
||||
limit=5,
|
||||
limit_mode="last",
|
||||
)
|
||||
chat_talking_prompt = await build_readable_messages(
|
||||
message_list_before_now,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=0.0,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
if global_config.bot.alias_names:
|
||||
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
|
||||
else:
|
||||
bot_nickname = ""
|
||||
|
||||
prompt_personality = global_config.personality.personality_core
|
||||
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:"
|
||||
|
||||
async def _regress_text_mood():
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"regress_mood_prompt_vtb",
|
||||
chat_talking_prompt=chat_talking_prompt,
|
||||
indentify_block=indentify_block,
|
||||
mood_state=self.mood_state,
|
||||
)
|
||||
logger.debug(f"text regress prompt: {prompt}")
|
||||
response, (reasoning_content, _, _) = await self.mood_model.generate_response_async(
|
||||
prompt=prompt, temperature=0.7
|
||||
)
|
||||
logger.info(f"text regress response: {response}")
|
||||
logger.debug(f"text regress reasoning_content: {reasoning_content}")
|
||||
return response
|
||||
|
||||
async def _regress_numerical_mood():
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"regress_mood_numerical_prompt",
|
||||
chat_talking_prompt=chat_talking_prompt,
|
||||
indentify_block=indentify_block,
|
||||
mood_state=self.mood_state,
|
||||
joy=self.mood_values["joy"],
|
||||
anger=self.mood_values["anger"],
|
||||
sorrow=self.mood_values["sorrow"],
|
||||
fear=self.mood_values["fear"],
|
||||
)
|
||||
logger.debug(f"numerical regress prompt: {prompt}")
|
||||
response, (reasoning_content, _, _) = await self.mood_model_numerical.generate_response_async(
|
||||
prompt=prompt,
|
||||
temperature=0.4,
|
||||
)
|
||||
logger.info(f"numerical regress response: {response}")
|
||||
logger.debug(f"numerical regress reasoning_content: {reasoning_content}")
|
||||
return self._parse_numerical_mood(response)
|
||||
|
||||
results = await asyncio.gather(_regress_text_mood(), _regress_numerical_mood())
|
||||
text_mood_response, numerical_mood_response = results
|
||||
|
||||
if text_mood_response:
|
||||
self.mood_state = text_mood_response
|
||||
|
||||
if numerical_mood_response:
|
||||
_old_mood_values = self.mood_values.copy()
|
||||
self.mood_values = numerical_mood_response
|
||||
|
||||
# 发送情绪更新到ws端
|
||||
await self.send_emotion_update(self.mood_values)
|
||||
|
||||
logger.info(f"[{self.chat_id}] 情绪回归: {_old_mood_values} -> {self.mood_values}")
|
||||
|
||||
self.regression_count += 1
|
||||
|
||||
async def send_emotion_update(self, mood_values: dict[str, int]):
|
||||
"""发送情绪更新到ws端"""
|
||||
emotion_data = {
|
||||
"joy": mood_values.get("joy", 5),
|
||||
"anger": mood_values.get("anger", 1),
|
||||
"sorrow": mood_values.get("sorrow", 1),
|
||||
"fear": mood_values.get("fear", 1),
|
||||
}
|
||||
|
||||
await send_api.custom_to_stream(
|
||||
message_type="emotion",
|
||||
content=emotion_data,
|
||||
stream_id=self.chat_id,
|
||||
storage_message=False,
|
||||
show_log=True,
|
||||
)
|
||||
|
||||
logger.info(f"[{self.chat_id}] 发送情绪更新: {emotion_data}")
|
||||
|
||||
|
||||
class MoodRegressionTask(AsyncTask):
|
||||
def __init__(self, mood_manager: "MoodManager"):
|
||||
super().__init__(task_name="MoodRegressionTask", run_interval=30)
|
||||
self.mood_manager = mood_manager
|
||||
self.run_count = 0
|
||||
|
||||
async def run(self):
|
||||
self.run_count += 1
|
||||
logger.info(f"[回归任务] 第{self.run_count}次检查,当前管理{len(self.mood_manager.mood_list)}个聊天的情绪状态")
|
||||
|
||||
now = time.time()
|
||||
regression_executed = 0
|
||||
|
||||
for mood in self.mood_manager.mood_list:
|
||||
chat_info = f"chat {mood.chat_id}"
|
||||
|
||||
if mood.last_change_time == 0:
|
||||
logger.debug(f"[回归任务] {chat_info} 尚未有情绪变化,跳过回归")
|
||||
continue
|
||||
|
||||
time_since_last_change = now - mood.last_change_time
|
||||
|
||||
# 检查是否有极端情绪需要快速回归
|
||||
high_emotions = {k: v for k, v in mood.mood_values.items() if v >= 8}
|
||||
has_extreme_emotion = len(high_emotions) > 0
|
||||
|
||||
# 回归条件:1. 正常时间间隔(120s) 或 2. 有极端情绪且距上次变化>=30s
|
||||
should_regress = False
|
||||
regress_reason = ""
|
||||
|
||||
if time_since_last_change > 120:
|
||||
should_regress = True
|
||||
regress_reason = f"常规回归(距上次变化{int(time_since_last_change)}秒)"
|
||||
elif has_extreme_emotion and time_since_last_change > 30:
|
||||
should_regress = True
|
||||
high_emotion_str = ", ".join([f"{k}={v}" for k, v in high_emotions.items()])
|
||||
regress_reason = f"极端情绪快速回归({high_emotion_str}, 距上次变化{int(time_since_last_change)}秒)"
|
||||
|
||||
if should_regress:
|
||||
if mood.regression_count >= 3:
|
||||
logger.debug(f"[回归任务] {chat_info} 已达到最大回归次数(3次),停止回归")
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"[回归任务] {chat_info} 开始情绪回归 ({regress_reason},第{mood.regression_count + 1}次回归)"
|
||||
)
|
||||
await mood.regress_mood()
|
||||
regression_executed += 1
|
||||
else:
|
||||
if has_extreme_emotion:
|
||||
remaining_time = 5 - time_since_last_change
|
||||
high_emotion_str = ", ".join([f"{k}={v}" for k, v in high_emotions.items()])
|
||||
logger.debug(
|
||||
f"[回归任务] {chat_info} 存在极端情绪({high_emotion_str}),距离快速回归还需等待{int(remaining_time)}秒"
|
||||
)
|
||||
else:
|
||||
remaining_time = 120 - time_since_last_change
|
||||
logger.debug(f"[回归任务] {chat_info} 距离回归还需等待{int(remaining_time)}秒")
|
||||
|
||||
if regression_executed > 0:
|
||||
logger.info(f"[回归任务] 本次执行了{regression_executed}个聊天的情绪回归")
|
||||
else:
|
||||
logger.debug("[回归任务] 本次没有符合回归条件的聊天")
|
||||
|
||||
|
||||
class MoodManager:
|
||||
def __init__(self):
|
||||
self.mood_list: list[ChatMood] = []
|
||||
"""当前情绪状态"""
|
||||
self.task_started: bool = False
|
||||
|
||||
async def start(self):
|
||||
"""启动情绪回归后台任务"""
|
||||
if self.task_started:
|
||||
return
|
||||
|
||||
logger.info("启动情绪管理任务...")
|
||||
|
||||
# 启动情绪回归任务
|
||||
regression_task = MoodRegressionTask(self)
|
||||
await async_task_manager.add_task(regression_task)
|
||||
|
||||
self.task_started = True
|
||||
logger.info("情绪管理任务已启动(情绪回归)")
|
||||
|
||||
def get_mood_by_chat_id(self, chat_id: str) -> ChatMood:
|
||||
for mood in self.mood_list:
|
||||
if mood.chat_id == chat_id:
|
||||
return mood
|
||||
|
||||
new_mood = ChatMood(chat_id)
|
||||
self.mood_list.append(new_mood)
|
||||
return new_mood
|
||||
|
||||
def reset_mood_by_chat_id(self, chat_id: str):
|
||||
for mood in self.mood_list:
|
||||
if mood.chat_id == chat_id:
|
||||
mood.mood_state = "感觉很平静"
|
||||
mood.mood_values = {"joy": 5, "anger": 1, "sorrow": 1, "fear": 1}
|
||||
mood.regression_count = 0
|
||||
# 发送重置后的情绪状态到ws端
|
||||
asyncio.create_task(mood.send_emotion_update(mood.mood_values))
|
||||
return
|
||||
|
||||
# 如果没有找到现有的mood,创建新的
|
||||
new_mood = ChatMood(chat_id)
|
||||
self.mood_list.append(new_mood)
|
||||
# 发送初始情绪状态到ws端
|
||||
asyncio.create_task(new_mood.send_emotion_update(new_mood.mood_values))
|
||||
|
||||
|
||||
if ENABLE_S4U:
|
||||
init_prompt()
|
||||
mood_manager = MoodManager()
|
||||
else:
|
||||
mood_manager = None
|
||||
|
||||
"""全局情绪管理器"""
|
||||
@@ -1,282 +0,0 @@
|
||||
import asyncio
|
||||
import math
|
||||
|
||||
from maim_message.message_base import GroupInfo
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
# 旧的Hippocampus系统已被移除,现在使用增强记忆系统
|
||||
# from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager
|
||||
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.chat.utils.utils import is_mentioned_bot_in_message
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.mais4u.mais4u_chat.body_emotion_action_manager import action_manager
|
||||
from src.mais4u.mais4u_chat.context_web_manager import get_context_web_manager
|
||||
from src.mais4u.mais4u_chat.gift_manager import gift_manager
|
||||
from src.mais4u.mais4u_chat.s4u_mood_manager import mood_manager
|
||||
from src.mais4u.mais4u_chat.s4u_watching_manager import watching_manager
|
||||
from src.mais4u.mais4u_chat.screen_manager import screen_manager
|
||||
|
||||
from .s4u_chat import get_s4u_chat_manager
|
||||
|
||||
# from ..message_receive.message_buffer import message_buffer
|
||||
|
||||
logger = get_logger("chat")
|
||||
|
||||
|
||||
async def _calculate_interest(message: MessageRecv) -> tuple[float, bool]:
|
||||
"""计算消息的兴趣度
|
||||
|
||||
Args:
|
||||
message: 待处理的消息对象
|
||||
|
||||
Returns:
|
||||
Tuple[float, bool]: (兴趣度, 是否被提及)
|
||||
"""
|
||||
is_mentioned, _ = is_mentioned_bot_in_message(message)
|
||||
interested_rate = 0.0
|
||||
|
||||
if global_config.memory.enable_memory:
|
||||
with Timer("记忆激活"):
|
||||
# 使用新的统一记忆系统计算兴趣度
|
||||
try:
|
||||
from src.chat.memory_system import get_memory_system
|
||||
|
||||
memory_system = get_memory_system()
|
||||
enhanced_memories = await memory_system.retrieve_relevant_memories(
|
||||
query_text=message.processed_plain_text,
|
||||
user_id=str(message.user_info.user_id),
|
||||
scope_id=message.chat_id,
|
||||
limit=5,
|
||||
)
|
||||
|
||||
# 基于检索结果计算兴趣度
|
||||
if enhanced_memories:
|
||||
# 有相关记忆,兴趣度基于相似度计算
|
||||
max_score = max(getattr(memory, "relevance_score", 0.5) for memory in enhanced_memories)
|
||||
interested_rate = min(max_score, 1.0) # 限制在0-1之间
|
||||
else:
|
||||
# 没有相关记忆,给予基础兴趣度
|
||||
interested_rate = 0.1
|
||||
|
||||
logger.debug(f"增强记忆系统兴趣度: {interested_rate:.2f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"增强记忆系统兴趣度计算失败: {e}")
|
||||
interested_rate = 0.1 # 默认基础兴趣度
|
||||
|
||||
text_len = len(message.processed_plain_text)
|
||||
# 根据文本长度分布调整兴趣度,采用分段函数实现更精确的兴趣度计算
|
||||
# 基于实际分布:0-5字符(26.57%), 6-10字符(27.18%), 11-20字符(22.76%), 21-30字符(10.33%), 31+字符(13.86%)
|
||||
|
||||
if text_len == 0:
|
||||
base_interest = 0.01 # 空消息最低兴趣度
|
||||
elif text_len <= 5:
|
||||
# 1-5字符:线性增长 0.01 -> 0.03
|
||||
base_interest = 0.01 + (text_len - 1) * (0.03 - 0.01) / 4
|
||||
elif text_len <= 10:
|
||||
# 6-10字符:线性增长 0.03 -> 0.06
|
||||
base_interest = 0.03 + (text_len - 5) * (0.06 - 0.03) / 5
|
||||
elif text_len <= 20:
|
||||
# 11-20字符:线性增长 0.06 -> 0.12
|
||||
base_interest = 0.06 + (text_len - 10) * (0.12 - 0.06) / 10
|
||||
elif text_len <= 30:
|
||||
# 21-30字符:线性增长 0.12 -> 0.18
|
||||
base_interest = 0.12 + (text_len - 20) * (0.18 - 0.12) / 10
|
||||
elif text_len <= 50:
|
||||
# 31-50字符:线性增长 0.18 -> 0.22
|
||||
base_interest = 0.18 + (text_len - 30) * (0.22 - 0.18) / 20
|
||||
elif text_len <= 100:
|
||||
# 51-100字符:线性增长 0.22 -> 0.26
|
||||
base_interest = 0.22 + (text_len - 50) * (0.26 - 0.22) / 50
|
||||
else:
|
||||
# 100+字符:对数增长 0.26 -> 0.3,增长率递减
|
||||
base_interest = 0.26 + (0.3 - 0.26) * (math.log10(text_len - 99) / math.log10(901)) # 1000-99=901
|
||||
|
||||
# 确保在范围内
|
||||
base_interest = min(max(base_interest, 0.01), 0.3)
|
||||
|
||||
interested_rate += base_interest
|
||||
|
||||
if is_mentioned:
|
||||
interest_increase_on_mention = 1
|
||||
interested_rate += interest_increase_on_mention
|
||||
|
||||
return interested_rate, is_mentioned
|
||||
|
||||
|
||||
class S4UMessageProcessor:
|
||||
"""心流处理器,负责处理接收到的消息并计算兴趣度"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化心流处理器,创建消息存储实例"""
|
||||
self.storage = MessageStorage()
|
||||
|
||||
async def process_message(self, message: MessageRecvS4U, skip_gift_debounce: bool = False) -> None:
|
||||
"""处理接收到的原始消息数据
|
||||
|
||||
主要流程:
|
||||
1. 消息解析与初始化
|
||||
2. 消息缓冲处理
|
||||
3. 过滤检查
|
||||
4. 兴趣度计算
|
||||
5. 关系处理
|
||||
|
||||
Args:
|
||||
message_data: 原始消息字符串
|
||||
"""
|
||||
|
||||
# 1. 消息解析与初始化
|
||||
groupinfo = message.message_info.group_info
|
||||
userinfo = message.message_info.user_info
|
||||
message_info = message.message_info
|
||||
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform=message_info.platform,
|
||||
user_info=userinfo,
|
||||
group_info=groupinfo,
|
||||
)
|
||||
|
||||
if await self.handle_internal_message(message):
|
||||
return
|
||||
|
||||
if await self.hadle_if_voice_done(message):
|
||||
return
|
||||
|
||||
# 处理礼物消息,如果消息被暂存则停止当前处理流程
|
||||
if not skip_gift_debounce and not await self.handle_if_gift(message):
|
||||
return
|
||||
await self.check_if_fake_gift(message)
|
||||
|
||||
# 处理屏幕消息
|
||||
if await self.handle_screen_message(message):
|
||||
return
|
||||
|
||||
await self.storage.store_message(message, chat)
|
||||
|
||||
s4u_chat = await get_s4u_chat_manager().get_or_create_chat(chat)
|
||||
|
||||
await s4u_chat.add_message(message)
|
||||
|
||||
_interested_rate, _ = await _calculate_interest(message)
|
||||
|
||||
await mood_manager.start()
|
||||
|
||||
# 一系列llm驱动的前处理
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(chat.stream_id)
|
||||
asyncio.create_task(chat_mood.update_mood_by_message(message))
|
||||
chat_action = action_manager.get_action_state_by_chat_id(chat.stream_id)
|
||||
asyncio.create_task(chat_action.update_action_by_message(message))
|
||||
# 视线管理:收到消息时切换视线状态
|
||||
chat_watching = watching_manager.get_watching_by_chat_id(chat.stream_id)
|
||||
await chat_watching.on_message_received()
|
||||
|
||||
# 上下文网页管理:启动独立task处理消息上下文
|
||||
asyncio.create_task(self._handle_context_web_update(chat.stream_id, message))
|
||||
|
||||
# 日志记录
|
||||
if message.is_gift:
|
||||
logger.info(f"[S4U-礼物] {userinfo.user_nickname} 送出了 {message.gift_name} x{message.gift_count}")
|
||||
else:
|
||||
logger.info(f"[S4U]{userinfo.user_nickname}:{message.processed_plain_text}")
|
||||
|
||||
@staticmethod
|
||||
async def handle_internal_message(message: MessageRecvS4U):
|
||||
if message.is_internal:
|
||||
group_info = GroupInfo(platform="amaidesu_default", group_id=660154, group_name="内心")
|
||||
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform="amaidesu_default", user_info=message.message_info.user_info, group_info=group_info
|
||||
)
|
||||
s4u_chat = await get_s4u_chat_manager().get_or_create_chat(chat)
|
||||
message.message_info.group_info = s4u_chat.chat_stream.group_info
|
||||
message.message_info.platform = s4u_chat.chat_stream.platform
|
||||
|
||||
s4u_chat.internal_message.append(message)
|
||||
s4u_chat.new_message_event.set()
|
||||
|
||||
logger.info(
|
||||
f"[{s4u_chat.stream_name}] 添加内部消息-------------------------------------------------------: {message.processed_plain_text}"
|
||||
)
|
||||
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def handle_screen_message(message: MessageRecvS4U):
|
||||
if message.is_screen:
|
||||
screen_manager.set_screen(message.screen_info)
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def hadle_if_voice_done(message: MessageRecvS4U):
|
||||
if message.voice_done:
|
||||
s4u_chat = await get_s4u_chat_manager().get_or_create_chat(message.chat_stream)
|
||||
s4u_chat.voice_done = message.voice_done
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
async def check_if_fake_gift(message: MessageRecvS4U) -> bool:
|
||||
"""检查消息是否为假礼物"""
|
||||
if message.is_gift:
|
||||
return False
|
||||
|
||||
gift_keywords = ["送出了礼物", "礼物", "送出了", "投喂"]
|
||||
if any(keyword in message.processed_plain_text for keyword in gift_keywords):
|
||||
message.is_fake_gift = True
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def handle_if_gift(self, message: MessageRecvS4U) -> bool:
|
||||
"""处理礼物消息
|
||||
|
||||
Returns:
|
||||
bool: True表示应该继续处理消息,False表示消息已被暂存不需要继续处理
|
||||
"""
|
||||
if message.is_gift:
|
||||
# 定义防抖完成后的回调函数
|
||||
def gift_callback(merged_message: MessageRecvS4U):
|
||||
"""礼物防抖完成后的回调"""
|
||||
# 创建异步任务来处理合并后的礼物消息,跳过防抖处理
|
||||
asyncio.create_task(self.process_message(merged_message, skip_gift_debounce=True))
|
||||
|
||||
# 交给礼物管理器处理,并传入回调函数
|
||||
# 对于礼物消息,handle_gift 总是返回 False(消息被暂存)
|
||||
await gift_manager.handle_gift(message, gift_callback)
|
||||
return False # 消息被暂存,不继续处理
|
||||
|
||||
return True # 非礼物消息,继续正常处理
|
||||
|
||||
@staticmethod
|
||||
async def _handle_context_web_update(chat_id: str, message: MessageRecv):
|
||||
"""处理上下文网页更新的独立task
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
message: 消息对象
|
||||
"""
|
||||
try:
|
||||
logger.debug(f"🔄 开始处理上下文网页更新: {message.message_info.user_info.user_nickname}")
|
||||
|
||||
context_manager = get_context_web_manager()
|
||||
|
||||
# 只在服务器未启动时启动(避免重复启动)
|
||||
if context_manager.site is None:
|
||||
logger.info("🚀 首次启动上下文网页服务器...")
|
||||
await context_manager.start_server()
|
||||
|
||||
# 添加消息到上下文并更新网页
|
||||
await asyncio.sleep(1.5)
|
||||
|
||||
await context_manager.add_message(chat_id, message)
|
||||
|
||||
logger.debug(f"✅ 上下文网页更新完成: {message.message_info.user_info.user_nickname}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 处理上下文网页更新失败: {e}", exc_info=True)
|
||||
@@ -1,443 +0,0 @@
|
||||
import asyncio
|
||||
|
||||
# 旧的Hippocampus系统已被移除,现在使用增强记忆系统
|
||||
# from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager
|
||||
import random
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
from src.chat.express.expression_selector import expression_selector
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.message_receive.message import MessageRecvS4U
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.chat.utils.utils import get_recent_group_speaker
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.mais4u.mais4u_chat.internal_manager import internal_manager
|
||||
from src.mais4u.mais4u_chat.screen_manager import screen_manager
|
||||
from src.mais4u.mais4u_chat.super_chat_manager import get_super_chat_manager
|
||||
from src.mais4u.s4u_config import s4u_config
|
||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||
from src.person_info.relationship_fetcher import relationship_fetcher_manager
|
||||
|
||||
from .s4u_mood_manager import mood_manager
|
||||
|
||||
logger = get_logger("prompt")
|
||||
|
||||
|
||||
def init_prompt():
|
||||
Prompt("\n你有以下这些**知识**:\n{prompt_info}\n请你**记住上面的知识**,之后可能会用到。\n", "knowledge_prompt")
|
||||
Prompt("\n关于你们的关系,你需要知道:\n{relation_info}\n", "relation_prompt")
|
||||
Prompt("你回想起了一些事情:\n{memory_info}\n", "memory_prompt")
|
||||
|
||||
Prompt(
|
||||
"""
|
||||
你的名字是麦麦, 是千石可乐开发的程序,可以在QQ,微信等平台发言,你现在正在哔哩哔哩作为虚拟主播进行直播
|
||||
虽然是个AI,但你在努力扮演真人, 人设的配色是橙色和紫色和绿色
|
||||
你有一头法式橙色卷发,你很可爱,穿戴英式侦探风格学院风裙子和帽子,你正在一个教室场景里进行虚拟线上直播。不过不要刻意提及这些内容,自然回复
|
||||
你可以看见用户发送的弹幕,礼物和superchat
|
||||
{screen_info}
|
||||
{internal_state}
|
||||
|
||||
{relation_info_block}
|
||||
{memory_block}
|
||||
{expression_habits_block}
|
||||
|
||||
你现在的主要任务是和 {sender_name} 发送的弹幕聊天。同时,也有其他用户会参与你们的聊天,你可以参考他们的回复内容,但是你主要还是关注你和{sender_name}的聊天内容。
|
||||
|
||||
{sc_info}
|
||||
|
||||
{background_dialogue_prompt}
|
||||
--------------------------------
|
||||
{time_block}
|
||||
这是你和{sender_name}的对话,你们正在交流中:
|
||||
{core_dialogue_prompt}
|
||||
|
||||
对方最新发送的内容:{message_txt}
|
||||
{gift_info}
|
||||
回复简短一些,平淡一些,可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞。
|
||||
表现的有个性,不要随意服从他人要求,积极互动。你现在的心情是:{mood_state}
|
||||
不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出回复内容,现在{sender_name}正在等待你的回复。
|
||||
你的回复风格不要浮夸,有逻辑和条理,请你继续回复{sender_name}。
|
||||
你的发言:
|
||||
""",
|
||||
"s4u_prompt", # New template for private CHAT chat
|
||||
)
|
||||
|
||||
Prompt(
|
||||
"""
|
||||
你的名字是麦麦, 是千石可乐开发的程序,可以在QQ,微信等平台发言,你现在正在哔哩哔哩作为虚拟主播进行直播
|
||||
虽然是个AI,但你在努力扮演真人, 人设的配色是橙色和紫色和绿色
|
||||
你有一头法式橙色卷发,你很可爱,穿戴英式侦探风格学院风裙子和帽子,你正在一个教室场景里进行虚拟线上直播。不过不要刻意提及这些内容,自然回复
|
||||
你可以看见用户发送的弹幕,礼物和superchat
|
||||
你可以看见面前的屏幕,目前屏幕的内容是:
|
||||
{screen_info}
|
||||
|
||||
{memory_block}
|
||||
{expression_habits_block}
|
||||
|
||||
{sc_info}
|
||||
|
||||
{time_block}
|
||||
{chat_info_danmu}
|
||||
--------------------------------
|
||||
以上是你和弹幕的对话,与此同时,你在与QQ群友聊天,聊天记录如下:
|
||||
{chat_info_qq}
|
||||
--------------------------------
|
||||
你刚刚回复了QQ群,你内心的想法是:{mind}
|
||||
请根据你内心的想法,组织一条回复,在直播间进行发言,可以点名吐槽对象,让观众知道你在说谁
|
||||
{gift_info}
|
||||
回复简短一些,平淡一些,可以参考贴吧,知乎和微博的回复风格。不要浮夸,有逻辑和条理。
|
||||
表现的有个性,不要随意服从他人要求,积极互动。你现在的心情是:{mood_state}
|
||||
不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。
|
||||
你的发言:
|
||||
""",
|
||||
"s4u_prompt_internal", # New template for private CHAT chat
|
||||
)
|
||||
|
||||
|
||||
class PromptBuilder:
|
||||
def __init__(self):
|
||||
self.prompt_built = ""
|
||||
self.activate_messages = ""
|
||||
|
||||
@staticmethod
|
||||
async def build_expression_habits(chat_stream: ChatStream, chat_history, target):
|
||||
style_habits = []
|
||||
grammar_habits = []
|
||||
|
||||
# 使用统一的表达方式选择入口(支持classic和exp_model模式)
|
||||
selected_expressions = await expression_selector.select_suitable_expressions(
|
||||
chat_id=chat_stream.stream_id,
|
||||
chat_history=chat_history,
|
||||
target_message=target,
|
||||
max_num=12,
|
||||
min_num=5
|
||||
)
|
||||
|
||||
if selected_expressions:
|
||||
logger.debug(f" 使用处理器选中的{len(selected_expressions)}个表达方式")
|
||||
for expr in selected_expressions:
|
||||
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
|
||||
expr_type = expr.get("type", "style")
|
||||
if expr_type == "grammar":
|
||||
grammar_habits.append(f"当{expr['situation']}时,使用 {expr['style']}")
|
||||
else:
|
||||
style_habits.append(f"当{expr['situation']}时,使用 {expr['style']}")
|
||||
else:
|
||||
logger.debug("没有从处理器获得表达方式,将使用空的表达方式")
|
||||
# 不再在replyer中进行随机选择,全部交给处理器处理
|
||||
|
||||
style_habits_str = "\n".join(style_habits)
|
||||
grammar_habits_str = "\n".join(grammar_habits)
|
||||
|
||||
# 动态构建expression habits块
|
||||
expression_habits_block = ""
|
||||
if style_habits_str.strip():
|
||||
expression_habits_block += f"你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中:\n{style_habits_str}\n\n"
|
||||
if grammar_habits_str.strip():
|
||||
expression_habits_block += f"请你根据情景使用以下句法:\n{grammar_habits_str}\n"
|
||||
|
||||
return expression_habits_block
|
||||
|
||||
@staticmethod
|
||||
async def build_relation_info(chat_stream) -> str:
|
||||
is_group_chat = bool(chat_stream.group_info)
|
||||
who_chat_in_group = []
|
||||
if is_group_chat:
|
||||
who_chat_in_group = get_recent_group_speaker(
|
||||
chat_stream.stream_id,
|
||||
(chat_stream.user_info.platform, chat_stream.user_info.user_id) if chat_stream.user_info else None,
|
||||
limit=global_config.chat.max_context_size,
|
||||
)
|
||||
elif chat_stream.user_info:
|
||||
who_chat_in_group.append(
|
||||
(chat_stream.user_info.platform, chat_stream.user_info.user_id, chat_stream.user_info.user_nickname)
|
||||
)
|
||||
|
||||
relation_prompt = ""
|
||||
if global_config.affinity_flow.enable_relationship_tracking and who_chat_in_group:
|
||||
relationship_fetcher = relationship_fetcher_manager.get_fetcher(chat_stream.stream_id)
|
||||
|
||||
# 将 (platform, user_id, nickname) 转换为 person_id
|
||||
person_ids = []
|
||||
for person in who_chat_in_group:
|
||||
person_id = PersonInfoManager.get_person_id(person[0], person[1])
|
||||
person_ids.append(person_id)
|
||||
|
||||
# 构建用户关系信息和聊天流印象信息
|
||||
user_relation_tasks = [relationship_fetcher.build_relation_info(person_id, points_num=3) for person_id in person_ids]
|
||||
stream_impression_task = relationship_fetcher.build_chat_stream_impression(chat_stream.stream_id)
|
||||
|
||||
# 并行获取所有信息
|
||||
results = await asyncio.gather(*user_relation_tasks, stream_impression_task)
|
||||
relation_info_list = results[:-1] # 用户关系信息
|
||||
stream_impression = results[-1] # 聊天流印象
|
||||
|
||||
# 组合用户关系信息和聊天流印象
|
||||
combined_info_parts = []
|
||||
if user_relation_info := "".join(relation_info_list):
|
||||
combined_info_parts.append(user_relation_info)
|
||||
if stream_impression:
|
||||
combined_info_parts.append(stream_impression)
|
||||
|
||||
if combined_info := "\n\n".join(combined_info_parts):
|
||||
relation_prompt = await global_prompt_manager.format_prompt(
|
||||
"relation_prompt", relation_info=combined_info
|
||||
)
|
||||
return relation_prompt
|
||||
|
||||
@staticmethod
|
||||
async def build_memory_block(text: str) -> str:
|
||||
# 使用新的统一记忆系统检索记忆
|
||||
try:
|
||||
from src.chat.memory_system import get_memory_system
|
||||
|
||||
memory_system = get_memory_system()
|
||||
enhanced_memories = await memory_system.retrieve_relevant_memories(
|
||||
query_text=text,
|
||||
user_id="system", # 系统查询
|
||||
scope_id="system",
|
||||
limit=5,
|
||||
)
|
||||
|
||||
related_memory_info = ""
|
||||
if enhanced_memories:
|
||||
for memory_chunk in enhanced_memories:
|
||||
related_memory_info += memory_chunk.display or memory_chunk.text_content or ""
|
||||
return await global_prompt_manager.format_prompt(
|
||||
"memory_prompt", memory_info=related_memory_info.strip()
|
||||
)
|
||||
return ""
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"增强记忆系统检索失败: {e}")
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
async def build_chat_history_prompts(chat_stream: ChatStream, message: MessageRecvS4U):
|
||||
message_list_before_now = await get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_stream.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=300,
|
||||
)
|
||||
|
||||
talk_type = f"{message.message_info.platform}:{message.chat_stream.user_info.user_id!s}"
|
||||
|
||||
core_dialogue_list = []
|
||||
background_dialogue_list = []
|
||||
bot_id = str(global_config.bot.qq_account)
|
||||
target_user_id = str(message.chat_stream.user_info.user_id)
|
||||
|
||||
for msg_dict in message_list_before_now:
|
||||
try:
|
||||
msg_user_id = str(msg_dict.get("user_id"))
|
||||
if msg_user_id == bot_id:
|
||||
if msg_dict.get("reply_to") and talk_type == msg_dict.get("reply_to"):
|
||||
core_dialogue_list.append(msg_dict)
|
||||
elif msg_dict.get("reply_to") and talk_type != msg_dict.get("reply_to"):
|
||||
background_dialogue_list.append(msg_dict)
|
||||
# else:
|
||||
# background_dialogue_list.append(msg_dict)
|
||||
elif msg_user_id == target_user_id:
|
||||
core_dialogue_list.append(msg_dict)
|
||||
else:
|
||||
background_dialogue_list.append(msg_dict)
|
||||
except Exception as e:
|
||||
logger.error(f"无法处理历史消息记录: {msg_dict}, 错误: {e}")
|
||||
|
||||
background_dialogue_prompt = ""
|
||||
if background_dialogue_list:
|
||||
context_msgs = background_dialogue_list[-s4u_config.max_context_message_length :]
|
||||
background_dialogue_prompt_str = await build_readable_messages(
|
||||
context_msgs,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
show_pic=False,
|
||||
)
|
||||
background_dialogue_prompt = f"这是其他用户的发言:\n{background_dialogue_prompt_str}"
|
||||
|
||||
core_msg_str = ""
|
||||
if core_dialogue_list:
|
||||
core_dialogue_list = core_dialogue_list[-s4u_config.max_core_message_length :]
|
||||
|
||||
first_msg = core_dialogue_list[0]
|
||||
start_speaking_user_id = first_msg.get("user_id")
|
||||
if start_speaking_user_id == bot_id:
|
||||
last_speaking_user_id = bot_id
|
||||
msg_seg_str = "你的发言:\n"
|
||||
else:
|
||||
start_speaking_user_id = target_user_id
|
||||
last_speaking_user_id = start_speaking_user_id
|
||||
msg_seg_str = "对方的发言:\n"
|
||||
|
||||
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(first_msg.get('time')))}: {first_msg.get('processed_plain_text')}\n"
|
||||
|
||||
all_msg_seg_list = []
|
||||
for msg in core_dialogue_list[1:]:
|
||||
speaker = msg.get("user_id")
|
||||
if speaker == last_speaking_user_id:
|
||||
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(msg.get('time')))}: {msg.get('processed_plain_text')}\n"
|
||||
else:
|
||||
msg_seg_str = f"{msg_seg_str}\n"
|
||||
all_msg_seg_list.append(msg_seg_str)
|
||||
|
||||
if speaker == bot_id:
|
||||
msg_seg_str = "你的发言:\n"
|
||||
else:
|
||||
msg_seg_str = "对方的发言:\n"
|
||||
|
||||
msg_seg_str += f"{time.strftime('%H:%M:%S', time.localtime(msg.get('time')))}: {msg.get('processed_plain_text')}\n"
|
||||
last_speaking_user_id = speaker
|
||||
|
||||
all_msg_seg_list.append(msg_seg_str)
|
||||
for msg in all_msg_seg_list:
|
||||
core_msg_str += msg
|
||||
|
||||
all_dialogue_prompt = await get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_stream.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=20,
|
||||
)
|
||||
all_dialogue_prompt_str = await build_readable_messages(
|
||||
all_dialogue_prompt,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
show_pic=False,
|
||||
)
|
||||
|
||||
return core_msg_str, background_dialogue_prompt, all_dialogue_prompt_str
|
||||
|
||||
@staticmethod
|
||||
def build_gift_info(message: MessageRecvS4U):
|
||||
if message.is_gift:
|
||||
return f"这是一条礼物信息,{message.gift_name} x{message.gift_count},请注意这位用户"
|
||||
else:
|
||||
if message.is_fake_gift:
|
||||
return f"{message.processed_plain_text}(注意:这是一条普通弹幕信息,对方没有真的发送礼物,不是礼物信息,注意区分,如果对方在发假的礼物骗你,请反击)"
|
||||
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def build_sc_info(message: MessageRecvS4U):
|
||||
super_chat_manager = get_super_chat_manager()
|
||||
return super_chat_manager.build_superchat_summary_string(message.chat_stream.stream_id)
|
||||
|
||||
async def build_prompt_normal(
|
||||
self,
|
||||
message: MessageRecvS4U,
|
||||
message_txt: str,
|
||||
) -> str:
|
||||
chat_stream = message.chat_stream
|
||||
|
||||
person_id = PersonInfoManager.get_person_id(
|
||||
message.chat_stream.user_info.platform, message.chat_stream.user_info.user_id
|
||||
)
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||
|
||||
if message.chat_stream.user_info.user_nickname:
|
||||
if person_name:
|
||||
sender_name = f"[{message.chat_stream.user_info.user_nickname}](你叫ta{person_name})"
|
||||
else:
|
||||
sender_name = f"[{message.chat_stream.user_info.user_nickname}]"
|
||||
else:
|
||||
sender_name = f"用户({message.chat_stream.user_info.user_id})"
|
||||
|
||||
relation_info_block, memory_block, expression_habits_block = await asyncio.gather(
|
||||
self.build_relation_info(chat_stream),
|
||||
self.build_memory_block(message_txt),
|
||||
self.build_expression_habits(chat_stream, message_txt, sender_name),
|
||||
)
|
||||
|
||||
core_dialogue_prompt, background_dialogue_prompt, all_dialogue_prompt = await self.build_chat_history_prompts(
|
||||
chat_stream, message
|
||||
)
|
||||
|
||||
gift_info = self.build_gift_info(message)
|
||||
|
||||
sc_info = self.build_sc_info(message)
|
||||
|
||||
screen_info = screen_manager.get_screen_str()
|
||||
|
||||
internal_state = internal_manager.get_internal_state_str()
|
||||
|
||||
time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
|
||||
mood = mood_manager.get_mood_by_chat_id(chat_stream.stream_id)
|
||||
|
||||
template_name = "s4u_prompt"
|
||||
|
||||
if not message.is_internal:
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
template_name,
|
||||
time_block=time_block,
|
||||
expression_habits_block=expression_habits_block,
|
||||
relation_info_block=relation_info_block,
|
||||
memory_block=memory_block,
|
||||
screen_info=screen_info,
|
||||
internal_state=internal_state,
|
||||
gift_info=gift_info,
|
||||
sc_info=sc_info,
|
||||
sender_name=sender_name,
|
||||
core_dialogue_prompt=core_dialogue_prompt,
|
||||
background_dialogue_prompt=background_dialogue_prompt,
|
||||
message_txt=message_txt,
|
||||
mood_state=mood.mood_state,
|
||||
)
|
||||
else:
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"s4u_prompt_internal",
|
||||
time_block=time_block,
|
||||
expression_habits_block=expression_habits_block,
|
||||
relation_info_block=relation_info_block,
|
||||
memory_block=memory_block,
|
||||
screen_info=screen_info,
|
||||
gift_info=gift_info,
|
||||
sc_info=sc_info,
|
||||
chat_info_danmu=all_dialogue_prompt,
|
||||
chat_info_qq=message.chat_info,
|
||||
mind=message.processed_plain_text,
|
||||
mood_state=mood.mood_state,
|
||||
)
|
||||
|
||||
# print(prompt)
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def weighted_sample_no_replacement(items, weights, k) -> list:
|
||||
"""
|
||||
加权且不放回地随机抽取k个元素。
|
||||
|
||||
参数:
|
||||
items: 待抽取的元素列表
|
||||
weights: 每个元素对应的权重(与items等长,且为正数)
|
||||
k: 需要抽取的元素个数
|
||||
返回:
|
||||
selected: 按权重加权且不重复抽取的k个元素组成的列表
|
||||
|
||||
如果 items 中的元素不足 k 个,就只会返回所有可用的元素
|
||||
|
||||
实现思路:
|
||||
每次从当前池中按权重加权随机选出一个元素,选中后将其从池中移除,重复k次。
|
||||
这样保证了:
|
||||
1. count越大被选中概率越高
|
||||
2. 不会重复选中同一个元素
|
||||
"""
|
||||
selected = []
|
||||
pool = list(zip(items, weights, strict=False))
|
||||
for _ in range(min(k, len(pool))):
|
||||
total = sum(w for _, w in pool)
|
||||
r = random.uniform(0, total)
|
||||
upto = 0
|
||||
for idx, (item, weight) in enumerate(pool):
|
||||
upto += weight
|
||||
if upto >= r:
|
||||
selected.append(item)
|
||||
pool.pop(idx)
|
||||
break
|
||||
return selected
|
||||
|
||||
|
||||
init_prompt()
|
||||
prompt_builder = PromptBuilder()
|
||||
@@ -1,168 +0,0 @@
|
||||
import asyncio
|
||||
import re
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from src.chat.message_receive.message import MessageRecvS4U
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import model_config
|
||||
from src.mais4u.mais4u_chat.s4u_prompt import prompt_builder
|
||||
from src.mais4u.openai_client import AsyncOpenAIClient
|
||||
|
||||
logger = get_logger("s4u_stream_generator")
|
||||
|
||||
|
||||
class S4UStreamGenerator:
|
||||
def __init__(self):
|
||||
replyer_config = model_config.model_task_config.replyer
|
||||
model_to_use = replyer_config.model_list[0]
|
||||
model_info = model_config.get_model_info(model_to_use)
|
||||
if not model_info:
|
||||
logger.error(f"模型 {model_to_use} 在配置中未找到")
|
||||
raise ValueError(f"模型 {model_to_use} 在配置中未找到")
|
||||
provider_name = model_info.api_provider
|
||||
provider_info = model_config.get_provider(provider_name)
|
||||
if not provider_info:
|
||||
logger.error("`replyer` 找不到对应的Provider")
|
||||
raise ValueError("`replyer` 找不到对应的Provider")
|
||||
|
||||
api_key = provider_info.api_key
|
||||
base_url = provider_info.base_url
|
||||
|
||||
if not api_key:
|
||||
logger.error(f"{provider_name}没有配置API KEY")
|
||||
raise ValueError(f"{provider_name}没有配置API KEY")
|
||||
|
||||
self.client_1 = AsyncOpenAIClient(api_key=api_key, base_url=base_url)
|
||||
self.model_1_name = model_to_use
|
||||
self.replyer_config = replyer_config
|
||||
|
||||
self.current_model_name = "unknown model"
|
||||
self.partial_response = ""
|
||||
|
||||
# 正则表达式用于按句子切分,同时处理各种标点和边缘情况
|
||||
# 匹配常见的句子结束符,但会忽略引号内和数字中的标点
|
||||
self.sentence_split_pattern = re.compile(
|
||||
r'([^\s\w"\'([{]*["\'([{].*?["\'}\])][^\s\w"\'([{]*|' # 匹配被引号/括号包裹的内容
|
||||
r'[^.。!??!\n\r]+(?:[.。!??!\n\r](?![\'"])|$))', # 匹配直到句子结束符
|
||||
re.UNICODE | re.DOTALL,
|
||||
)
|
||||
|
||||
self.chat_stream = None
|
||||
|
||||
@staticmethod
|
||||
async def build_last_internal_message(message: MessageRecvS4U, previous_reply_context: str = ""):
|
||||
# person_id = PersonInfoManager.get_person_id(
|
||||
# message.chat_stream.user_info.platform, message.chat_stream.user_info.user_id
|
||||
# )
|
||||
# person_info_manager = get_person_info_manager()
|
||||
# person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||
|
||||
# if message.chat_stream.user_info.user_nickname:
|
||||
# if person_name:
|
||||
# sender_name = f"[{message.chat_stream.user_info.user_nickname}](你叫ta{person_name})"
|
||||
# else:
|
||||
# sender_name = f"[{message.chat_stream.user_info.user_nickname}]"
|
||||
# else:
|
||||
# sender_name = f"用户({message.chat_stream.user_info.user_id})"
|
||||
|
||||
# 构建prompt
|
||||
if previous_reply_context:
|
||||
message_txt = f"""
|
||||
你正在回复用户的消息,但中途被打断了。这是已有的对话上下文:
|
||||
[你已经对上一条消息说的话]: {previous_reply_context}
|
||||
---
|
||||
[这是用户发来的新消息, 你需要结合上下文,对此进行回复]:
|
||||
{message.processed_plain_text}
|
||||
"""
|
||||
return True, message_txt
|
||||
else:
|
||||
message_txt = message.processed_plain_text
|
||||
return False, message_txt
|
||||
|
||||
async def generate_response(
|
||||
self, message: MessageRecvS4U, previous_reply_context: str = ""
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""根据当前模型类型选择对应的生成函数"""
|
||||
# 从global_config中获取模型概率值并选择模型
|
||||
self.partial_response = ""
|
||||
message_txt = message.processed_plain_text
|
||||
if not message.is_internal:
|
||||
interupted, message_txt_added = await self.build_last_internal_message(message, previous_reply_context)
|
||||
if interupted:
|
||||
message_txt = message_txt_added
|
||||
|
||||
message.chat_stream = self.chat_stream
|
||||
prompt = await prompt_builder.build_prompt_normal(
|
||||
message=message,
|
||||
message_txt=message_txt,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"{self.current_model_name}思考:{message_txt[:30] + '...' if len(message_txt) > 30 else message_txt}"
|
||||
)
|
||||
|
||||
current_client = self.client_1
|
||||
self.current_model_name = self.model_1_name
|
||||
|
||||
extra_kwargs = {}
|
||||
if self.replyer_config.get("enable_thinking") is not None:
|
||||
extra_kwargs["enable_thinking"] = self.replyer_config.get("enable_thinking")
|
||||
if self.replyer_config.get("thinking_budget") is not None:
|
||||
extra_kwargs["thinking_budget"] = self.replyer_config.get("thinking_budget")
|
||||
|
||||
async for chunk in self._generate_response_with_model(
|
||||
prompt, current_client, self.current_model_name, **extra_kwargs
|
||||
):
|
||||
yield chunk
|
||||
|
||||
async def _generate_response_with_model(
|
||||
self,
|
||||
prompt: str,
|
||||
client: AsyncOpenAIClient,
|
||||
model_name: str,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
buffer = ""
|
||||
delimiters = ",。!?,.!?\n\r" # For final trimming
|
||||
punctuation_buffer = ""
|
||||
|
||||
async for content in client.get_stream_content(
|
||||
messages=[{"role": "user", "content": prompt}], model=model_name, **kwargs
|
||||
):
|
||||
buffer += content
|
||||
|
||||
# 使用正则表达式匹配句子
|
||||
last_match_end = 0
|
||||
for match in self.sentence_split_pattern.finditer(buffer):
|
||||
sentence = match.group(0).strip()
|
||||
if sentence:
|
||||
# 如果句子看起来完整(即不只是等待更多内容),则发送
|
||||
if match.end(0) < len(buffer) or sentence.endswith(tuple(delimiters)):
|
||||
# 检查是否只是一个标点符号
|
||||
if sentence in [",", ",", ".", "。", "!", "!", "?", "?"]:
|
||||
punctuation_buffer += sentence
|
||||
else:
|
||||
# 发送之前累积的标点和当前句子
|
||||
to_yield = punctuation_buffer + sentence
|
||||
if to_yield.endswith((",", ",")):
|
||||
to_yield = to_yield.rstrip(",,")
|
||||
|
||||
self.partial_response += to_yield
|
||||
yield to_yield
|
||||
punctuation_buffer = "" # 清空标点符号缓冲区
|
||||
await asyncio.sleep(0) # 允许其他任务运行
|
||||
|
||||
last_match_end = match.end(0)
|
||||
|
||||
# 从缓冲区移除已发送的部分
|
||||
if last_match_end > 0:
|
||||
buffer = buffer[last_match_end:]
|
||||
|
||||
# 发送缓冲区中剩余的任何内容
|
||||
to_yield = (punctuation_buffer + buffer).strip()
|
||||
if to_yield:
|
||||
if to_yield.endswith((",", ",")):
|
||||
to_yield = to_yield.rstrip(",,")
|
||||
if to_yield:
|
||||
self.partial_response += to_yield
|
||||
yield to_yield
|
||||
@@ -1,106 +0,0 @@
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.apis import send_api
|
||||
|
||||
"""
|
||||
视线管理系统使用说明:
|
||||
|
||||
1. 视线状态:
|
||||
- wandering: 随意看
|
||||
- danmu: 看弹幕
|
||||
- lens: 看镜头
|
||||
|
||||
2. 状态切换逻辑:
|
||||
- 收到消息时 → 切换为看弹幕,立即发送更新
|
||||
- 开始生成回复时 → 切换为看镜头或随意,立即发送更新
|
||||
- 生成完毕后 → 看弹幕1秒,然后回到看镜头直到有新消息,状态变化时立即发送更新
|
||||
|
||||
3. 使用方法:
|
||||
# 获取视线管理器
|
||||
watching = watching_manager.get_watching_by_chat_id(chat_id)
|
||||
|
||||
# 收到消息时调用
|
||||
await watching.on_message_received()
|
||||
|
||||
# 开始生成回复时调用
|
||||
await watching.on_reply_start()
|
||||
|
||||
# 生成回复完毕时调用
|
||||
await watching.on_reply_finished()
|
||||
|
||||
4. 自动更新系统:
|
||||
- 状态变化时立即发送type为"watching",data为状态值的websocket消息
|
||||
- 使用定时器自动处理状态转换(如看弹幕时间结束后自动切换到看镜头)
|
||||
- 无需定期检查,所有状态变化都是事件驱动的
|
||||
"""
|
||||
|
||||
logger = get_logger("watching")
|
||||
|
||||
HEAD_CODE = {
|
||||
"看向上方": "(0,0.5,0)",
|
||||
"看向下方": "(0,-0.5,0)",
|
||||
"看向左边": "(-1,0,0)",
|
||||
"看向右边": "(1,0,0)",
|
||||
"随意朝向": "random",
|
||||
"看向摄像机": "camera",
|
||||
"注视对方": "(0,0,0)",
|
||||
"看向正前方": "(0,0,0)",
|
||||
}
|
||||
|
||||
|
||||
class ChatWatching:
|
||||
def __init__(self, chat_id: str):
|
||||
self.chat_id: str = chat_id
|
||||
|
||||
async def on_reply_start(self):
|
||||
"""开始生成回复时调用"""
|
||||
await send_api.custom_to_stream(
|
||||
message_type="state", content="start_thinking", stream_id=self.chat_id, storage_message=False
|
||||
)
|
||||
|
||||
async def on_reply_finished(self):
|
||||
"""生成回复完毕时调用"""
|
||||
await send_api.custom_to_stream(
|
||||
message_type="state", content="finish_reply", stream_id=self.chat_id, storage_message=False
|
||||
)
|
||||
|
||||
async def on_thinking_finished(self):
|
||||
"""思考完毕时调用"""
|
||||
await send_api.custom_to_stream(
|
||||
message_type="state", content="finish_thinking", stream_id=self.chat_id, storage_message=False
|
||||
)
|
||||
|
||||
async def on_message_received(self):
|
||||
"""收到消息时调用"""
|
||||
await send_api.custom_to_stream(
|
||||
message_type="state", content="start_viewing", stream_id=self.chat_id, storage_message=False
|
||||
)
|
||||
|
||||
async def on_internal_message_start(self):
|
||||
"""收到消息时调用"""
|
||||
await send_api.custom_to_stream(
|
||||
message_type="state", content="start_internal_thinking", stream_id=self.chat_id, storage_message=False
|
||||
)
|
||||
|
||||
|
||||
class WatchingManager:
|
||||
def __init__(self):
|
||||
self.watching_list: list[ChatWatching] = []
|
||||
"""当前视线状态列表"""
|
||||
self.task_started: bool = False
|
||||
|
||||
def get_watching_by_chat_id(self, chat_id: str) -> ChatWatching:
|
||||
"""获取或创建聊天对应的视线管理器"""
|
||||
for watching in self.watching_list:
|
||||
if watching.chat_id == chat_id:
|
||||
return watching
|
||||
|
||||
new_watching = ChatWatching(chat_id)
|
||||
self.watching_list.append(new_watching)
|
||||
logger.info(f"为chat {chat_id}创建新的视线管理器")
|
||||
|
||||
return new_watching
|
||||
|
||||
|
||||
# 全局视线管理器实例
|
||||
watching_manager = WatchingManager()
|
||||
"""全局视线管理器"""
|
||||
@@ -1,15 +0,0 @@
|
||||
class ScreenManager:
|
||||
def __init__(self):
|
||||
self.now_screen = ""
|
||||
|
||||
def set_screen(self, screen_str: str):
|
||||
self.now_screen = screen_str
|
||||
|
||||
def get_screen(self):
|
||||
return self.now_screen
|
||||
|
||||
def get_screen_str(self):
|
||||
return f"你可以看见面前的屏幕,目前屏幕的内容是:现在千石可乐在和你一起直播,这是他正在操作的屏幕内容:{self.now_screen}"
|
||||
|
||||
|
||||
screen_manager = ScreenManager()
|
||||
@@ -1,304 +0,0 @@
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
from src.chat.message_receive.message import MessageRecvS4U
|
||||
from src.common.logger import get_logger
|
||||
|
||||
# 全局SuperChat管理器实例
|
||||
from src.mais4u.constant_s4u import ENABLE_S4U
|
||||
|
||||
logger = get_logger("super_chat_manager")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SuperChatRecord:
|
||||
"""SuperChat记录数据类"""
|
||||
|
||||
user_id: str
|
||||
user_nickname: str
|
||||
platform: str
|
||||
chat_id: str
|
||||
price: float
|
||||
message_text: str
|
||||
timestamp: float
|
||||
expire_time: float
|
||||
group_name: str | None = None
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""检查SuperChat是否已过期"""
|
||||
return time.time() > self.expire_time
|
||||
|
||||
def remaining_time(self) -> float:
|
||||
"""获取剩余时间(秒)"""
|
||||
return max(0, self.expire_time - time.time())
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""转换为字典格式"""
|
||||
return {
|
||||
"user_id": self.user_id,
|
||||
"user_nickname": self.user_nickname,
|
||||
"platform": self.platform,
|
||||
"chat_id": self.chat_id,
|
||||
"price": self.price,
|
||||
"message_text": self.message_text,
|
||||
"timestamp": self.timestamp,
|
||||
"expire_time": self.expire_time,
|
||||
"group_name": self.group_name,
|
||||
"remaining_time": self.remaining_time(),
|
||||
}
|
||||
|
||||
|
||||
class SuperChatManager:
|
||||
"""SuperChat管理器,负责管理和跟踪SuperChat消息"""
|
||||
|
||||
def __init__(self):
|
||||
self.super_chats: dict[str, list[SuperChatRecord]] = {} # chat_id -> SuperChat列表
|
||||
self._cleanup_task: asyncio.Task | None = None
|
||||
self._is_initialized = False
|
||||
logger.info("SuperChat管理器已初始化")
|
||||
|
||||
def _ensure_cleanup_task_started(self):
|
||||
"""确保清理任务已启动(延迟启动)"""
|
||||
if self._cleanup_task is None or self._cleanup_task.done():
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
self._cleanup_task = loop.create_task(self._cleanup_expired_superchats())
|
||||
self._is_initialized = True
|
||||
logger.info("SuperChat清理任务已启动")
|
||||
except RuntimeError:
|
||||
# 没有运行的事件循环,稍后再启动
|
||||
logger.debug("当前没有运行的事件循环,将在需要时启动清理任务")
|
||||
|
||||
def _start_cleanup_task(self):
|
||||
"""启动清理任务(已弃用,保留向后兼容)"""
|
||||
self._ensure_cleanup_task_started()
|
||||
|
||||
async def _cleanup_expired_superchats(self):
|
||||
"""定期清理过期的SuperChat"""
|
||||
while True:
|
||||
try:
|
||||
total_removed = 0
|
||||
|
||||
for chat_id in list(self.super_chats.keys()):
|
||||
original_count = len(self.super_chats[chat_id])
|
||||
# 移除过期的SuperChat
|
||||
self.super_chats[chat_id] = [sc for sc in self.super_chats[chat_id] if not sc.is_expired()]
|
||||
|
||||
removed_count = original_count - len(self.super_chats[chat_id])
|
||||
total_removed += removed_count
|
||||
|
||||
if removed_count > 0:
|
||||
logger.info(f"从聊天 {chat_id} 中清理了 {removed_count} 个过期的SuperChat")
|
||||
|
||||
# 如果列表为空,删除该聊天的记录
|
||||
if not self.super_chats[chat_id]:
|
||||
del self.super_chats[chat_id]
|
||||
|
||||
if total_removed > 0:
|
||||
logger.info(f"总共清理了 {total_removed} 个过期的SuperChat")
|
||||
|
||||
# 每30秒检查一次
|
||||
await asyncio.sleep(30)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清理过期SuperChat时出错: {e}", exc_info=True)
|
||||
await asyncio.sleep(60) # 出错时等待更长时间
|
||||
|
||||
@staticmethod
|
||||
def _calculate_expire_time(price: float) -> float:
|
||||
"""根据SuperChat金额计算过期时间"""
|
||||
current_time = time.time()
|
||||
|
||||
# 根据金额阶梯设置不同的存活时间
|
||||
if price >= 500:
|
||||
# 500元以上:保持4小时
|
||||
duration = 4 * 3600
|
||||
elif price >= 200:
|
||||
# 200-499元:保持2小时
|
||||
duration = 2 * 3600
|
||||
elif price >= 100:
|
||||
# 100-199元:保持1小时
|
||||
duration = 1 * 3600
|
||||
elif price >= 50:
|
||||
# 50-99元:保持30分钟
|
||||
duration = 30 * 60
|
||||
elif price >= 20:
|
||||
# 20-49元:保持15分钟
|
||||
duration = 15 * 60
|
||||
elif price >= 10:
|
||||
# 10-19元:保持10分钟
|
||||
duration = 10 * 60
|
||||
else:
|
||||
# 10元以下:保持5分钟
|
||||
duration = 5 * 60
|
||||
|
||||
return current_time + duration
|
||||
|
||||
async def add_superchat(self, message: MessageRecvS4U) -> None:
|
||||
"""添加新的SuperChat记录"""
|
||||
# 确保清理任务已启动
|
||||
self._ensure_cleanup_task_started()
|
||||
|
||||
if not message.is_superchat or not message.superchat_price:
|
||||
logger.warning("尝试添加非SuperChat消息到SuperChat管理器")
|
||||
return
|
||||
|
||||
try:
|
||||
price = float(message.superchat_price)
|
||||
except (ValueError, TypeError):
|
||||
logger.error(f"无效的SuperChat价格: {message.superchat_price}")
|
||||
return
|
||||
|
||||
user_info = message.message_info.user_info
|
||||
group_info = message.message_info.group_info
|
||||
chat_id = getattr(message, "chat_stream", None)
|
||||
if chat_id:
|
||||
chat_id = chat_id.stream_id
|
||||
else:
|
||||
# 生成chat_id的备用方法
|
||||
chat_id = f"{message.message_info.platform}_{user_info.user_id}"
|
||||
if group_info:
|
||||
chat_id = f"{message.message_info.platform}_{group_info.group_id}"
|
||||
|
||||
expire_time = self._calculate_expire_time(price)
|
||||
|
||||
record = SuperChatRecord(
|
||||
user_id=user_info.user_id,
|
||||
user_nickname=user_info.user_nickname,
|
||||
platform=message.message_info.platform,
|
||||
chat_id=chat_id,
|
||||
price=price,
|
||||
message_text=message.superchat_message_text or "",
|
||||
timestamp=message.message_info.time,
|
||||
expire_time=expire_time,
|
||||
group_name=group_info.group_name if group_info else None,
|
||||
)
|
||||
|
||||
# 添加到对应聊天的SuperChat列表
|
||||
if chat_id not in self.super_chats:
|
||||
self.super_chats[chat_id] = []
|
||||
|
||||
self.super_chats[chat_id].append(record)
|
||||
|
||||
# 按价格降序排序(价格高的在前)
|
||||
self.super_chats[chat_id].sort(key=lambda x: x.price, reverse=True)
|
||||
|
||||
logger.info(f"添加SuperChat记录: {user_info.user_nickname} - {price}元 - {message.superchat_message_text}")
|
||||
|
||||
def get_superchats_by_chat(self, chat_id: str) -> list[SuperChatRecord]:
|
||||
"""获取指定聊天的所有有效SuperChat"""
|
||||
# 确保清理任务已启动
|
||||
self._ensure_cleanup_task_started()
|
||||
|
||||
if chat_id not in self.super_chats:
|
||||
return []
|
||||
|
||||
# 过滤掉过期的SuperChat
|
||||
valid_superchats = [sc for sc in self.super_chats[chat_id] if not sc.is_expired()]
|
||||
return valid_superchats
|
||||
|
||||
def get_all_valid_superchats(self) -> dict[str, list[SuperChatRecord]]:
|
||||
"""获取所有有效的SuperChat"""
|
||||
# 确保清理任务已启动
|
||||
self._ensure_cleanup_task_started()
|
||||
|
||||
result = {}
|
||||
for chat_id, superchats in self.super_chats.items():
|
||||
valid_superchats = [sc for sc in superchats if not sc.is_expired()]
|
||||
if valid_superchats:
|
||||
result[chat_id] = valid_superchats
|
||||
return result
|
||||
|
||||
def build_superchat_display_string(self, chat_id: str, max_count: int = 10) -> str:
|
||||
"""构建SuperChat显示字符串"""
|
||||
superchats = self.get_superchats_by_chat(chat_id)
|
||||
|
||||
if not superchats:
|
||||
return ""
|
||||
|
||||
# 限制显示数量
|
||||
display_superchats = superchats[:max_count]
|
||||
|
||||
lines = ["📢 当前有效超级弹幕:"]
|
||||
for i, sc in enumerate(display_superchats, 1):
|
||||
remaining_minutes = int(sc.remaining_time() / 60)
|
||||
remaining_seconds = int(sc.remaining_time() % 60)
|
||||
|
||||
time_display = (
|
||||
f"{remaining_minutes}分{remaining_seconds}秒" if remaining_minutes > 0 else f"{remaining_seconds}秒"
|
||||
)
|
||||
|
||||
line = f"{i}. 【{sc.price}元】{sc.user_nickname}: {sc.message_text}"
|
||||
if len(line) > 100: # 限制单行长度
|
||||
line = f"{line[:97]}..."
|
||||
line += f" (剩余{time_display})"
|
||||
lines.append(line)
|
||||
|
||||
if len(superchats) > max_count:
|
||||
lines.append(f"... 还有{len(superchats) - max_count}条SuperChat")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def build_superchat_summary_string(self, chat_id: str) -> str:
|
||||
"""构建SuperChat摘要字符串"""
|
||||
superchats = self.get_superchats_by_chat(chat_id)
|
||||
|
||||
if not superchats:
|
||||
return "当前没有有效的超级弹幕"
|
||||
lines = []
|
||||
for sc in superchats:
|
||||
single_sc_str = f"{sc.user_nickname} - {sc.price}元 - {sc.message_text}"
|
||||
if len(single_sc_str) > 100:
|
||||
single_sc_str = f"{single_sc_str[:97]}..."
|
||||
single_sc_str += f" (剩余{int(sc.remaining_time())}秒)"
|
||||
lines.append(single_sc_str)
|
||||
|
||||
total_amount = sum(sc.price for sc in superchats)
|
||||
count = len(superchats)
|
||||
highest_amount = max(sc.price for sc in superchats)
|
||||
|
||||
final_str = f"当前有{count}条超级弹幕,总金额{total_amount}元,最高单笔{highest_amount}元"
|
||||
if lines:
|
||||
final_str += "\n" + "\n".join(lines)
|
||||
return final_str
|
||||
|
||||
def get_superchat_statistics(self, chat_id: str) -> dict:
|
||||
"""获取SuperChat统计信息"""
|
||||
superchats = self.get_superchats_by_chat(chat_id)
|
||||
|
||||
if not superchats:
|
||||
return {"count": 0, "total_amount": 0, "average_amount": 0, "highest_amount": 0, "lowest_amount": 0}
|
||||
|
||||
amounts = [sc.price for sc in superchats]
|
||||
|
||||
return {
|
||||
"count": len(superchats),
|
||||
"total_amount": sum(amounts),
|
||||
"average_amount": sum(amounts) / len(amounts),
|
||||
"highest_amount": max(amounts),
|
||||
"lowest_amount": min(amounts),
|
||||
}
|
||||
|
||||
async def shutdown(self): # sourcery skip: use-contextlib-suppress
|
||||
"""关闭管理器,清理资源"""
|
||||
if self._cleanup_task and not self._cleanup_task.done():
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("SuperChat管理器已关闭")
|
||||
|
||||
|
||||
# sourcery skip: assign-if-exp
|
||||
if ENABLE_S4U:
|
||||
super_chat_manager = SuperChatManager()
|
||||
else:
|
||||
super_chat_manager = None
|
||||
|
||||
|
||||
def get_super_chat_manager() -> SuperChatManager:
|
||||
"""获取全局SuperChat管理器实例"""
|
||||
|
||||
return super_chat_manager
|
||||
@@ -1,46 +0,0 @@
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.plugin_system.apis import send_api
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
head_actions_list = ["不做额外动作", "点头一次", "点头两次", "摇头", "歪脑袋", "低头望向一边"]
|
||||
|
||||
|
||||
async def yes_or_no_head(text: str, emotion: str = "", chat_history: str = "", chat_id: str = ""):
|
||||
prompt = f"""
|
||||
{chat_history}
|
||||
以上是对方的发言:
|
||||
|
||||
对这个发言,你的心情是:{emotion}
|
||||
对上面的发言,你的回复是:{text}
|
||||
请判断时是否要伴随回复做头部动作,你可以选择:
|
||||
|
||||
不做额外动作
|
||||
点头一次
|
||||
点头两次
|
||||
摇头
|
||||
歪脑袋
|
||||
低头望向一边
|
||||
|
||||
请从上面的动作中选择一个,并输出,请只输出你选择的动作就好,不要输出其他内容。"""
|
||||
model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="motion")
|
||||
|
||||
try:
|
||||
# logger.info(f"prompt: {prompt}")
|
||||
response, _ = await model.generate_response_async(prompt=prompt, temperature=0.7)
|
||||
logger.info(f"response: {response}")
|
||||
|
||||
head_action = response if response in head_actions_list else "不做额外动作"
|
||||
await send_api.custom_to_stream(
|
||||
message_type="head_action",
|
||||
content=head_action,
|
||||
stream_id=chat_id,
|
||||
storage_message=False,
|
||||
show_log=True,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"yes_or_no_head error: {e}")
|
||||
return "不做额外动作"
|
||||
@@ -1,287 +0,0 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatMessage:
|
||||
"""聊天消息数据类"""
|
||||
|
||||
role: str
|
||||
content: str
|
||||
|
||||
def to_dict(self) -> dict[str, str]:
|
||||
return {"role": self.role, "content": self.content}
|
||||
|
||||
|
||||
class AsyncOpenAIClient:
|
||||
"""异步OpenAI客户端,支持流式传输"""
|
||||
|
||||
def __init__(self, api_key: str, base_url: str | None = None):
|
||||
"""
|
||||
初始化客户端
|
||||
|
||||
Args:
|
||||
api_key: OpenAI API密钥
|
||||
base_url: 可选的API基础URL,用于自定义端点
|
||||
"""
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
timeout=10.0, # 设置60秒的全局超时
|
||||
)
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
messages: list[ChatMessage | dict[str, str]],
|
||||
model: str = "gpt-3.5-turbo",
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int | None = None,
|
||||
**kwargs,
|
||||
) -> ChatCompletion:
|
||||
"""
|
||||
非流式聊天完成
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
model: 模型名称
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大token数
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
完整的聊天回复
|
||||
"""
|
||||
# 转换消息格式
|
||||
formatted_messages = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, ChatMessage):
|
||||
formatted_messages.append(msg.to_dict())
|
||||
else:
|
||||
formatted_messages.append(msg)
|
||||
|
||||
extra_body = {}
|
||||
if kwargs.get("enable_thinking") is not None:
|
||||
extra_body["enable_thinking"] = kwargs.pop("enable_thinking")
|
||||
if kwargs.get("thinking_budget") is not None:
|
||||
extra_body["thinking_budget"] = kwargs.pop("thinking_budget")
|
||||
|
||||
response = await self.client.chat.completions.create(
|
||||
model=model,
|
||||
messages=formatted_messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=False,
|
||||
extra_body=extra_body if extra_body else None,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
async def chat_completion_stream(
|
||||
self,
|
||||
messages: list[ChatMessage | dict[str, str]],
|
||||
model: str = "gpt-3.5-turbo",
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int | None = None,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[ChatCompletionChunk, None]:
|
||||
"""
|
||||
流式聊天完成
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
model: 模型名称
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大token数
|
||||
**kwargs: 其他参数
|
||||
|
||||
Yields:
|
||||
ChatCompletionChunk: 流式响应块
|
||||
"""
|
||||
# 转换消息格式
|
||||
formatted_messages = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, ChatMessage):
|
||||
formatted_messages.append(msg.to_dict())
|
||||
else:
|
||||
formatted_messages.append(msg)
|
||||
|
||||
extra_body = {}
|
||||
if kwargs.get("enable_thinking") is not None:
|
||||
extra_body["enable_thinking"] = kwargs.pop("enable_thinking")
|
||||
if kwargs.get("thinking_budget") is not None:
|
||||
extra_body["thinking_budget"] = kwargs.pop("thinking_budget")
|
||||
|
||||
stream = await self.client.chat.completions.create(
|
||||
model=model,
|
||||
messages=formatted_messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
stream=True,
|
||||
extra_body=extra_body if extra_body else None,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async for chunk in stream:
|
||||
yield chunk
|
||||
|
||||
async def get_stream_content(
|
||||
self,
|
||||
messages: list[ChatMessage | dict[str, str]],
|
||||
model: str = "gpt-3.5-turbo",
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int | None = None,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
获取流式内容(只返回文本内容)
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
model: 模型名称
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大token数
|
||||
**kwargs: 其他参数
|
||||
|
||||
Yields:
|
||||
str: 文本内容片段
|
||||
"""
|
||||
async for chunk in self.chat_completion_stream(
|
||||
messages=messages, model=model, temperature=temperature, max_tokens=max_tokens, **kwargs
|
||||
):
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
yield chunk.choices[0].delta.content
|
||||
|
||||
async def collect_stream_response(
|
||||
self,
|
||||
messages: list[ChatMessage | dict[str, str]],
|
||||
model: str = "gpt-3.5-turbo",
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int | None = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
收集完整的流式响应
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
model: 模型名称
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大token数
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
str: 完整的响应文本
|
||||
"""
|
||||
full_response = ""
|
||||
async for content in self.get_stream_content(
|
||||
messages=messages, model=model, temperature=temperature, max_tokens=max_tokens, **kwargs
|
||||
):
|
||||
full_response += content
|
||||
|
||||
return full_response
|
||||
|
||||
async def close(self):
|
||||
"""关闭客户端"""
|
||||
await self.client.close()
|
||||
|
||||
async def __aenter__(self):
|
||||
"""异步上下文管理器入口"""
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""异步上下文管理器退出"""
|
||||
await self.close()
|
||||
|
||||
|
||||
class ConversationManager:
|
||||
"""对话管理器,用于管理对话历史"""
|
||||
|
||||
def __init__(self, client: AsyncOpenAIClient, system_prompt: str | None = None):
|
||||
"""
|
||||
初始化对话管理器
|
||||
|
||||
Args:
|
||||
client: OpenAI客户端实例
|
||||
system_prompt: 系统提示词
|
||||
"""
|
||||
self.client = client
|
||||
self.messages: list[ChatMessage] = []
|
||||
|
||||
if system_prompt:
|
||||
self.messages.append(ChatMessage(role="system", content=system_prompt))
|
||||
|
||||
def add_user_message(self, content: str):
|
||||
"""添加用户消息"""
|
||||
self.messages.append(ChatMessage(role="user", content=content))
|
||||
|
||||
def add_assistant_message(self, content: str):
|
||||
"""添加助手消息"""
|
||||
self.messages.append(ChatMessage(role="assistant", content=content))
|
||||
|
||||
async def send_message_stream(
|
||||
self, content: str, model: str = "gpt-3.5-turbo", **kwargs
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
发送消息并获取流式响应
|
||||
|
||||
Args:
|
||||
content: 用户消息内容
|
||||
model: 模型名称
|
||||
**kwargs: 其他参数
|
||||
|
||||
Yields:
|
||||
str: 响应内容片段
|
||||
"""
|
||||
self.add_user_message(content)
|
||||
|
||||
response_content = ""
|
||||
async for chunk in self.client.get_stream_content(messages=self.messages, model=model, **kwargs):
|
||||
response_content += chunk
|
||||
yield chunk
|
||||
|
||||
self.add_assistant_message(response_content)
|
||||
|
||||
async def send_message(self, content: str, model: str = "gpt-3.5-turbo", **kwargs) -> str:
|
||||
"""
|
||||
发送消息并获取完整响应
|
||||
|
||||
Args:
|
||||
content: 用户消息内容
|
||||
model: 模型名称
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
str: 完整响应
|
||||
"""
|
||||
self.add_user_message(content)
|
||||
|
||||
response = await self.client.chat_completion(messages=self.messages, model=model, **kwargs)
|
||||
|
||||
response_content = response.choices[0].message.content
|
||||
self.add_assistant_message(response_content)
|
||||
|
||||
return response_content
|
||||
|
||||
def clear_history(self, keep_system: bool = True):
|
||||
"""
|
||||
清除对话历史
|
||||
|
||||
Args:
|
||||
keep_system: 是否保留系统消息
|
||||
"""
|
||||
if keep_system and self.messages and self.messages[0].role == "system":
|
||||
self.messages = [self.messages[0]]
|
||||
else:
|
||||
self.messages = []
|
||||
|
||||
def get_message_count(self) -> int:
|
||||
"""获取消息数量"""
|
||||
return len(self.messages)
|
||||
|
||||
def get_conversation_history(self) -> list[dict[str, str]]:
|
||||
"""获取对话历史"""
|
||||
return [msg.to_dict() for msg in self.messages]
|
||||
@@ -1,373 +0,0 @@
|
||||
import os
|
||||
import shutil
|
||||
from dataclasses import MISSING, dataclass, field, fields
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal, TypeVar, get_args, get_origin
|
||||
|
||||
import tomlkit
|
||||
from tomlkit import TOMLDocument
|
||||
from tomlkit.items import Table
|
||||
from typing_extensions import Self
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.mais4u.constant_s4u import ENABLE_S4U
|
||||
|
||||
logger = get_logger("s4u_config")
|
||||
|
||||
|
||||
# 新增:兼容dict和tomlkit Table
|
||||
def is_dict_like(obj):
|
||||
return isinstance(obj, dict | Table)
|
||||
|
||||
|
||||
# 新增:递归将Table转为dict
|
||||
def table_to_dict(obj):
|
||||
if isinstance(obj, Table):
|
||||
return {k: table_to_dict(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, dict):
|
||||
return {k: table_to_dict(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return [table_to_dict(i) for i in obj]
|
||||
else:
|
||||
return obj
|
||||
|
||||
|
||||
# 获取mais4u模块目录
|
||||
MAIS4U_ROOT = os.path.dirname(__file__)
|
||||
CONFIG_DIR = os.path.join(MAIS4U_ROOT, "config")
|
||||
TEMPLATE_PATH = os.path.join(CONFIG_DIR, "s4u_config_template.toml")
|
||||
CONFIG_PATH = os.path.join(CONFIG_DIR, "s4u_config.toml")
|
||||
|
||||
# S4U配置版本
|
||||
S4U_VERSION = "1.1.0"
|
||||
|
||||
T = TypeVar("T", bound="S4UConfigBase")
|
||||
|
||||
|
||||
@dataclass
|
||||
class S4UConfigBase:
|
||||
"""S4U配置类的基类"""
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> Self:
|
||||
"""从字典加载配置字段"""
|
||||
data = table_to_dict(data) # 递归转dict,兼容tomlkit Table
|
||||
if not is_dict_like(data):
|
||||
raise TypeError(f"Expected a dictionary, got {type(data).__name__}")
|
||||
|
||||
init_args: dict[str, Any] = {}
|
||||
|
||||
for f in fields(cls):
|
||||
field_name = f.name
|
||||
|
||||
if field_name.startswith("_"):
|
||||
# 跳过以 _ 开头的字段
|
||||
continue
|
||||
|
||||
if field_name not in data:
|
||||
if f.default is not MISSING or f.default_factory is not MISSING:
|
||||
# 跳过未提供且有默认值/默认构造方法的字段
|
||||
continue
|
||||
else:
|
||||
raise ValueError(f"Missing required field: '{field_name}'")
|
||||
|
||||
value = data[field_name]
|
||||
field_type = f.type
|
||||
|
||||
try:
|
||||
init_args[field_name] = cls._convert_field(value, field_type) # type: ignore
|
||||
except TypeError as e:
|
||||
raise TypeError(f"Field '{field_name}' has a type error: {e}") from e
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to convert field '{field_name}' to target type: {e}") from e
|
||||
|
||||
return cls()
|
||||
|
||||
@classmethod
|
||||
def _convert_field(cls, value: Any, field_type: type[Any]) -> Any:
|
||||
"""转换字段值为指定类型"""
|
||||
# 如果是嵌套的 dataclass,递归调用 from_dict 方法
|
||||
if isinstance(field_type, type) and issubclass(field_type, S4UConfigBase):
|
||||
if not is_dict_like(value):
|
||||
raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
|
||||
return field_type.from_dict(value)
|
||||
|
||||
# 处理泛型集合类型(list, set, tuple)
|
||||
field_origin_type = get_origin(field_type)
|
||||
field_type_args = get_args(field_type)
|
||||
|
||||
if field_origin_type in {list, set, tuple}:
|
||||
if not isinstance(value, list):
|
||||
raise TypeError(f"Expected an list for {field_type.__name__}, got {type(value).__name__}")
|
||||
|
||||
if field_origin_type is list:
|
||||
if (
|
||||
field_type_args
|
||||
and isinstance(field_type_args[0], type)
|
||||
and issubclass(field_type_args[0], S4UConfigBase)
|
||||
):
|
||||
return [field_type_args[0].from_dict(item) for item in value]
|
||||
return [cls._convert_field(item, field_type_args[0]) for item in value]
|
||||
elif field_origin_type is set:
|
||||
return {cls._convert_field(item, field_type_args[0]) for item in value}
|
||||
elif field_origin_type is tuple:
|
||||
if len(value) != len(field_type_args):
|
||||
raise TypeError(
|
||||
f"Expected {len(field_type_args)} items for {field_type.__name__}, got {len(value)}"
|
||||
)
|
||||
return tuple(cls._convert_field(item, arg) for item, arg in zip(value, field_type_args, strict=False))
|
||||
|
||||
if field_origin_type is dict:
|
||||
if not is_dict_like(value):
|
||||
raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
|
||||
|
||||
if len(field_type_args) != 2:
|
||||
raise TypeError(f"Expected a dictionary with two type arguments for {field_type.__name__}")
|
||||
key_type, value_type = field_type_args
|
||||
|
||||
return {cls._convert_field(k, key_type): cls._convert_field(v, value_type) for k, v in value.items()}
|
||||
|
||||
# 处理基础类型,例如 int, str 等
|
||||
if field_origin_type is type(None) and value is None: # 处理Optional类型
|
||||
return None
|
||||
|
||||
# 处理Literal类型
|
||||
if field_origin_type is Literal or get_origin(field_type) is Literal:
|
||||
allowed_values = get_args(field_type)
|
||||
if value in allowed_values:
|
||||
return value
|
||||
else:
|
||||
raise TypeError(f"Value '{value}' is not in allowed values {allowed_values} for Literal type")
|
||||
|
||||
if field_type is Any or isinstance(value, field_type):
|
||||
return value
|
||||
|
||||
# 其他类型,尝试直接转换
|
||||
try:
|
||||
return field_type(value)
|
||||
except (ValueError, TypeError) as e:
|
||||
raise TypeError(f"Cannot convert {type(value).__name__} to {field_type.__name__}") from e
|
||||
|
||||
|
||||
@dataclass
|
||||
class S4UModelConfig(S4UConfigBase):
|
||||
"""S4U模型配置类"""
|
||||
|
||||
# 主要对话模型配置
|
||||
chat: dict[str, Any] = field(default_factory=lambda: {})
|
||||
"""主要对话模型配置"""
|
||||
|
||||
# 规划模型配置(原model_motion)
|
||||
motion: dict[str, Any] = field(default_factory=lambda: {})
|
||||
"""规划模型配置"""
|
||||
|
||||
# 情感分析模型配置
|
||||
emotion: dict[str, Any] = field(default_factory=lambda: {})
|
||||
"""情感分析模型配置"""
|
||||
|
||||
# 记忆模型配置
|
||||
memory: dict[str, Any] = field(default_factory=lambda: {})
|
||||
"""记忆模型配置"""
|
||||
|
||||
# 工具使用模型配置
|
||||
tool_use: dict[str, Any] = field(default_factory=lambda: {})
|
||||
"""工具使用模型配置"""
|
||||
|
||||
# 嵌入模型配置
|
||||
embedding: dict[str, Any] = field(default_factory=lambda: {})
|
||||
"""嵌入模型配置"""
|
||||
|
||||
# 视觉语言模型配置
|
||||
vlm: dict[str, Any] = field(default_factory=lambda: {})
|
||||
"""视觉语言模型配置"""
|
||||
|
||||
# 知识库模型配置
|
||||
knowledge: dict[str, Any] = field(default_factory=lambda: {})
|
||||
"""知识库模型配置"""
|
||||
|
||||
# 实体提取模型配置
|
||||
entity_extract: dict[str, Any] = field(default_factory=lambda: {})
|
||||
"""实体提取模型配置"""
|
||||
|
||||
# 问答模型配置
|
||||
qa: dict[str, Any] = field(default_factory=lambda: {})
|
||||
"""问答模型配置"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class S4UConfig(S4UConfigBase):
|
||||
"""S4U聊天系统配置类"""
|
||||
|
||||
message_timeout_seconds: int = 120
|
||||
"""普通消息存活时间(秒),超过此时间的消息将被丢弃"""
|
||||
|
||||
at_bot_priority_bonus: float = 100.0
|
||||
"""@机器人时的优先级加成分数"""
|
||||
|
||||
recent_message_keep_count: int = 6
|
||||
"""保留最近N条消息,超出范围的普通消息将被移除"""
|
||||
|
||||
typing_delay: float = 0.1
|
||||
"""打字延迟时间(秒),模拟真实打字速度"""
|
||||
|
||||
chars_per_second: float = 15.0
|
||||
"""每秒字符数,用于计算动态打字延迟"""
|
||||
|
||||
min_typing_delay: float = 0.2
|
||||
"""最小打字延迟(秒)"""
|
||||
|
||||
max_typing_delay: float = 2.0
|
||||
"""最大打字延迟(秒)"""
|
||||
|
||||
enable_dynamic_typing_delay: bool = False
|
||||
"""是否启用基于文本长度的动态打字延迟"""
|
||||
|
||||
vip_queue_priority: bool = True
|
||||
"""是否启用VIP队列优先级系统"""
|
||||
|
||||
enable_message_interruption: bool = True
|
||||
"""是否允许高优先级消息中断当前回复"""
|
||||
|
||||
enable_old_message_cleanup: bool = True
|
||||
"""是否自动清理过旧的普通消息"""
|
||||
|
||||
enable_streaming_output: bool = True
|
||||
"""是否启用流式输出,false时全部生成后一次性发送"""
|
||||
|
||||
max_context_message_length: int = 20
|
||||
"""上下文消息最大长度"""
|
||||
|
||||
max_core_message_length: int = 30
|
||||
"""核心消息最大长度"""
|
||||
|
||||
# 模型配置
|
||||
models: S4UModelConfig = field(default_factory=S4UModelConfig)
|
||||
"""S4U模型配置"""
|
||||
|
||||
# 兼容性字段,保持向后兼容
|
||||
|
||||
|
||||
@dataclass
|
||||
class S4UGlobalConfig(S4UConfigBase):
|
||||
"""S4U总配置类"""
|
||||
|
||||
s4u: S4UConfig
|
||||
S4U_VERSION: str = S4U_VERSION
|
||||
|
||||
|
||||
def update_s4u_config():
|
||||
"""更新S4U配置文件"""
|
||||
# 创建配置目录(如果不存在)
|
||||
os.makedirs(CONFIG_DIR, exist_ok=True)
|
||||
|
||||
# 检查模板文件是否存在
|
||||
if not os.path.exists(TEMPLATE_PATH):
|
||||
logger.error(f"S4U配置模板文件不存在: {TEMPLATE_PATH}")
|
||||
logger.error("请确保模板文件存在后重新运行")
|
||||
raise FileNotFoundError(f"S4U配置模板文件不存在: {TEMPLATE_PATH}")
|
||||
|
||||
# 检查配置文件是否存在
|
||||
if not os.path.exists(CONFIG_PATH):
|
||||
logger.info("S4U配置文件不存在,从模板创建新配置")
|
||||
shutil.copy2(TEMPLATE_PATH, CONFIG_PATH)
|
||||
logger.info(f"已创建S4U配置文件: {CONFIG_PATH}")
|
||||
return
|
||||
|
||||
# 读取旧配置文件和模板文件
|
||||
with open(CONFIG_PATH, encoding="utf-8") as f:
|
||||
old_config = tomlkit.load(f)
|
||||
with open(TEMPLATE_PATH, encoding="utf-8") as f:
|
||||
new_config = tomlkit.load(f)
|
||||
|
||||
# 检查version是否相同
|
||||
if old_config and "inner" in old_config and "inner" in new_config:
|
||||
old_version = old_config["inner"].get("version") # type: ignore
|
||||
new_version = new_config["inner"].get("version") # type: ignore
|
||||
if old_version and new_version and old_version == new_version:
|
||||
logger.info(f"检测到S4U配置文件版本号相同 (v{old_version}),跳过更新")
|
||||
return
|
||||
else:
|
||||
logger.info(f"检测到S4U配置版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}")
|
||||
else:
|
||||
logger.info("S4U配置文件未检测到版本号,可能是旧版本。将进行更新")
|
||||
|
||||
# 创建备份目录
|
||||
old_config_dir = os.path.join(CONFIG_DIR, "old")
|
||||
os.makedirs(old_config_dir, exist_ok=True)
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
old_backup_path = os.path.join(old_config_dir, f"s4u_config_{timestamp}.toml")
|
||||
|
||||
# 移动旧配置文件到old目录
|
||||
shutil.move(CONFIG_PATH, old_backup_path)
|
||||
logger.info(f"已备份旧S4U配置文件到: {old_backup_path}")
|
||||
|
||||
# 复制模板文件到配置目录
|
||||
shutil.copy2(TEMPLATE_PATH, CONFIG_PATH)
|
||||
logger.info(f"已创建新S4U配置文件: {CONFIG_PATH}")
|
||||
|
||||
def update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict):
|
||||
"""
|
||||
将source字典的值更新到target字典中(如果target中存在相同的键)
|
||||
"""
|
||||
for key, value in source.items():
|
||||
# 跳过version字段的更新
|
||||
if key == "version":
|
||||
continue
|
||||
if key in target:
|
||||
target_value = target[key]
|
||||
if isinstance(value, dict) and isinstance(target_value, dict | Table):
|
||||
update_dict(target_value, value)
|
||||
else:
|
||||
try:
|
||||
# 对数组类型进行特殊处理
|
||||
if isinstance(value, list):
|
||||
target[key] = tomlkit.array(str(value)) if value else tomlkit.array()
|
||||
else:
|
||||
# 其他类型使用item方法创建新值
|
||||
target[key] = tomlkit.item(value)
|
||||
except (TypeError, ValueError):
|
||||
# 如果转换失败,直接赋值
|
||||
target[key] = value
|
||||
|
||||
# 将旧配置的值更新到新配置中
|
||||
logger.info("开始合并S4U新旧配置...")
|
||||
update_dict(new_config, old_config)
|
||||
|
||||
# 保存更新后的配置(保留注释和格式)
|
||||
with open(CONFIG_PATH, "w", encoding="utf-8") as f:
|
||||
f.write(tomlkit.dumps(new_config))
|
||||
|
||||
logger.info("S4U配置文件更新完成")
|
||||
|
||||
|
||||
def load_s4u_config(config_path: str) -> S4UGlobalConfig:
|
||||
"""
|
||||
加载S4U配置文件
|
||||
:param config_path: 配置文件路径
|
||||
:return: S4UGlobalConfig对象
|
||||
"""
|
||||
# 读取配置文件
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
config_data = tomlkit.load(f)
|
||||
|
||||
# 创建S4UGlobalConfig对象
|
||||
try:
|
||||
return S4UGlobalConfig.from_dict(config_data)
|
||||
except Exception as e:
|
||||
logger.critical("S4U配置文件解析失败")
|
||||
raise e
|
||||
|
||||
|
||||
if not ENABLE_S4U:
|
||||
s4u_config = None
|
||||
s4u_config_main = None
|
||||
else:
|
||||
# 初始化S4U配置
|
||||
logger.info(f"S4U当前版本: {S4U_VERSION}")
|
||||
update_s4u_config()
|
||||
|
||||
logger.info("正在加载S4U配置文件...")
|
||||
s4u_config_main = load_s4u_config(config_path=CONFIG_PATH)
|
||||
logger.info("S4U配置文件加载完成!")
|
||||
|
||||
s4u_config: S4UConfig = s4u_config_main.s4u
|
||||
Reference in New Issue
Block a user