feat:添加新的message类为s4u服务,添加s4u config,添加sc和gift的解析,修复关系构建的一些问题
This commit is contained in:
@@ -205,7 +205,6 @@ class MongoToSQLiteMigrator:
|
||||
"user_info.user_nickname": "user_nickname",
|
||||
"user_info.user_cardname": "user_cardname",
|
||||
"processed_plain_text": "processed_plain_text",
|
||||
"detailed_plain_text": "detailed_plain_text",
|
||||
"memorized_times": "memorized_times",
|
||||
},
|
||||
enable_validation=False, # 禁用数据验证
|
||||
|
||||
@@ -9,7 +9,7 @@ from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.mood.mood_manager import mood_manager # 导入情绪管理器
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
@@ -141,6 +141,29 @@ class ChatBot:
|
||||
logger.error(f"处理命令时出错: {e}")
|
||||
return False, None, True # 出错时继续处理消息
|
||||
|
||||
async def do_s4u(self, message_data: Dict[str, Any]):
|
||||
message = MessageRecvS4U(message_data)
|
||||
group_info = message.message_info.group_info
|
||||
user_info = message.message_info.user_info
|
||||
|
||||
|
||||
get_chat_manager().register_message(message)
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform=message.message_info.platform, # type: ignore
|
||||
user_info=user_info, # type: ignore
|
||||
group_info=group_info,
|
||||
)
|
||||
|
||||
message.update_chat_stream(chat)
|
||||
|
||||
# 处理消息内容
|
||||
await message.process()
|
||||
|
||||
await self.s4u_message_processor.process_message(message)
|
||||
|
||||
return
|
||||
|
||||
|
||||
async def message_process(self, message_data: Dict[str, Any]) -> None:
|
||||
"""处理转化后的统一格式消息
|
||||
这个函数本质是预处理一些数据,根据配置信息和消息内容,预处理消息,并分发到合适的消息处理器中
|
||||
@@ -159,6 +182,10 @@ class ChatBot:
|
||||
# 确保所有任务已启动
|
||||
await self._ensure_started()
|
||||
|
||||
if ENABLE_S4U_CHAT:
|
||||
await self.do_s4u(message_data)
|
||||
return
|
||||
|
||||
if message_data["message_info"].get("group_info") is not None:
|
||||
message_data["message_info"]["group_info"]["group_id"] = str(
|
||||
message_data["message_info"]["group_info"]["group_id"]
|
||||
@@ -221,11 +248,6 @@ class ChatBot:
|
||||
template_group_name = None
|
||||
|
||||
async def preprocess():
|
||||
if ENABLE_S4U_CHAT:
|
||||
logger.info("进入S4U流程")
|
||||
await self.s4u_message_processor.process_message(message)
|
||||
return
|
||||
|
||||
await self.heartflow_message_receiver.process_message(message)
|
||||
|
||||
if template_group_name:
|
||||
|
||||
@@ -38,7 +38,6 @@ class Message(MessageBase):
|
||||
message_segment: Optional[Seg] = None,
|
||||
timestamp: Optional[float] = None,
|
||||
reply: Optional["MessageRecv"] = None,
|
||||
detailed_plain_text: str = "",
|
||||
processed_plain_text: str = "",
|
||||
):
|
||||
# 使用传入的时间戳或当前时间
|
||||
@@ -58,7 +57,6 @@ class Message(MessageBase):
|
||||
self.chat_stream = chat_stream
|
||||
# 文本处理相关属性
|
||||
self.processed_plain_text = processed_plain_text
|
||||
self.detailed_plain_text = detailed_plain_text
|
||||
|
||||
# 回复消息
|
||||
self.reply = reply
|
||||
@@ -104,7 +102,6 @@ class MessageRecv(Message):
|
||||
self.message_segment = Seg.from_dict(message_dict.get("message_segment", {}))
|
||||
self.raw_message = message_dict.get("raw_message")
|
||||
self.processed_plain_text = message_dict.get("processed_plain_text", "")
|
||||
self.detailed_plain_text = message_dict.get("detailed_plain_text", "")
|
||||
self.is_emoji = False
|
||||
self.has_emoji = False
|
||||
self.is_picid = False
|
||||
@@ -123,7 +120,6 @@ class MessageRecv(Message):
|
||||
这个方法必须在创建实例后显式调用,因为它包含异步操作。
|
||||
"""
|
||||
self.processed_plain_text = await self._process_message_segments(self.message_segment)
|
||||
self.detailed_plain_text = self._generate_detailed_text()
|
||||
|
||||
async def _process_single_segment(self, segment: Seg) -> str:
|
||||
"""处理单个消息段
|
||||
@@ -182,12 +178,97 @@ class MessageRecv(Message):
|
||||
logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}")
|
||||
return f"[处理失败的{segment.type}消息]"
|
||||
|
||||
def _generate_detailed_text(self) -> str:
|
||||
"""生成详细文本,包含时间和用户信息"""
|
||||
timestamp = self.message_info.time
|
||||
user_info = self.message_info.user_info
|
||||
name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>" # type: ignore
|
||||
return f"[{timestamp}] {name}: {self.processed_plain_text}\n"
|
||||
@dataclass
|
||||
class MessageRecvS4U(MessageRecv):
|
||||
def __init__(self, message_dict: dict[str, Any]):
|
||||
super().__init__(message_dict)
|
||||
self.is_gift = False
|
||||
self.is_superchat = False
|
||||
self.gift_info = None
|
||||
self.gift_name = None
|
||||
self.gift_count = None
|
||||
self.superchat_info = None
|
||||
self.superchat_price = None
|
||||
self.superchat_message_text = None
|
||||
|
||||
async def process(self) -> None:
|
||||
self.processed_plain_text = await self._process_message_segments(self.message_segment)
|
||||
|
||||
async def _process_single_segment(self, segment: Seg) -> str:
|
||||
"""处理单个消息段
|
||||
|
||||
Args:
|
||||
segment: 消息段
|
||||
|
||||
Returns:
|
||||
str: 处理后的文本
|
||||
"""
|
||||
try:
|
||||
if segment.type == "text":
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
return segment.data # type: ignore
|
||||
elif segment.type == "image":
|
||||
# 如果是base64图片数据
|
||||
if isinstance(segment.data, str):
|
||||
self.has_picid = True
|
||||
self.is_picid = True
|
||||
self.is_emoji = False
|
||||
image_manager = get_image_manager()
|
||||
# print(f"segment.data: {segment.data}")
|
||||
_, processed_text = await image_manager.process_image(segment.data)
|
||||
return processed_text
|
||||
return "[发了一张图片,网卡了加载不出来]"
|
||||
elif segment.type == "emoji":
|
||||
self.has_emoji = True
|
||||
self.is_emoji = True
|
||||
self.is_picid = False
|
||||
if isinstance(segment.data, str):
|
||||
return await get_image_manager().get_emoji_description(segment.data)
|
||||
return "[发了一个表情包,网卡了加载不出来]"
|
||||
elif segment.type == "mention_bot":
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
self.is_mentioned = float(segment.data) # type: ignore
|
||||
return ""
|
||||
elif segment.type == "priority_info":
|
||||
self.is_picid = False
|
||||
self.is_emoji = False
|
||||
if isinstance(segment.data, dict):
|
||||
# 处理优先级信息
|
||||
self.priority_mode = "priority"
|
||||
self.priority_info = segment.data
|
||||
"""
|
||||
{
|
||||
'message_type': 'vip', # vip or normal
|
||||
'message_priority': 1.0, # 优先级,大为优先,float
|
||||
}
|
||||
"""
|
||||
return ""
|
||||
elif segment.type == "gift":
|
||||
self.is_gift = True
|
||||
# 解析gift_info,格式为"名称:数量"
|
||||
name, count = segment.data.split(":", 1)
|
||||
self.gift_info = segment.data
|
||||
self.gift_name = name.strip()
|
||||
self.gift_count = int(count.strip())
|
||||
return ""
|
||||
elif segment.type == "superchat":
|
||||
self.is_superchat = True
|
||||
self.superchat_info = segment.data
|
||||
price,message_text = segment.data.split(":", 1)
|
||||
self.superchat_price = price.strip()
|
||||
self.superchat_message_text = message_text.strip()
|
||||
|
||||
self.processed_plain_text = str(self.superchat_message_text)
|
||||
self.processed_plain_text += f"(注意:这是一条超级弹幕信息,价值{self.superchat_price}元,请你认真回复)"
|
||||
|
||||
return self.processed_plain_text
|
||||
else:
|
||||
return ""
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}")
|
||||
return f"[处理失败的{segment.type}消息]"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -472,7 +553,6 @@ def message_from_db_dict(db_dict: dict) -> MessageRecv:
|
||||
"message_segment": {"type": "text", "data": processed_text}, # 从纯文本重建消息段
|
||||
"raw_message": None, # 数据库中未存储原始消息
|
||||
"processed_plain_text": processed_text,
|
||||
"detailed_plain_text": db_dict.get("detailed_plain_text", ""),
|
||||
}
|
||||
|
||||
# 创建 MessageRecv 实例
|
||||
|
||||
@@ -121,27 +121,6 @@ async def get_embedding(text, request_type="embedding"):
|
||||
return embedding
|
||||
|
||||
|
||||
def get_recent_group_detailed_plain_text(chat_stream_id: str, limit: int = 12, combine=False):
|
||||
filter_query = {"chat_id": chat_stream_id}
|
||||
sort_order = [("time", -1)]
|
||||
recent_messages = find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
|
||||
|
||||
if not recent_messages:
|
||||
return []
|
||||
|
||||
# 反转消息列表,使最新的消息在最后
|
||||
recent_messages.reverse()
|
||||
|
||||
if combine:
|
||||
return "".join(str(msg_db_data["detailed_plain_text"]) for msg_db_data in recent_messages)
|
||||
|
||||
message_detailed_plain_text_list = []
|
||||
|
||||
for msg_db_data in recent_messages:
|
||||
message_detailed_plain_text_list.append(msg_db_data["detailed_plain_text"])
|
||||
return message_detailed_plain_text_list
|
||||
|
||||
|
||||
def get_recent_group_speaker(chat_stream_id: str, sender, limit: int = 12) -> list:
|
||||
# 获取当前群聊记录内发言的人
|
||||
filter_query = {"chat_id": chat_stream_id}
|
||||
|
||||
@@ -153,7 +153,6 @@ class Messages(BaseModel):
|
||||
|
||||
processed_plain_text = TextField(null=True) # 处理后的纯文本消息
|
||||
display_message = TextField(null=True) # 显示的消息
|
||||
detailed_plain_text = TextField(null=True) # 详细的纯文本消息
|
||||
memorized_times = IntegerField(default=0) # 被记忆的次数
|
||||
|
||||
priority_mode = TextField(null=True)
|
||||
|
||||
@@ -403,6 +403,10 @@ MODULE_COLORS = {
|
||||
"model_utils": "\033[38;5;164m", # 紫红色
|
||||
"relationship_fetcher": "\033[38;5;170m", # 浅紫色
|
||||
"relationship_builder": "\033[38;5;93m", # 浅蓝色
|
||||
|
||||
#s4u
|
||||
"context_web_api": "\033[38;5;240m", # 深灰色
|
||||
"S4U_chat": "\033[92m", # 深灰色
|
||||
}
|
||||
|
||||
RESET_COLOR = "\033[0m"
|
||||
|
||||
36
src/mais4u/config/old/s4u_config_20250715_141713.toml
Normal file
36
src/mais4u/config/old/s4u_config_20250715_141713.toml
Normal file
@@ -0,0 +1,36 @@
|
||||
[inner]
|
||||
version = "1.0.0"
|
||||
|
||||
#----以下是S4U聊天系统配置文件----
|
||||
# S4U (Smart 4 U) 聊天系统是MaiBot的核心对话模块
|
||||
# 支持优先级队列、消息中断、VIP用户等高级功能
|
||||
#
|
||||
# 如果你想要修改配置文件,请在修改后将version的值进行变更
|
||||
# 如果新增项目,请参考src/mais4u/s4u_config.py中的S4UConfig类
|
||||
#
|
||||
# 版本格式:主版本号.次版本号.修订号
|
||||
#----S4U配置说明结束----
|
||||
|
||||
[s4u]
|
||||
# 消息管理配置
|
||||
message_timeout_seconds = 120 # 普通消息存活时间(秒),超过此时间的消息将被丢弃
|
||||
recent_message_keep_count = 6 # 保留最近N条消息,超出范围的普通消息将被移除
|
||||
|
||||
# 优先级系统配置
|
||||
at_bot_priority_bonus = 100.0 # @机器人时的优先级加成分数
|
||||
vip_queue_priority = true # 是否启用VIP队列优先级系统
|
||||
enable_message_interruption = true # 是否允许高优先级消息中断当前回复
|
||||
|
||||
# 打字效果配置
|
||||
typing_delay = 0.1 # 打字延迟时间(秒),模拟真实打字速度
|
||||
enable_dynamic_typing_delay = false # 是否启用基于文本长度的动态打字延迟
|
||||
|
||||
# 动态打字延迟参数(仅在enable_dynamic_typing_delay=true时生效)
|
||||
chars_per_second = 15.0 # 每秒字符数,用于计算动态打字延迟
|
||||
min_typing_delay = 0.2 # 最小打字延迟(秒)
|
||||
max_typing_delay = 2.0 # 最大打字延迟(秒)
|
||||
|
||||
# 系统功能开关
|
||||
enable_old_message_cleanup = true # 是否自动清理过旧的普通消息
|
||||
enable_loading_indicator = true # 是否显示加载提示
|
||||
|
||||
38
src/mais4u/config/s4u_config.toml
Normal file
38
src/mais4u/config/s4u_config.toml
Normal file
@@ -0,0 +1,38 @@
|
||||
[inner]
|
||||
version = "1.0.1"
|
||||
|
||||
#----以下是S4U聊天系统配置文件----
|
||||
# S4U (Smart 4 U) 聊天系统是MaiBot的核心对话模块
|
||||
# 支持优先级队列、消息中断、VIP用户等高级功能
|
||||
#
|
||||
# 如果你想要修改配置文件,请在修改后将version的值进行变更
|
||||
# 如果新增项目,请参考src/mais4u/s4u_config.py中的S4UConfig类
|
||||
#
|
||||
# 版本格式:主版本号.次版本号.修订号
|
||||
#----S4U配置说明结束----
|
||||
|
||||
[s4u]
|
||||
# 消息管理配置
|
||||
message_timeout_seconds = 120 # 普通消息存活时间(秒),超过此时间的消息将被丢弃
|
||||
recent_message_keep_count = 6 # 保留最近N条消息,超出范围的普通消息将被移除
|
||||
|
||||
# 优先级系统配置
|
||||
at_bot_priority_bonus = 100.0 # @机器人时的优先级加成分数
|
||||
vip_queue_priority = true # 是否启用VIP队列优先级系统
|
||||
enable_message_interruption = true # 是否允许高优先级消息中断当前回复
|
||||
|
||||
# 打字效果配置
|
||||
typing_delay = 0.1 # 打字延迟时间(秒),模拟真实打字速度
|
||||
enable_dynamic_typing_delay = false # 是否启用基于文本长度的动态打字延迟
|
||||
|
||||
# 动态打字延迟参数(仅在enable_dynamic_typing_delay=true时生效)
|
||||
chars_per_second = 15.0 # 每秒字符数,用于计算动态打字延迟
|
||||
min_typing_delay = 0.2 # 最小打字延迟(秒)
|
||||
max_typing_delay = 2.0 # 最大打字延迟(秒)
|
||||
|
||||
# 系统功能开关
|
||||
enable_old_message_cleanup = true # 是否自动清理过旧的普通消息
|
||||
enable_loading_indicator = true # 是否显示加载提示
|
||||
|
||||
max_context_message_length = 20
|
||||
max_core_message_length = 30
|
||||
38
src/mais4u/config/s4u_config_template.toml
Normal file
38
src/mais4u/config/s4u_config_template.toml
Normal file
@@ -0,0 +1,38 @@
|
||||
[inner]
|
||||
version = "1.0.1"
|
||||
|
||||
#----以下是S4U聊天系统配置文件----
|
||||
# S4U (Smart 4 U) 聊天系统是MaiBot的核心对话模块
|
||||
# 支持优先级队列、消息中断、VIP用户等高级功能
|
||||
#
|
||||
# 如果你想要修改配置文件,请在修改后将version的值进行变更
|
||||
# 如果新增项目,请参考src/mais4u/s4u_config.py中的S4UConfig类
|
||||
#
|
||||
# 版本格式:主版本号.次版本号.修订号
|
||||
#----S4U配置说明结束----
|
||||
|
||||
[s4u]
|
||||
# 消息管理配置
|
||||
message_timeout_seconds = 120 # 普通消息存活时间(秒),超过此时间的消息将被丢弃
|
||||
recent_message_keep_count = 6 # 保留最近N条消息,超出范围的普通消息将被移除
|
||||
|
||||
# 优先级系统配置
|
||||
at_bot_priority_bonus = 100.0 # @机器人时的优先级加成分数
|
||||
vip_queue_priority = true # 是否启用VIP队列优先级系统
|
||||
enable_message_interruption = true # 是否允许高优先级消息中断当前回复
|
||||
|
||||
# 打字效果配置
|
||||
typing_delay = 0.1 # 打字延迟时间(秒),模拟真实打字速度
|
||||
enable_dynamic_typing_delay = false # 是否启用基于文本长度的动态打字延迟
|
||||
|
||||
# 动态打字延迟参数(仅在enable_dynamic_typing_delay=true时生效)
|
||||
chars_per_second = 15.0 # 每秒字符数,用于计算动态打字延迟
|
||||
min_typing_delay = 0.2 # 最小打字延迟(秒)
|
||||
max_typing_delay = 2.0 # 最大打字延迟(秒)
|
||||
|
||||
# 系统功能开关
|
||||
enable_old_message_cleanup = true # 是否自动清理过旧的普通消息
|
||||
enable_loading_indicator = true # 是否显示加载提示
|
||||
|
||||
max_context_message_length = 20
|
||||
max_core_message_length = 30
|
||||
1
src/mais4u/mais4u_chat/SUPERCHAT_MANAGER_README.md
Normal file
1
src/mais4u/mais4u_chat/SUPERCHAT_MANAGER_README.md
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
@@ -24,13 +24,32 @@ class ContextMessage:
|
||||
self.timestamp = datetime.now()
|
||||
self.group_name = message.message_info.group_info.group_name if message.message_info.group_info else "私聊"
|
||||
|
||||
# 识别消息类型
|
||||
self.is_gift = getattr(message, 'is_gift', False)
|
||||
self.is_superchat = getattr(message, 'is_superchat', False)
|
||||
|
||||
# 添加礼物和SC相关信息
|
||||
if self.is_gift:
|
||||
self.gift_name = getattr(message, 'gift_name', '')
|
||||
self.gift_count = getattr(message, 'gift_count', '1')
|
||||
self.content = f"送出了 {self.gift_name} x{self.gift_count}"
|
||||
elif self.is_superchat:
|
||||
self.superchat_price = getattr(message, 'superchat_price', '0')
|
||||
self.superchat_message = getattr(message, 'superchat_message_text', '')
|
||||
if self.superchat_message:
|
||||
self.content = f"[¥{self.superchat_price}] {self.superchat_message}"
|
||||
else:
|
||||
self.content = f"[¥{self.superchat_price}] {self.content}"
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"user_name": self.user_name,
|
||||
"user_id": self.user_id,
|
||||
"content": self.content,
|
||||
"timestamp": self.timestamp.strftime("%m-%d %H:%M:%S"),
|
||||
"group_name": self.group_name
|
||||
"group_name": self.group_name,
|
||||
"is_gift": self.is_gift,
|
||||
"is_superchat": self.is_superchat
|
||||
}
|
||||
|
||||
|
||||
@@ -155,6 +174,44 @@ class ContextWebManager:
|
||||
transform: translateX(5px);
|
||||
transition: all 0.3s ease;
|
||||
}
|
||||
.message.gift {
|
||||
border-left: 4px solid #ff8800;
|
||||
background: rgba(255, 136, 0, 0.2);
|
||||
}
|
||||
.message.gift:hover {
|
||||
background: rgba(255, 136, 0, 0.3);
|
||||
}
|
||||
.message.gift .username {
|
||||
color: #ff8800;
|
||||
}
|
||||
.message.superchat {
|
||||
border-left: 4px solid #ff6b6b;
|
||||
background: linear-gradient(135deg, rgba(255, 107, 107, 0.2), rgba(107, 255, 107, 0.2), rgba(107, 107, 255, 0.2));
|
||||
background-size: 200% 200%;
|
||||
animation: rainbow 3s ease infinite;
|
||||
}
|
||||
.message.superchat:hover {
|
||||
background: linear-gradient(135deg, rgba(255, 107, 107, 0.4), rgba(107, 255, 107, 0.4), rgba(107, 107, 255, 0.4));
|
||||
background-size: 200% 200%;
|
||||
}
|
||||
.message.superchat .username {
|
||||
background: linear-gradient(45deg, #ff6b6b, #4ecdc4, #45b7d1, #96ceb4, #feca57);
|
||||
background-size: 300% 300%;
|
||||
animation: rainbow-text 2s ease infinite;
|
||||
-webkit-background-clip: text;
|
||||
-webkit-text-fill-color: transparent;
|
||||
background-clip: text;
|
||||
}
|
||||
@keyframes rainbow {
|
||||
0% { background-position: 0% 50%; }
|
||||
50% { background-position: 100% 50%; }
|
||||
100% { background-position: 0% 50%; }
|
||||
}
|
||||
@keyframes rainbow-text {
|
||||
0% { background-position: 0% 50%; }
|
||||
50% { background-position: 100% 50%; }
|
||||
100% { background-position: 0% 50%; }
|
||||
}
|
||||
.message-line {
|
||||
line-height: 1.4;
|
||||
word-wrap: break-word;
|
||||
@@ -373,7 +430,20 @@ class ContextWebManager:
|
||||
|
||||
function createMessageElement(msg, isNew = false) {
|
||||
const messageDiv = document.createElement('div');
|
||||
messageDiv.className = 'message' + (isNew ? ' new-message' : '');
|
||||
let className = 'message';
|
||||
|
||||
// 根据消息类型添加对应的CSS类
|
||||
if (msg.is_gift) {
|
||||
className += ' gift';
|
||||
} else if (msg.is_superchat) {
|
||||
className += ' superchat';
|
||||
}
|
||||
|
||||
if (isNew) {
|
||||
className += ' new-message';
|
||||
}
|
||||
|
||||
messageDiv.className = className;
|
||||
messageDiv.innerHTML = `
|
||||
<div class="message-line">
|
||||
<span class="username">${escapeHtml(msg.user_name)}:</span><span class="content">${escapeHtml(msg.content)}</span>
|
||||
|
||||
155
src/mais4u/mais4u_chat/gift_manager.py
Normal file
155
src/mais4u/mais4u_chat/gift_manager.py
Normal file
@@ -0,0 +1,155 @@
|
||||
import asyncio
|
||||
from typing import Dict, Tuple, Callable, Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
from src.chat.message_receive.message import MessageRecvS4U
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("gift_manager")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PendingGift:
|
||||
"""等待中的礼物消息"""
|
||||
message: MessageRecvS4U
|
||||
total_count: int
|
||||
timer_task: asyncio.Task
|
||||
callback: Callable[[MessageRecvS4U], None]
|
||||
|
||||
|
||||
class GiftManager:
|
||||
"""礼物管理器,提供防抖功能"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化礼物管理器"""
|
||||
self.pending_gifts: Dict[Tuple[str, str], PendingGift] = {}
|
||||
self.debounce_timeout = 3.0 # 3秒防抖时间
|
||||
|
||||
async def handle_gift(self, message: MessageRecvS4U, callback: Optional[Callable[[MessageRecvS4U], None]] = None) -> bool:
|
||||
"""处理礼物消息,返回是否应该立即处理
|
||||
|
||||
Args:
|
||||
message: 礼物消息
|
||||
callback: 防抖完成后的回调函数
|
||||
|
||||
Returns:
|
||||
bool: False表示消息被暂存等待防抖,True表示应该立即处理
|
||||
"""
|
||||
if not message.is_gift:
|
||||
return True
|
||||
|
||||
# 构建礼物的唯一键:(发送人ID, 礼物名称)
|
||||
gift_key = (message.message_info.user_info.user_id, message.gift_name)
|
||||
|
||||
# 如果已经有相同的礼物在等待中,则合并
|
||||
if gift_key in self.pending_gifts:
|
||||
await self._merge_gift(gift_key, message)
|
||||
return False
|
||||
|
||||
# 创建新的等待礼物
|
||||
await self._create_pending_gift(gift_key, message, callback)
|
||||
return False
|
||||
|
||||
async def _merge_gift(self, gift_key: Tuple[str, str], new_message: MessageRecvS4U) -> None:
|
||||
"""合并礼物消息"""
|
||||
pending_gift = self.pending_gifts[gift_key]
|
||||
|
||||
# 取消之前的定时器
|
||||
if not pending_gift.timer_task.cancelled():
|
||||
pending_gift.timer_task.cancel()
|
||||
|
||||
# 累加礼物数量
|
||||
try:
|
||||
new_count = int(new_message.gift_count)
|
||||
pending_gift.total_count += new_count
|
||||
|
||||
# 更新消息为最新的(保留最新的消息,但累加数量)
|
||||
pending_gift.message = new_message
|
||||
pending_gift.message.gift_count = str(pending_gift.total_count)
|
||||
pending_gift.message.gift_info = f"{pending_gift.message.gift_name}:{pending_gift.total_count}"
|
||||
|
||||
except ValueError:
|
||||
logger.warning(f"无法解析礼物数量: {new_message.gift_count}")
|
||||
# 如果无法解析数量,保持原有数量不变
|
||||
|
||||
# 重新创建定时器
|
||||
pending_gift.timer_task = asyncio.create_task(
|
||||
self._gift_timeout(gift_key)
|
||||
)
|
||||
|
||||
logger.debug(f"合并礼物: {gift_key}, 总数量: {pending_gift.total_count}")
|
||||
|
||||
async def _create_pending_gift(
|
||||
self,
|
||||
gift_key: Tuple[str, str],
|
||||
message: MessageRecvS4U,
|
||||
callback: Optional[Callable[[MessageRecvS4U], None]]
|
||||
) -> None:
|
||||
"""创建新的等待礼物"""
|
||||
try:
|
||||
initial_count = int(message.gift_count)
|
||||
except ValueError:
|
||||
initial_count = 1
|
||||
logger.warning(f"无法解析礼物数量: {message.gift_count},默认设为1")
|
||||
|
||||
# 创建定时器任务
|
||||
timer_task = asyncio.create_task(self._gift_timeout(gift_key))
|
||||
|
||||
# 创建等待礼物对象
|
||||
pending_gift = PendingGift(
|
||||
message=message,
|
||||
total_count=initial_count,
|
||||
timer_task=timer_task,
|
||||
callback=callback
|
||||
)
|
||||
|
||||
self.pending_gifts[gift_key] = pending_gift
|
||||
|
||||
logger.debug(f"创建等待礼物: {gift_key}, 初始数量: {initial_count}")
|
||||
|
||||
async def _gift_timeout(self, gift_key: Tuple[str, str]) -> None:
|
||||
"""礼物防抖超时处理"""
|
||||
try:
|
||||
# 等待防抖时间
|
||||
await asyncio.sleep(self.debounce_timeout)
|
||||
|
||||
# 获取等待中的礼物
|
||||
if gift_key not in self.pending_gifts:
|
||||
return
|
||||
|
||||
pending_gift = self.pending_gifts.pop(gift_key)
|
||||
|
||||
logger.info(f"礼物防抖完成: {gift_key}, 最终数量: {pending_gift.total_count}")
|
||||
|
||||
message = pending_gift.message
|
||||
message.processed_plain_text = f"用户{message.message_info.user_info.user_nickname}送出了礼物{message.gift_name} x{pending_gift.total_count}"
|
||||
|
||||
# 执行回调
|
||||
if pending_gift.callback:
|
||||
try:
|
||||
pending_gift.callback(message)
|
||||
except Exception as e:
|
||||
logger.error(f"礼物回调执行失败: {e}", exc_info=True)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# 定时器被取消,不需要处理
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"礼物防抖处理异常: {e}", exc_info=True)
|
||||
|
||||
def get_pending_count(self) -> int:
|
||||
"""获取当前等待中的礼物数量"""
|
||||
return len(self.pending_gifts)
|
||||
|
||||
async def flush_all(self) -> None:
|
||||
"""立即处理所有等待中的礼物"""
|
||||
for gift_key in list(self.pending_gifts.keys()):
|
||||
pending_gift = self.pending_gifts.get(gift_key)
|
||||
if pending_gift and not pending_gift.timer_task.cancelled():
|
||||
pending_gift.timer_task.cancel()
|
||||
await self._gift_timeout(gift_key)
|
||||
|
||||
|
||||
# 创建全局礼物管理器实例
|
||||
gift_manager = GiftManager()
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import traceback
|
||||
import time
|
||||
import random
|
||||
from typing import Optional, Dict, Tuple # 导入类型提示
|
||||
@@ -6,7 +7,7 @@ from maim_message import UserInfo, Seg
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from .s4u_stream_generator import S4UStreamGenerator
|
||||
from src.chat.message_receive.message import MessageSending, MessageRecv
|
||||
from src.chat.message_receive.message import MessageSending, MessageRecv, MessageRecvS4U
|
||||
from src.config.config import global_config
|
||||
from src.common.message.api import get_global_api
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
@@ -14,6 +15,9 @@ from .s4u_watching_manager import watching_manager
|
||||
import json
|
||||
from src.person_info.relationship_builder_manager import relationship_builder_manager
|
||||
from .loading import send_loading, send_unloading
|
||||
from src.mais4u.s4u_config import s4u_config
|
||||
from src.person_info.person_info import PersonInfoManager
|
||||
from .super_chat_manager import get_super_chat_manager
|
||||
|
||||
logger = get_logger("S4U_chat")
|
||||
|
||||
@@ -49,9 +53,9 @@ class MessageSenderContainer:
|
||||
|
||||
def _calculate_typing_delay(self, text: str) -> float:
|
||||
"""根据文本长度计算模拟打字延迟。"""
|
||||
chars_per_second = 15.0
|
||||
min_delay = 0.2
|
||||
max_delay = 2.0
|
||||
chars_per_second = s4u_config.chars_per_second
|
||||
min_delay = s4u_config.min_typing_delay
|
||||
max_delay = s4u_config.max_typing_delay
|
||||
|
||||
delay = len(text) / chars_per_second
|
||||
return max(min_delay, min(delay, max_delay))
|
||||
@@ -73,8 +77,11 @@ class MessageSenderContainer:
|
||||
# Check for pause signal *after* getting an item.
|
||||
await self._paused_event.wait()
|
||||
|
||||
# delay = self._calculate_typing_delay(chunk)
|
||||
delay = 0.1
|
||||
# 根据配置选择延迟模式
|
||||
if s4u_config.enable_dynamic_typing_delay:
|
||||
delay = self._calculate_typing_delay(chunk)
|
||||
else:
|
||||
delay = s4u_config.typing_delay
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
current_time = time.time()
|
||||
@@ -144,8 +151,6 @@ def get_s4u_chat_manager() -> S4UChatManager:
|
||||
|
||||
|
||||
class S4UChat:
|
||||
_MESSAGE_TIMEOUT_SECONDS = 120 # 普通消息存活时间(秒)
|
||||
|
||||
def __init__(self, chat_stream: ChatStream):
|
||||
"""初始化 S4UChat 实例。"""
|
||||
|
||||
@@ -169,8 +174,7 @@ class S4UChat:
|
||||
self._is_replying = False
|
||||
self.gpt = S4UStreamGenerator()
|
||||
self.interest_dict: Dict[str, float] = {} # 用户兴趣分
|
||||
self.at_bot_priority_bonus = 100.0 # @机器人的优先级加成
|
||||
self.recent_message_keep_count = 6 # 保留最近N条消息,超出范围的普通消息将被移除
|
||||
|
||||
logger.info(f"[{self.stream_name}] S4UChat with two-queue system initialized.")
|
||||
|
||||
def _get_priority_info(self, message: MessageRecv) -> dict:
|
||||
@@ -194,16 +198,13 @@ class S4UChat:
|
||||
"""获取用户的兴趣分,默认为1.0"""
|
||||
return self.interest_dict.get(user_id, 1.0)
|
||||
|
||||
|
||||
|
||||
def _calculate_base_priority_score(self, message: MessageRecv, priority_info: dict) -> float:
|
||||
"""
|
||||
为消息计算基础优先级分数。分数越高,优先级越高。
|
||||
"""
|
||||
score = 0.0
|
||||
# 如果消息 @ 了机器人,则增加一个很大的分数
|
||||
# if f"@{global_config.bot.nickname}" in message.processed_plain_text or any(
|
||||
# f"@{alias}" in message.processed_plain_text for alias in global_config.bot.alias_names
|
||||
# ):
|
||||
# score += self.at_bot_priority_bonus
|
||||
|
||||
# 加上消息自带的优先级
|
||||
score += priority_info.get("message_priority", 0.0)
|
||||
@@ -212,17 +213,55 @@ class S4UChat:
|
||||
score += self._get_interest_score(message.message_info.user_info.user_id)
|
||||
return score
|
||||
|
||||
async def add_message(self, message: MessageRecv) -> None:
|
||||
"""根据VIP状态和中断逻辑将消息放入相应队列。"""
|
||||
def decay_interest_score(self,message: MessageRecvS4U|MessageRecv):
|
||||
for person_id, score in self.interest_dict.items():
|
||||
if score > 0:
|
||||
self.interest_dict[person_id] = score * 0.95
|
||||
else:
|
||||
self.interest_dict[person_id] = 0
|
||||
|
||||
await self.relationship_builder.build_relation()
|
||||
async def add_message(self, message: MessageRecvS4U|MessageRecv) -> None:
|
||||
|
||||
self.decay_interest_score(message)
|
||||
|
||||
"""根据VIP状态和中断逻辑将消息放入相应队列。"""
|
||||
user_id = message.message_info.user_info.user_id
|
||||
platform = message.message_info.platform
|
||||
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||
|
||||
try:
|
||||
is_gift = message.is_gift
|
||||
is_superchat = message.is_superchat
|
||||
print(is_gift)
|
||||
print(is_superchat)
|
||||
if is_gift:
|
||||
await self.relationship_builder.build_relation(immediate_build=person_id)
|
||||
# 安全地增加兴趣分,如果person_id不存在则先初始化为1.0
|
||||
current_score = self.interest_dict.get(person_id, 1.0)
|
||||
self.interest_dict[person_id] = current_score + 0.1 * message.gift_count
|
||||
elif is_superchat:
|
||||
await self.relationship_builder.build_relation(immediate_build=person_id)
|
||||
# 安全地增加兴趣分,如果person_id不存在则先初始化为1.0
|
||||
current_score = self.interest_dict.get(person_id, 1.0)
|
||||
self.interest_dict[person_id] = current_score + 0.1 * float(message.superchat_price)
|
||||
|
||||
# 添加SuperChat到管理器
|
||||
super_chat_manager = get_super_chat_manager()
|
||||
await super_chat_manager.add_superchat(message)
|
||||
else:
|
||||
await self.relationship_builder.build_relation(20)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
||||
logger.info(f"[{self.stream_name}] 消息处理完毕,消息内容:{message.processed_plain_text}")
|
||||
|
||||
priority_info = self._get_priority_info(message)
|
||||
is_vip = self._is_vip(priority_info)
|
||||
new_priority_score = self._calculate_base_priority_score(message, priority_info)
|
||||
|
||||
should_interrupt = False
|
||||
if self._current_generation_task and not self._current_generation_task.done():
|
||||
if (s4u_config.enable_message_interruption and
|
||||
self._current_generation_task and not self._current_generation_task.done()):
|
||||
if self._current_message_being_replied:
|
||||
current_queue, current_priority, _, current_msg = self._current_message_being_replied
|
||||
|
||||
@@ -260,7 +299,7 @@ class S4UChat:
|
||||
# 这样,原始分数越高的消息,在队列中的优先级数字越小,越靠前
|
||||
item = (-new_priority_score, self._entry_counter, time.time(), message)
|
||||
|
||||
if is_vip:
|
||||
if is_vip and s4u_config.vip_queue_priority:
|
||||
await self._vip_queue.put(item)
|
||||
logger.info(f"[{self.stream_name}] VIP message added to queue.")
|
||||
else:
|
||||
@@ -271,11 +310,11 @@ class S4UChat:
|
||||
|
||||
def _cleanup_old_normal_messages(self):
|
||||
"""清理普通队列中不在最近N条消息范围内的消息"""
|
||||
if self._normal_queue.empty():
|
||||
if not s4u_config.enable_old_message_cleanup or self._normal_queue.empty():
|
||||
return
|
||||
|
||||
# 计算阈值:保留最近 recent_message_keep_count 条消息
|
||||
cutoff_counter = max(0, self._entry_counter - self.recent_message_keep_count)
|
||||
cutoff_counter = max(0, self._entry_counter - s4u_config.recent_message_keep_count)
|
||||
|
||||
# 临时存储需要保留的消息
|
||||
temp_messages = []
|
||||
@@ -302,7 +341,7 @@ class S4UChat:
|
||||
self._normal_queue.put_nowait(item)
|
||||
|
||||
if removed_count > 0:
|
||||
logger.info(f"[{self.stream_name}] Cleaned up {removed_count} old normal messages outside recent {self.recent_message_keep_count} range.")
|
||||
logger.info(f"[{self.stream_name}] Cleaned up {removed_count} old normal messages outside recent {s4u_config.recent_message_keep_count} range.")
|
||||
|
||||
async def _message_processor(self):
|
||||
"""调度器:优先处理VIP队列,然后处理普通队列。"""
|
||||
@@ -325,7 +364,7 @@ class S4UChat:
|
||||
neg_priority, entry_count, timestamp, message = self._normal_queue.get_nowait()
|
||||
priority = -neg_priority
|
||||
# 检查普通消息是否超时
|
||||
if time.time() - timestamp > self._MESSAGE_TIMEOUT_SECONDS:
|
||||
if time.time() - timestamp > s4u_config.message_timeout_seconds:
|
||||
logger.info(
|
||||
f"[{self.stream_name}] Discarding stale normal message: {message.processed_plain_text[:20]}..."
|
||||
)
|
||||
@@ -369,18 +408,24 @@ class S4UChat:
|
||||
logger.error(f"[{self.stream_name}] Message processor main loop error: {e}", exc_info=True)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def delay_change_watching_state(self):
|
||||
random_delay = random.randint(1, 3)
|
||||
await asyncio.sleep(random_delay)
|
||||
chat_watching = watching_manager.get_watching_by_chat_id(self.stream_id)
|
||||
await chat_watching.on_message_received()
|
||||
|
||||
async def _generate_and_send(self, message: MessageRecv):
|
||||
"""为单个消息生成文本回复。整个过程可以被中断。"""
|
||||
self._is_replying = True
|
||||
total_chars_sent = 0 # 跟踪发送的总字符数
|
||||
|
||||
if s4u_config.enable_loading_indicator:
|
||||
await send_loading(self.stream_id, "......")
|
||||
|
||||
# 视线管理:开始生成回复时切换视线状态
|
||||
chat_watching = watching_manager.get_watching_by_chat_id(self.stream_id)
|
||||
await chat_watching.on_reply_start()
|
||||
asyncio.create_task(self.delay_change_watching_state())
|
||||
|
||||
# 回复生成实时展示:开始生成
|
||||
user_name = message.message_info.user_info.user_nickname
|
||||
|
||||
sender_container = MessageSenderContainer(self.chat_stream, message)
|
||||
sender_container.start()
|
||||
@@ -395,12 +440,18 @@ class S4UChat:
|
||||
|
||||
# a. 发送文本块
|
||||
await sender_container.add_message(chunk)
|
||||
total_chars_sent += len(chunk) # 累计字符数
|
||||
|
||||
|
||||
# 等待所有文本消息发送完成
|
||||
await sender_container.close()
|
||||
await sender_container.join()
|
||||
|
||||
# 回复完成后延迟,每个字延迟0.4秒
|
||||
if total_chars_sent > 0:
|
||||
delay_time = total_chars_sent * 0.4
|
||||
logger.info(f"[{self.stream_name}] 回复完成,共发送 {total_chars_sent} 个字符,等待 {delay_time:.1f} 秒后继续处理下一个消息。")
|
||||
await asyncio.sleep(delay_time)
|
||||
|
||||
logger.info(f"[{self.stream_name}] 所有文本块处理完毕。")
|
||||
|
||||
@@ -408,11 +459,13 @@ class S4UChat:
|
||||
logger.info(f"[{self.stream_name}] 回复流程(文本)被中断。")
|
||||
raise # 将取消异常向上传播
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
logger.error(f"[{self.stream_name}] 回复生成过程中出现错误: {e}", exc_info=True)
|
||||
# 回复生成实时展示:清空内容(出错时)
|
||||
finally:
|
||||
self._is_replying = False
|
||||
|
||||
if s4u_config.enable_loading_indicator:
|
||||
await send_unloading(self.stream_id)
|
||||
|
||||
# 视线管理:回复结束时切换视线状态
|
||||
@@ -442,3 +495,8 @@ class S4UChat:
|
||||
await self._processing_task
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"处理任务已成功取消: {self.stream_name}")
|
||||
|
||||
# 注意:SuperChat管理器是全局的,不需要在单个S4UChat关闭时关闭
|
||||
# 如果需要关闭SuperChat管理器,应该在应用程序关闭时调用
|
||||
# super_chat_manager = get_super_chat_manager()
|
||||
# await super_chat_manager.shutdown()
|
||||
|
||||
@@ -214,7 +214,7 @@ class ChatMood:
|
||||
sorrow=self.mood_values["sorrow"],
|
||||
fear=self.mood_values["fear"],
|
||||
)
|
||||
logger.info(f"numerical mood prompt: {prompt}")
|
||||
logger.debug(f"numerical mood prompt: {prompt}")
|
||||
response, (reasoning_content, model_name) = await self.mood_model_numerical.generate_response_async(
|
||||
prompt=prompt
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@ import math
|
||||
from typing import Tuple
|
||||
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
@@ -14,6 +14,7 @@ from src.mais4u.mais4u_chat.body_emotion_action_manager import action_manager
|
||||
from src.mais4u.mais4u_chat.s4u_mood_manager import mood_manager
|
||||
from src.mais4u.mais4u_chat.s4u_watching_manager import watching_manager
|
||||
from src.mais4u.mais4u_chat.context_web_manager import get_context_web_manager
|
||||
from src.mais4u.mais4u_chat.gift_manager import gift_manager
|
||||
|
||||
from .s4u_chat import get_s4u_chat_manager
|
||||
|
||||
@@ -66,7 +67,7 @@ class S4UMessageProcessor:
|
||||
"""初始化心流处理器,创建消息存储实例"""
|
||||
self.storage = MessageStorage()
|
||||
|
||||
async def process_message(self, message: MessageRecv) -> None:
|
||||
async def process_message(self, message: MessageRecvS4U, skip_gift_debounce: bool = False) -> None:
|
||||
"""处理接收到的原始消息数据
|
||||
|
||||
主要流程:
|
||||
@@ -80,8 +81,6 @@ class S4UMessageProcessor:
|
||||
message_data: 原始消息字符串
|
||||
"""
|
||||
|
||||
target_user_id_list = ["1026294844", "964959351"]
|
||||
|
||||
# 1. 消息解析与初始化
|
||||
groupinfo = message.message_info.group_info
|
||||
userinfo = message.message_info.user_info
|
||||
@@ -93,25 +92,29 @@ class S4UMessageProcessor:
|
||||
group_info=groupinfo,
|
||||
)
|
||||
|
||||
# 处理礼物消息,如果消息被暂存则停止当前处理流程
|
||||
if not skip_gift_debounce and not await self.handle_if_gift(message):
|
||||
return
|
||||
|
||||
await self.check_if_fake_gift(message)
|
||||
|
||||
await self.storage.store_message(message, chat)
|
||||
|
||||
s4u_chat = get_s4u_chat_manager().get_or_create_chat(chat)
|
||||
|
||||
if userinfo.user_id in target_user_id_list:
|
||||
await s4u_chat.add_message(message)
|
||||
else:
|
||||
await s4u_chat.add_message(message)
|
||||
|
||||
interested_rate, _ = await _calculate_interest(message)
|
||||
_interested_rate, _ = await _calculate_interest(message)
|
||||
|
||||
await mood_manager.start()
|
||||
|
||||
|
||||
|
||||
# 一系列llm驱动的前处理
|
||||
chat_mood = mood_manager.get_mood_by_chat_id(chat.stream_id)
|
||||
asyncio.create_task(chat_mood.update_mood_by_message(message))
|
||||
chat_action = action_manager.get_action_state_by_chat_id(chat.stream_id)
|
||||
asyncio.create_task(chat_action.update_action_by_message(message))
|
||||
# asyncio.create_task(chat_action.update_facial_expression_by_message(message, interested_rate))
|
||||
|
||||
# 视线管理:收到消息时切换视线状态
|
||||
chat_watching = watching_manager.get_watching_by_chat_id(chat.stream_id)
|
||||
asyncio.create_task(chat_watching.on_message_received())
|
||||
@@ -119,9 +122,44 @@ class S4UMessageProcessor:
|
||||
# 上下文网页管理:启动独立task处理消息上下文
|
||||
asyncio.create_task(self._handle_context_web_update(chat.stream_id, message))
|
||||
|
||||
# 7. 日志记录
|
||||
# 日志记录
|
||||
if message.is_gift:
|
||||
logger.info(f"[S4U-礼物] {userinfo.user_nickname} 送出了 {message.gift_name} x{message.gift_count}")
|
||||
else:
|
||||
logger.info(f"[S4U]{userinfo.user_nickname}:{message.processed_plain_text}")
|
||||
|
||||
async def check_if_fake_gift(self, message: MessageRecvS4U) -> bool:
|
||||
"""检查消息是否为假礼物"""
|
||||
if message.is_gift:
|
||||
return False
|
||||
|
||||
gift_keywords = ["送出了礼物", "礼物", "送出了"]
|
||||
if any(keyword in message.processed_plain_text for keyword in gift_keywords):
|
||||
message.processed_plain_text += "(注意:这是一条普通弹幕信息,对方没有真的发送礼物,不是礼物信息,注意区分)"
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def handle_if_gift(self, message: MessageRecvS4U) -> bool:
|
||||
"""处理礼物消息
|
||||
|
||||
Returns:
|
||||
bool: True表示应该继续处理消息,False表示消息已被暂存不需要继续处理
|
||||
"""
|
||||
if message.is_gift:
|
||||
# 定义防抖完成后的回调函数
|
||||
def gift_callback(merged_message: MessageRecvS4U):
|
||||
"""礼物防抖完成后的回调"""
|
||||
# 创建异步任务来处理合并后的礼物消息,跳过防抖处理
|
||||
asyncio.create_task(self.process_message(merged_message, skip_gift_debounce=True))
|
||||
|
||||
# 交给礼物管理器处理,并传入回调函数
|
||||
# 对于礼物消息,handle_gift 总是返回 False(消息被暂存)
|
||||
await gift_manager.handle_gift(message, gift_callback)
|
||||
return False # 消息被暂存,不继续处理
|
||||
|
||||
return True # 非礼物消息,继续正常处理
|
||||
|
||||
async def _handle_context_web_update(self, chat_id: str, message: MessageRecv):
|
||||
"""处理上下文网页更新的独立task
|
||||
|
||||
|
||||
@@ -8,10 +8,13 @@ from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
import random
|
||||
from datetime import datetime
|
||||
import asyncio
|
||||
from src.mais4u.s4u_config import s4u_config
|
||||
import ast
|
||||
|
||||
from src.chat.message_receive.message import MessageSending, MessageRecvS4U
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from src.person_info.relationship_manager import get_relationship_manager
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.mais4u.mais4u_chat.super_chat_manager import get_super_chat_manager
|
||||
|
||||
logger = get_logger("prompt")
|
||||
|
||||
@@ -22,13 +25,19 @@ def init_prompt():
|
||||
Prompt("你回想起了一些事情:\n{memory_info}\n", "memory_prompt")
|
||||
|
||||
Prompt(
|
||||
"""{identity_block}
|
||||
"""
|
||||
你的名字是麦麦, 是千石可乐开发的程序,可以在QQ,微信等平台发言,你现在正在哔哩哔哩作为虚拟主播进行直播
|
||||
虽然是个AI,但你在努力扮演真人, 人设的配色是橙色和紫色和绿色
|
||||
你有一头法式橙色卷发,你很可爱,穿戴英式侦探风格学院风裙子和帽子,你正在一个教室场景里进行虚拟线上直播。
|
||||
你可以看见用户发送的弹幕,礼物和superchat
|
||||
你可以看见面前的屏幕,
|
||||
|
||||
{relation_info_block}
|
||||
{memory_block}
|
||||
|
||||
你现在的主要任务是和 {sender_name} 聊天。同时,也有其他用户会参与你们的聊天,你可以参考他们的回复内容,但是你主要还是关注你和{sender_name}的聊天内容。
|
||||
你现在的主要任务是和 {sender_name} 发送的弹幕聊天。同时,也有其他用户会参与你们的聊天,你可以参考他们的回复内容,但是你主要还是关注你和{sender_name}的聊天内容。
|
||||
|
||||
{sc_info}
|
||||
|
||||
{background_dialogue_prompt}
|
||||
--------------------------------
|
||||
@@ -37,6 +46,7 @@ def init_prompt():
|
||||
{core_dialogue_prompt}
|
||||
|
||||
对方最新发送的内容:{message_txt}
|
||||
{gift_info}
|
||||
回复可以简短一些。可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞,平淡一些。
|
||||
不要输出多余内容(包括前后缀,冒号和引号,括号(),表情包,at或 @等 )。只输出回复内容,现在{sender_name}正在等待你的回复。
|
||||
你的回复风格不要浮夸,有逻辑和条理,请你继续回复{sender_name}。
|
||||
@@ -117,14 +127,14 @@ class PromptBuilder:
|
||||
return await global_prompt_manager.format_prompt("memory_prompt", memory_info=related_memory_info)
|
||||
return ""
|
||||
|
||||
def build_chat_history_prompts(self, chat_stream, message) -> (str, str):
|
||||
def build_chat_history_prompts(self, chat_stream: ChatStream, message: MessageRecvS4U):
|
||||
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
|
||||
chat_id=chat_stream.stream_id,
|
||||
timestamp=time.time(),
|
||||
limit=100,
|
||||
limit=200,
|
||||
)
|
||||
|
||||
talk_type = message.message_info.platform + ":" + message.chat_stream.user_info.user_id
|
||||
talk_type = message.message_info.platform + ":" + str(message.chat_stream.user_info.user_id)
|
||||
|
||||
core_dialogue_list = []
|
||||
background_dialogue_list = []
|
||||
@@ -148,10 +158,9 @@ class PromptBuilder:
|
||||
|
||||
background_dialogue_prompt = ""
|
||||
if background_dialogue_list:
|
||||
latest_25_msgs = background_dialogue_list[-25:]
|
||||
context_msgs = background_dialogue_list[-s4u_config.max_context_message_length:]
|
||||
background_dialogue_prompt_str = build_readable_messages(
|
||||
latest_25_msgs,
|
||||
merge_messages=True,
|
||||
context_msgs,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
show_pic=False,
|
||||
)
|
||||
@@ -159,7 +168,7 @@ class PromptBuilder:
|
||||
|
||||
core_msg_str = ""
|
||||
if core_dialogue_list:
|
||||
core_dialogue_list = core_dialogue_list[-50:]
|
||||
core_dialogue_list = core_dialogue_list[-s4u_config.max_core_message_length:]
|
||||
|
||||
first_msg = core_dialogue_list[0]
|
||||
start_speaking_user_id = first_msg.get("user_id")
|
||||
@@ -196,10 +205,19 @@ class PromptBuilder:
|
||||
|
||||
return core_msg_str, background_dialogue_prompt
|
||||
|
||||
def build_gift_info(self, message: MessageRecvS4U):
|
||||
if message.is_gift:
|
||||
return f"这是一条礼物信息,{message.gift_name} x{message.gift_count},请注意这位用户"
|
||||
return ""
|
||||
|
||||
def build_sc_info(self, message: MessageRecvS4U):
|
||||
super_chat_manager = get_super_chat_manager()
|
||||
return super_chat_manager.build_superchat_summary_string(message.chat_stream.stream_id)
|
||||
|
||||
async def build_prompt_normal(
|
||||
self,
|
||||
message,
|
||||
chat_stream,
|
||||
message: MessageRecvS4U,
|
||||
chat_stream: ChatStream,
|
||||
message_txt: str,
|
||||
sender_name: str = "某人",
|
||||
) -> str:
|
||||
@@ -209,6 +227,10 @@ class PromptBuilder:
|
||||
|
||||
core_dialogue_prompt, background_dialogue_prompt = self.build_chat_history_prompts(chat_stream, message)
|
||||
|
||||
gift_info = self.build_gift_info(message)
|
||||
|
||||
sc_info = self.build_sc_info(message)
|
||||
|
||||
time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
|
||||
template_name = "s4u_prompt"
|
||||
@@ -219,12 +241,16 @@ class PromptBuilder:
|
||||
time_block=time_block,
|
||||
relation_info_block=relation_info_block,
|
||||
memory_block=memory_block,
|
||||
gift_info=gift_info,
|
||||
sc_info=sc_info,
|
||||
sender_name=sender_name,
|
||||
core_dialogue_prompt=core_dialogue_prompt,
|
||||
background_dialogue_prompt=background_dialogue_prompt,
|
||||
message_txt=message_txt,
|
||||
)
|
||||
|
||||
print(prompt)
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
|
||||
307
src/mais4u/mais4u_chat/super_chat_manager.py
Normal file
307
src/mais4u/mais4u_chat/super_chat_manager.py
Normal file
@@ -0,0 +1,307 @@
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.message import MessageRecvS4U, MessageRecv
|
||||
from src.mais4u.s4u_config import s4u_config
|
||||
|
||||
logger = get_logger("super_chat_manager")
|
||||
|
||||
|
||||
@dataclass
|
||||
class SuperChatRecord:
|
||||
"""SuperChat记录数据类"""
|
||||
|
||||
user_id: str
|
||||
user_nickname: str
|
||||
platform: str
|
||||
chat_id: str
|
||||
price: float
|
||||
message_text: str
|
||||
timestamp: float
|
||||
expire_time: float
|
||||
group_name: Optional[str] = None
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""检查SuperChat是否已过期"""
|
||||
return time.time() > self.expire_time
|
||||
|
||||
def remaining_time(self) -> float:
|
||||
"""获取剩余时间(秒)"""
|
||||
return max(0, self.expire_time - time.time())
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""转换为字典格式"""
|
||||
return {
|
||||
"user_id": self.user_id,
|
||||
"user_nickname": self.user_nickname,
|
||||
"platform": self.platform,
|
||||
"chat_id": self.chat_id,
|
||||
"price": self.price,
|
||||
"message_text": self.message_text,
|
||||
"timestamp": self.timestamp,
|
||||
"expire_time": self.expire_time,
|
||||
"group_name": self.group_name,
|
||||
"remaining_time": self.remaining_time()
|
||||
}
|
||||
|
||||
|
||||
class SuperChatManager:
|
||||
"""SuperChat管理器,负责管理和跟踪SuperChat消息"""
|
||||
|
||||
def __init__(self):
|
||||
self.super_chats: Dict[str, List[SuperChatRecord]] = {} # chat_id -> SuperChat列表
|
||||
self._cleanup_task: Optional[asyncio.Task] = None
|
||||
self._is_initialized = False
|
||||
logger.info("SuperChat管理器已初始化")
|
||||
|
||||
def _ensure_cleanup_task_started(self):
|
||||
"""确保清理任务已启动(延迟启动)"""
|
||||
if self._cleanup_task is None or self._cleanup_task.done():
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
self._cleanup_task = loop.create_task(self._cleanup_expired_superchats())
|
||||
self._is_initialized = True
|
||||
logger.info("SuperChat清理任务已启动")
|
||||
except RuntimeError:
|
||||
# 没有运行的事件循环,稍后再启动
|
||||
logger.debug("当前没有运行的事件循环,将在需要时启动清理任务")
|
||||
|
||||
def _start_cleanup_task(self):
|
||||
"""启动清理任务(已弃用,保留向后兼容)"""
|
||||
self._ensure_cleanup_task_started()
|
||||
|
||||
async def _cleanup_expired_superchats(self):
|
||||
"""定期清理过期的SuperChat"""
|
||||
while True:
|
||||
try:
|
||||
current_time = time.time()
|
||||
total_removed = 0
|
||||
|
||||
for chat_id in list(self.super_chats.keys()):
|
||||
original_count = len(self.super_chats[chat_id])
|
||||
# 移除过期的SuperChat
|
||||
self.super_chats[chat_id] = [
|
||||
sc for sc in self.super_chats[chat_id]
|
||||
if not sc.is_expired()
|
||||
]
|
||||
|
||||
removed_count = original_count - len(self.super_chats[chat_id])
|
||||
total_removed += removed_count
|
||||
|
||||
if removed_count > 0:
|
||||
logger.info(f"从聊天 {chat_id} 中清理了 {removed_count} 个过期的SuperChat")
|
||||
|
||||
# 如果列表为空,删除该聊天的记录
|
||||
if not self.super_chats[chat_id]:
|
||||
del self.super_chats[chat_id]
|
||||
|
||||
if total_removed > 0:
|
||||
logger.info(f"总共清理了 {total_removed} 个过期的SuperChat")
|
||||
|
||||
# 每30秒检查一次
|
||||
await asyncio.sleep(30)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"清理过期SuperChat时出错: {e}", exc_info=True)
|
||||
await asyncio.sleep(60) # 出错时等待更长时间
|
||||
|
||||
def _calculate_expire_time(self, price: float) -> float:
|
||||
"""根据SuperChat金额计算过期时间"""
|
||||
current_time = time.time()
|
||||
|
||||
# 根据金额阶梯设置不同的存活时间
|
||||
if price >= 500:
|
||||
# 500元以上:保持4小时
|
||||
duration = 4 * 3600
|
||||
elif price >= 200:
|
||||
# 200-499元:保持2小时
|
||||
duration = 2 * 3600
|
||||
elif price >= 100:
|
||||
# 100-199元:保持1小时
|
||||
duration = 1 * 3600
|
||||
elif price >= 50:
|
||||
# 50-99元:保持30分钟
|
||||
duration = 30 * 60
|
||||
elif price >= 20:
|
||||
# 20-49元:保持15分钟
|
||||
duration = 15 * 60
|
||||
elif price >= 10:
|
||||
# 10-19元:保持10分钟
|
||||
duration = 10 * 60
|
||||
else:
|
||||
# 10元以下:保持5分钟
|
||||
duration = 5 * 60
|
||||
|
||||
return current_time + duration
|
||||
|
||||
async def add_superchat(self, message: MessageRecvS4U) -> None:
|
||||
"""添加新的SuperChat记录"""
|
||||
# 确保清理任务已启动
|
||||
self._ensure_cleanup_task_started()
|
||||
|
||||
if not message.is_superchat or not message.superchat_price:
|
||||
logger.warning("尝试添加非SuperChat消息到SuperChat管理器")
|
||||
return
|
||||
|
||||
try:
|
||||
price = float(message.superchat_price)
|
||||
except (ValueError, TypeError):
|
||||
logger.error(f"无效的SuperChat价格: {message.superchat_price}")
|
||||
return
|
||||
|
||||
user_info = message.message_info.user_info
|
||||
group_info = message.message_info.group_info
|
||||
chat_id = getattr(message, 'chat_stream', None)
|
||||
if chat_id:
|
||||
chat_id = chat_id.stream_id
|
||||
else:
|
||||
# 生成chat_id的备用方法
|
||||
chat_id = f"{message.message_info.platform}_{user_info.user_id}"
|
||||
if group_info:
|
||||
chat_id = f"{message.message_info.platform}_{group_info.group_id}"
|
||||
|
||||
expire_time = self._calculate_expire_time(price)
|
||||
|
||||
record = SuperChatRecord(
|
||||
user_id=user_info.user_id,
|
||||
user_nickname=user_info.user_nickname,
|
||||
platform=message.message_info.platform,
|
||||
chat_id=chat_id,
|
||||
price=price,
|
||||
message_text=message.superchat_message_text or "",
|
||||
timestamp=message.message_info.time,
|
||||
expire_time=expire_time,
|
||||
group_name=group_info.group_name if group_info else None
|
||||
)
|
||||
|
||||
# 添加到对应聊天的SuperChat列表
|
||||
if chat_id not in self.super_chats:
|
||||
self.super_chats[chat_id] = []
|
||||
|
||||
self.super_chats[chat_id].append(record)
|
||||
|
||||
# 按价格降序排序(价格高的在前)
|
||||
self.super_chats[chat_id].sort(key=lambda x: x.price, reverse=True)
|
||||
|
||||
logger.info(f"添加SuperChat记录: {user_info.user_nickname} - {price}元 - {message.superchat_message_text}")
|
||||
|
||||
def get_superchats_by_chat(self, chat_id: str) -> List[SuperChatRecord]:
|
||||
"""获取指定聊天的所有有效SuperChat"""
|
||||
# 确保清理任务已启动
|
||||
self._ensure_cleanup_task_started()
|
||||
|
||||
if chat_id not in self.super_chats:
|
||||
return []
|
||||
|
||||
# 过滤掉过期的SuperChat
|
||||
valid_superchats = [sc for sc in self.super_chats[chat_id] if not sc.is_expired()]
|
||||
return valid_superchats
|
||||
|
||||
def get_all_valid_superchats(self) -> Dict[str, List[SuperChatRecord]]:
|
||||
"""获取所有有效的SuperChat"""
|
||||
# 确保清理任务已启动
|
||||
self._ensure_cleanup_task_started()
|
||||
|
||||
result = {}
|
||||
for chat_id, superchats in self.super_chats.items():
|
||||
valid_superchats = [sc for sc in superchats if not sc.is_expired()]
|
||||
if valid_superchats:
|
||||
result[chat_id] = valid_superchats
|
||||
return result
|
||||
|
||||
def build_superchat_display_string(self, chat_id: str, max_count: int = 10) -> str:
|
||||
"""构建SuperChat显示字符串"""
|
||||
superchats = self.get_superchats_by_chat(chat_id)
|
||||
|
||||
if not superchats:
|
||||
return ""
|
||||
|
||||
# 限制显示数量
|
||||
display_superchats = superchats[:max_count]
|
||||
|
||||
lines = []
|
||||
lines.append("📢 当前有效超级弹幕:")
|
||||
|
||||
for i, sc in enumerate(display_superchats, 1):
|
||||
remaining_minutes = int(sc.remaining_time() / 60)
|
||||
remaining_seconds = int(sc.remaining_time() % 60)
|
||||
|
||||
time_display = f"{remaining_minutes}分{remaining_seconds}秒" if remaining_minutes > 0 else f"{remaining_seconds}秒"
|
||||
|
||||
line = f"{i}. 【{sc.price}元】{sc.user_nickname}: {sc.message_text}"
|
||||
if len(line) > 100: # 限制单行长度
|
||||
line = line[:97] + "..."
|
||||
line += f" (剩余{time_display})"
|
||||
lines.append(line)
|
||||
|
||||
if len(superchats) > max_count:
|
||||
lines.append(f"... 还有{len(superchats) - max_count}条SuperChat")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def build_superchat_summary_string(self, chat_id: str) -> str:
|
||||
"""构建SuperChat摘要字符串"""
|
||||
superchats = self.get_superchats_by_chat(chat_id)
|
||||
|
||||
if not superchats:
|
||||
return "当前没有有效的超级弹幕"
|
||||
lines = []
|
||||
for sc in superchats:
|
||||
single_sc_str = f"{sc.user_nickname} - {sc.price}元 - {sc.message_text}"
|
||||
if len(single_sc_str) > 100:
|
||||
single_sc_str = single_sc_str[:97] + "..."
|
||||
single_sc_str += f" (剩余{int(sc.remaining_time())}秒)"
|
||||
lines.append(single_sc_str)
|
||||
|
||||
total_amount = sum(sc.price for sc in superchats)
|
||||
count = len(superchats)
|
||||
highest_amount = max(sc.price for sc in superchats)
|
||||
|
||||
final_str = f"当前有{count}条超级弹幕,总金额{total_amount}元,最高单笔{highest_amount}元"
|
||||
if lines:
|
||||
final_str += "\n" + "\n".join(lines)
|
||||
return final_str
|
||||
|
||||
def get_superchat_statistics(self, chat_id: str) -> dict:
|
||||
"""获取SuperChat统计信息"""
|
||||
superchats = self.get_superchats_by_chat(chat_id)
|
||||
|
||||
if not superchats:
|
||||
return {
|
||||
"count": 0,
|
||||
"total_amount": 0,
|
||||
"average_amount": 0,
|
||||
"highest_amount": 0,
|
||||
"lowest_amount": 0
|
||||
}
|
||||
|
||||
amounts = [sc.price for sc in superchats]
|
||||
|
||||
return {
|
||||
"count": len(superchats),
|
||||
"total_amount": sum(amounts),
|
||||
"average_amount": sum(amounts) / len(amounts),
|
||||
"highest_amount": max(amounts),
|
||||
"lowest_amount": min(amounts)
|
||||
}
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭管理器,清理资源"""
|
||||
if self._cleanup_task and not self._cleanup_task.done():
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
logger.info("SuperChat管理器已关闭")
|
||||
|
||||
|
||||
# 全局SuperChat管理器实例
|
||||
super_chat_manager = SuperChatManager()
|
||||
|
||||
|
||||
def get_super_chat_manager() -> SuperChatManager:
|
||||
"""获取全局SuperChat管理器实例"""
|
||||
return super_chat_manager
|
||||
296
src/mais4u/s4u_config.py
Normal file
296
src/mais4u/s4u_config.py
Normal file
@@ -0,0 +1,296 @@
|
||||
import os
|
||||
import tomlkit
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from tomlkit import TOMLDocument
|
||||
from tomlkit.items import Table
|
||||
from dataclasses import dataclass, fields, MISSING
|
||||
from typing import TypeVar, Type, Any, get_origin, get_args, Literal
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("s4u_config")
|
||||
|
||||
# 获取mais4u模块目录
|
||||
MAIS4U_ROOT = os.path.dirname(__file__)
|
||||
CONFIG_DIR = os.path.join(MAIS4U_ROOT, "config")
|
||||
TEMPLATE_PATH = os.path.join(CONFIG_DIR, "s4u_config_template.toml")
|
||||
CONFIG_PATH = os.path.join(CONFIG_DIR, "s4u_config.toml")
|
||||
|
||||
# S4U配置版本
|
||||
S4U_VERSION = "1.0.0"
|
||||
|
||||
T = TypeVar("T", bound="S4UConfigBase")
|
||||
|
||||
|
||||
@dataclass
|
||||
class S4UConfigBase:
|
||||
"""S4U配置类的基类"""
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls: Type[T], data: dict[str, Any]) -> T:
|
||||
"""从字典加载配置字段"""
|
||||
if not isinstance(data, dict):
|
||||
raise TypeError(f"Expected a dictionary, got {type(data).__name__}")
|
||||
|
||||
init_args: dict[str, Any] = {}
|
||||
|
||||
for f in fields(cls):
|
||||
field_name = f.name
|
||||
|
||||
if field_name.startswith("_"):
|
||||
# 跳过以 _ 开头的字段
|
||||
continue
|
||||
|
||||
if field_name not in data:
|
||||
if f.default is not MISSING or f.default_factory is not MISSING:
|
||||
# 跳过未提供且有默认值/默认构造方法的字段
|
||||
continue
|
||||
else:
|
||||
raise ValueError(f"Missing required field: '{field_name}'")
|
||||
|
||||
value = data[field_name]
|
||||
field_type = f.type
|
||||
|
||||
try:
|
||||
init_args[field_name] = cls._convert_field(value, field_type) # type: ignore
|
||||
except TypeError as e:
|
||||
raise TypeError(f"Field '{field_name}' has a type error: {e}") from e
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to convert field '{field_name}' to target type: {e}") from e
|
||||
|
||||
return cls(**init_args)
|
||||
|
||||
@classmethod
|
||||
def _convert_field(cls, value: Any, field_type: Type[Any]) -> Any:
|
||||
"""转换字段值为指定类型"""
|
||||
# 如果是嵌套的 dataclass,递归调用 from_dict 方法
|
||||
if isinstance(field_type, type) and issubclass(field_type, S4UConfigBase):
|
||||
if not isinstance(value, dict):
|
||||
raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
|
||||
return field_type.from_dict(value)
|
||||
|
||||
# 处理泛型集合类型(list, set, tuple)
|
||||
field_origin_type = get_origin(field_type)
|
||||
field_type_args = get_args(field_type)
|
||||
|
||||
if field_origin_type in {list, set, tuple}:
|
||||
if not isinstance(value, list):
|
||||
raise TypeError(f"Expected an list for {field_type.__name__}, got {type(value).__name__}")
|
||||
|
||||
if field_origin_type is list:
|
||||
if (
|
||||
field_type_args
|
||||
and isinstance(field_type_args[0], type)
|
||||
and issubclass(field_type_args[0], S4UConfigBase)
|
||||
):
|
||||
return [field_type_args[0].from_dict(item) for item in value]
|
||||
return [cls._convert_field(item, field_type_args[0]) for item in value]
|
||||
elif field_origin_type is set:
|
||||
return {cls._convert_field(item, field_type_args[0]) for item in value}
|
||||
elif field_origin_type is tuple:
|
||||
if len(value) != len(field_type_args):
|
||||
raise TypeError(
|
||||
f"Expected {len(field_type_args)} items for {field_type.__name__}, got {len(value)}"
|
||||
)
|
||||
return tuple(cls._convert_field(item, arg) for item, arg in zip(value, field_type_args, strict=False))
|
||||
|
||||
if field_origin_type is dict:
|
||||
if not isinstance(value, dict):
|
||||
raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
|
||||
|
||||
if len(field_type_args) != 2:
|
||||
raise TypeError(f"Expected a dictionary with two type arguments for {field_type.__name__}")
|
||||
key_type, value_type = field_type_args
|
||||
|
||||
return {cls._convert_field(k, key_type): cls._convert_field(v, value_type) for k, v in value.items()}
|
||||
|
||||
# 处理基础类型,例如 int, str 等
|
||||
if field_origin_type is type(None) and value is None: # 处理Optional类型
|
||||
return None
|
||||
|
||||
# 处理Literal类型
|
||||
if field_origin_type is Literal or get_origin(field_type) is Literal:
|
||||
allowed_values = get_args(field_type)
|
||||
if value in allowed_values:
|
||||
return value
|
||||
else:
|
||||
raise TypeError(f"Value '{value}' is not in allowed values {allowed_values} for Literal type")
|
||||
|
||||
if field_type is Any or isinstance(value, field_type):
|
||||
return value
|
||||
|
||||
# 其他类型,尝试直接转换
|
||||
try:
|
||||
return field_type(value)
|
||||
except (ValueError, TypeError) as e:
|
||||
raise TypeError(f"Cannot convert {type(value).__name__} to {field_type.__name__}") from e
|
||||
|
||||
|
||||
@dataclass
|
||||
class S4UConfig(S4UConfigBase):
|
||||
"""S4U聊天系统配置类"""
|
||||
|
||||
message_timeout_seconds: int = 120
|
||||
"""普通消息存活时间(秒),超过此时间的消息将被丢弃"""
|
||||
|
||||
at_bot_priority_bonus: float = 100.0
|
||||
"""@机器人时的优先级加成分数"""
|
||||
|
||||
recent_message_keep_count: int = 6
|
||||
"""保留最近N条消息,超出范围的普通消息将被移除"""
|
||||
|
||||
typing_delay: float = 0.1
|
||||
"""打字延迟时间(秒),模拟真实打字速度"""
|
||||
|
||||
chars_per_second: float = 15.0
|
||||
"""每秒字符数,用于计算动态打字延迟"""
|
||||
|
||||
min_typing_delay: float = 0.2
|
||||
"""最小打字延迟(秒)"""
|
||||
|
||||
max_typing_delay: float = 2.0
|
||||
"""最大打字延迟(秒)"""
|
||||
|
||||
enable_dynamic_typing_delay: bool = False
|
||||
"""是否启用基于文本长度的动态打字延迟"""
|
||||
|
||||
vip_queue_priority: bool = True
|
||||
"""是否启用VIP队列优先级系统"""
|
||||
|
||||
enable_message_interruption: bool = True
|
||||
"""是否允许高优先级消息中断当前回复"""
|
||||
|
||||
enable_old_message_cleanup: bool = True
|
||||
"""是否自动清理过旧的普通消息"""
|
||||
|
||||
enable_loading_indicator: bool = True
|
||||
"""是否显示加载提示"""
|
||||
|
||||
max_context_message_length: int = 20
|
||||
"""上下文消息最大长度"""
|
||||
|
||||
max_core_message_length: int = 30
|
||||
"""核心消息最大长度"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class S4UGlobalConfig(S4UConfigBase):
|
||||
"""S4U总配置类"""
|
||||
|
||||
s4u: S4UConfig
|
||||
S4U_VERSION: str = S4U_VERSION
|
||||
|
||||
|
||||
def update_s4u_config():
|
||||
"""更新S4U配置文件"""
|
||||
# 创建配置目录(如果不存在)
|
||||
os.makedirs(CONFIG_DIR, exist_ok=True)
|
||||
|
||||
# 检查模板文件是否存在
|
||||
if not os.path.exists(TEMPLATE_PATH):
|
||||
logger.error(f"S4U配置模板文件不存在: {TEMPLATE_PATH}")
|
||||
logger.error("请确保模板文件存在后重新运行")
|
||||
raise FileNotFoundError(f"S4U配置模板文件不存在: {TEMPLATE_PATH}")
|
||||
|
||||
# 检查配置文件是否存在
|
||||
if not os.path.exists(CONFIG_PATH):
|
||||
logger.info("S4U配置文件不存在,从模板创建新配置")
|
||||
shutil.copy2(TEMPLATE_PATH, CONFIG_PATH)
|
||||
logger.info(f"已创建S4U配置文件: {CONFIG_PATH}")
|
||||
return
|
||||
|
||||
# 读取旧配置文件和模板文件
|
||||
with open(CONFIG_PATH, "r", encoding="utf-8") as f:
|
||||
old_config = tomlkit.load(f)
|
||||
with open(TEMPLATE_PATH, "r", encoding="utf-8") as f:
|
||||
new_config = tomlkit.load(f)
|
||||
|
||||
# 检查version是否相同
|
||||
if old_config and "inner" in old_config and "inner" in new_config:
|
||||
old_version = old_config["inner"].get("version") # type: ignore
|
||||
new_version = new_config["inner"].get("version") # type: ignore
|
||||
if old_version and new_version and old_version == new_version:
|
||||
logger.info(f"检测到S4U配置文件版本号相同 (v{old_version}),跳过更新")
|
||||
return
|
||||
else:
|
||||
logger.info(f"检测到S4U配置版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}")
|
||||
else:
|
||||
logger.info("S4U配置文件未检测到版本号,可能是旧版本。将进行更新")
|
||||
|
||||
# 创建备份目录
|
||||
old_config_dir = os.path.join(CONFIG_DIR, "old")
|
||||
os.makedirs(old_config_dir, exist_ok=True)
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
old_backup_path = os.path.join(old_config_dir, f"s4u_config_{timestamp}.toml")
|
||||
|
||||
# 移动旧配置文件到old目录
|
||||
shutil.move(CONFIG_PATH, old_backup_path)
|
||||
logger.info(f"已备份旧S4U配置文件到: {old_backup_path}")
|
||||
|
||||
# 复制模板文件到配置目录
|
||||
shutil.copy2(TEMPLATE_PATH, CONFIG_PATH)
|
||||
logger.info(f"已创建新S4U配置文件: {CONFIG_PATH}")
|
||||
|
||||
def update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict):
|
||||
"""
|
||||
将source字典的值更新到target字典中(如果target中存在相同的键)
|
||||
"""
|
||||
for key, value in source.items():
|
||||
# 跳过version字段的更新
|
||||
if key == "version":
|
||||
continue
|
||||
if key in target:
|
||||
target_value = target[key]
|
||||
if isinstance(value, dict) and isinstance(target_value, (dict, Table)):
|
||||
update_dict(target_value, value)
|
||||
else:
|
||||
try:
|
||||
# 对数组类型进行特殊处理
|
||||
if isinstance(value, list):
|
||||
target[key] = tomlkit.array(str(value)) if value else tomlkit.array()
|
||||
else:
|
||||
# 其他类型使用item方法创建新值
|
||||
target[key] = tomlkit.item(value)
|
||||
except (TypeError, ValueError):
|
||||
# 如果转换失败,直接赋值
|
||||
target[key] = value
|
||||
|
||||
# 将旧配置的值更新到新配置中
|
||||
logger.info("开始合并S4U新旧配置...")
|
||||
update_dict(new_config, old_config)
|
||||
|
||||
# 保存更新后的配置(保留注释和格式)
|
||||
with open(CONFIG_PATH, "w", encoding="utf-8") as f:
|
||||
f.write(tomlkit.dumps(new_config))
|
||||
|
||||
logger.info("S4U配置文件更新完成")
|
||||
|
||||
|
||||
def load_s4u_config(config_path: str) -> S4UGlobalConfig:
|
||||
"""
|
||||
加载S4U配置文件
|
||||
:param config_path: 配置文件路径
|
||||
:return: S4UGlobalConfig对象
|
||||
"""
|
||||
# 读取配置文件
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = tomlkit.load(f)
|
||||
|
||||
# 创建S4UGlobalConfig对象
|
||||
try:
|
||||
return S4UGlobalConfig.from_dict(config_data)
|
||||
except Exception as e:
|
||||
logger.critical("S4U配置文件解析失败")
|
||||
raise e
|
||||
|
||||
|
||||
# 初始化S4U配置
|
||||
logger.info(f"S4U当前版本: {S4U_VERSION}")
|
||||
update_s4u_config()
|
||||
|
||||
logger.info("正在加载S4U配置文件...")
|
||||
s4u_config_main = load_s4u_config(config_path=CONFIG_PATH)
|
||||
logger.info("S4U配置文件加载完成!")
|
||||
|
||||
s4u_config: S4UConfig = s4u_config_main.s4u
|
||||
@@ -161,6 +161,60 @@ class PersonInfoManager:
|
||||
|
||||
await asyncio.to_thread(_db_create_sync, final_data)
|
||||
|
||||
async def _safe_create_person_info(self, person_id: str, data: Optional[dict] = None):
|
||||
"""安全地创建用户信息,处理竞态条件"""
|
||||
if not person_id:
|
||||
logger.debug("创建失败,person_id不存在")
|
||||
return
|
||||
|
||||
_person_info_default = copy.deepcopy(person_info_default)
|
||||
model_fields = PersonInfo._meta.fields.keys() # type: ignore
|
||||
|
||||
final_data = {"person_id": person_id}
|
||||
|
||||
# Start with defaults for all model fields
|
||||
for key, default_value in _person_info_default.items():
|
||||
if key in model_fields:
|
||||
final_data[key] = default_value
|
||||
|
||||
# Override with provided data
|
||||
if data:
|
||||
for key, value in data.items():
|
||||
if key in model_fields:
|
||||
final_data[key] = value
|
||||
|
||||
# Ensure person_id is correctly set from the argument
|
||||
final_data["person_id"] = person_id
|
||||
|
||||
# Serialize JSON fields
|
||||
for key in JSON_SERIALIZED_FIELDS:
|
||||
if key in final_data:
|
||||
if isinstance(final_data[key], (list, dict)):
|
||||
final_data[key] = json.dumps(final_data[key], ensure_ascii=False)
|
||||
elif final_data[key] is None: # Default for lists is [], store as "[]"
|
||||
final_data[key] = json.dumps([], ensure_ascii=False)
|
||||
|
||||
def _db_safe_create_sync(p_data: dict):
|
||||
try:
|
||||
# 首先检查是否已存在
|
||||
existing = PersonInfo.get_or_none(PersonInfo.person_id == p_data["person_id"])
|
||||
if existing:
|
||||
logger.debug(f"用户 {p_data['person_id']} 已存在,跳过创建")
|
||||
return True
|
||||
|
||||
# 尝试创建
|
||||
PersonInfo.create(**p_data)
|
||||
return True
|
||||
except Exception as e:
|
||||
if "UNIQUE constraint failed" in str(e):
|
||||
logger.debug(f"检测到并发创建用户 {p_data.get('person_id')},跳过错误")
|
||||
return True # 其他协程已创建,视为成功
|
||||
else:
|
||||
logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (Peewee): {e}")
|
||||
return False
|
||||
|
||||
await asyncio.to_thread(_db_safe_create_sync, final_data)
|
||||
|
||||
async def update_one_field(self, person_id: str, field_name: str, value, data: Optional[Dict] = None):
|
||||
"""更新某一个字段,会补全"""
|
||||
if field_name not in PersonInfo._meta.fields: # type: ignore
|
||||
@@ -221,7 +275,8 @@ class PersonInfoManager:
|
||||
if data and "user_id" in data:
|
||||
creation_data["user_id"] = data["user_id"]
|
||||
|
||||
await self.create_person_info(person_id, creation_data)
|
||||
# 使用安全的创建方法,处理竞态条件
|
||||
await self._safe_create_person_info(person_id, creation_data)
|
||||
|
||||
@staticmethod
|
||||
async def has_one_field(person_id: str, field_name: str):
|
||||
@@ -529,16 +584,31 @@ class PersonInfoManager:
|
||||
"""
|
||||
根据 platform 和 user_id 获取 person_id。
|
||||
如果对应的用户不存在,则使用提供的可选信息创建新用户。
|
||||
使用try-except处理竞态条件,避免重复创建错误。
|
||||
"""
|
||||
person_id = self.get_person_id(platform, user_id)
|
||||
|
||||
def _db_check_exists_sync(p_id: str):
|
||||
return PersonInfo.get_or_none(PersonInfo.person_id == p_id)
|
||||
def _db_get_or_create_sync(p_id: str, init_data: dict):
|
||||
"""原子性的获取或创建操作"""
|
||||
# 首先尝试获取现有记录
|
||||
record = PersonInfo.get_or_none(PersonInfo.person_id == p_id)
|
||||
if record:
|
||||
return record, False # 记录存在,未创建
|
||||
|
||||
record = await asyncio.to_thread(_db_check_exists_sync, person_id)
|
||||
# 记录不存在,尝试创建
|
||||
try:
|
||||
PersonInfo.create(**init_data)
|
||||
return PersonInfo.get(PersonInfo.person_id == p_id), True # 创建成功
|
||||
except Exception as e:
|
||||
# 如果创建失败(可能是因为竞态条件),再次尝试获取
|
||||
if "UNIQUE constraint failed" in str(e):
|
||||
logger.debug(f"检测到并发创建用户 {p_id},获取现有记录")
|
||||
record = PersonInfo.get_or_none(PersonInfo.person_id == p_id)
|
||||
if record:
|
||||
return record, False # 其他协程已创建,返回现有记录
|
||||
# 如果仍然失败,重新抛出异常
|
||||
raise e
|
||||
|
||||
if record is None:
|
||||
logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录 (Peewee)。")
|
||||
unique_nickname = await self._generate_unique_person_name(nickname)
|
||||
initial_data = {
|
||||
"person_id": person_id,
|
||||
@@ -554,11 +624,25 @@ class PersonInfoManager:
|
||||
"points": [],
|
||||
"forgotten_points": [],
|
||||
}
|
||||
|
||||
# 序列化JSON字段
|
||||
for key in JSON_SERIALIZED_FIELDS:
|
||||
if key in initial_data:
|
||||
if isinstance(initial_data[key], (list, dict)):
|
||||
initial_data[key] = json.dumps(initial_data[key], ensure_ascii=False)
|
||||
elif initial_data[key] is None:
|
||||
initial_data[key] = json.dumps([], ensure_ascii=False)
|
||||
|
||||
model_fields = PersonInfo._meta.fields.keys() # type: ignore
|
||||
filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields}
|
||||
|
||||
await self.create_person_info(person_id, data=filtered_initial_data)
|
||||
record, was_created = await asyncio.to_thread(_db_get_or_create_sync, person_id, filtered_initial_data)
|
||||
|
||||
if was_created:
|
||||
logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录 (Peewee)。")
|
||||
logger.info(f"已为 {person_id} 创建新记录,初始数据 (filtered for model): {filtered_initial_data}")
|
||||
else:
|
||||
logger.debug(f"用户 {platform}:{user_id} (person_id: {person_id}) 已存在,返回现有记录。")
|
||||
|
||||
return person_id
|
||||
|
||||
|
||||
@@ -60,9 +60,9 @@ class RelationshipBuilder:
|
||||
# 获取聊天名称用于日志
|
||||
try:
|
||||
chat_name = get_chat_manager().get_stream_name(self.chat_id)
|
||||
self.log_prefix = f"[{chat_name}] 关系构建"
|
||||
self.log_prefix = f"[{chat_name}]"
|
||||
except Exception:
|
||||
self.log_prefix = f"[{self.chat_id}] 关系构建"
|
||||
self.log_prefix = f"[{self.chat_id}]"
|
||||
|
||||
# 加载持久化的缓存
|
||||
self._load_cache()
|
||||
@@ -349,11 +349,14 @@ class RelationshipBuilder:
|
||||
# 统筹各模块协作、对外提供服务接口
|
||||
# ================================
|
||||
|
||||
async def build_relation(self):
|
||||
"""构建关系"""
|
||||
async def build_relation(self,immediate_build: str = "",max_build_threshold: int = MAX_MESSAGE_COUNT):
|
||||
"""构建关系
|
||||
immediate_build: 立即构建关系,可选值为"all"或person_id
|
||||
"""
|
||||
self._cleanup_old_segments()
|
||||
current_time = time.time()
|
||||
|
||||
|
||||
if latest_messages := get_raw_msg_by_timestamp_with_chat(
|
||||
self.chat_id,
|
||||
self.last_processed_message_time,
|
||||
@@ -374,7 +377,7 @@ class RelationshipBuilder:
|
||||
):
|
||||
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||
self._update_message_segments(person_id, msg_time)
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"{self.log_prefix} 更新用户 {person_id} 的消息段,消息时间:{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(msg_time))}"
|
||||
)
|
||||
self.last_processed_message_time = max(self.last_processed_message_time, msg_time)
|
||||
@@ -383,15 +386,17 @@ class RelationshipBuilder:
|
||||
users_to_build_relationship = []
|
||||
for person_id, segments in self.person_engaged_cache.items():
|
||||
total_message_count = self._get_total_message_count(person_id)
|
||||
if total_message_count >= MAX_MESSAGE_COUNT:
|
||||
person_name = get_person_info_manager().get_value_sync(person_id, "person_name") or person_id
|
||||
|
||||
if total_message_count >= max_build_threshold or (total_message_count >= 5 and (immediate_build == person_id or immediate_build == "all")):
|
||||
users_to_build_relationship.append(person_id)
|
||||
logger.debug(
|
||||
f"{self.log_prefix} 用户 {person_id} 满足关系构建条件,总消息数:{total_message_count},消息段数:{len(segments)}"
|
||||
logger.info(
|
||||
f"{self.log_prefix} 用户 {person_name} 满足关系构建条件,总消息数:{total_message_count},消息段数:{len(segments)}"
|
||||
)
|
||||
elif total_message_count > 0:
|
||||
# 记录进度信息
|
||||
logger.debug(
|
||||
f"{self.log_prefix} 用户 {person_id} 进度:{total_message_count}60 条消息,{len(segments)} 个消息段"
|
||||
logger.info(
|
||||
f"{self.log_prefix} 用户 {person_name} 进度:{total_message_count}/60 条消息,{len(segments)} 个消息段"
|
||||
)
|
||||
|
||||
# 2. 为满足条件的用户构建关系
|
||||
@@ -405,6 +410,7 @@ class RelationshipBuilder:
|
||||
del self.person_engaged_cache[person_id]
|
||||
self._save_cache()
|
||||
|
||||
|
||||
# ================================
|
||||
# 关系构建模块
|
||||
# 负责触发关系构建、整合消息段、更新用户印象
|
||||
@@ -413,7 +419,7 @@ class RelationshipBuilder:
|
||||
async def update_impression_on_segments(self, person_id: str, chat_id: str, segments: List[Dict[str, Any]]):
|
||||
"""基于消息段更新用户印象"""
|
||||
original_segment_count = len(segments)
|
||||
logger.debug(f"开始为 {person_id} 基于 {original_segment_count} 个消息段更新印象")
|
||||
logger.info(f"开始为 {person_id} 基于 {original_segment_count} 个消息段更新印象")
|
||||
try:
|
||||
# 筛选要处理的消息段,每个消息段有10%的概率被丢弃
|
||||
segments_to_process = [s for s in segments if random.random() >= 0.1]
|
||||
|
||||
@@ -44,8 +44,8 @@ class RelationshipManager:
|
||||
"konw_time": int(time.time()),
|
||||
"person_name": unique_nickname, # 使用唯一的 person_name
|
||||
}
|
||||
# 先创建用户基本信息
|
||||
await person_info_manager.create_person_info(person_id=person_id, data=data)
|
||||
# 先创建用户基本信息,使用安全创建方法避免竞态条件
|
||||
await person_info_manager._safe_create_person_info(person_id=person_id, data=data)
|
||||
# 更新昵称
|
||||
await person_info_manager.update_one_field(
|
||||
person_id=person_id, field_name="nickname", value=user_nickname, data=data
|
||||
|
||||
@@ -250,7 +250,6 @@ async def _find_reply_message(target_stream, reply_to: str) -> Optional[MessageR
|
||||
message_dict = {
|
||||
"message_info": message_info,
|
||||
"raw_message": find_msg.get("processed_plain_text"),
|
||||
"detailed_plain_text": find_msg.get("processed_plain_text"),
|
||||
"processed_plain_text": find_msg.get("processed_plain_text"),
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user