feat:添加新的message类为s4u服务,添加s4u config,添加sc和gift的解析,修复关系构建的一些问题

This commit is contained in:
SengokuCola
2025-07-15 17:04:30 +08:00
parent 4ebcf4e056
commit 5ec0d42cde
22 changed files with 1371 additions and 136 deletions

View File

@@ -205,7 +205,6 @@ class MongoToSQLiteMigrator:
"user_info.user_nickname": "user_nickname", "user_info.user_nickname": "user_nickname",
"user_info.user_cardname": "user_cardname", "user_info.user_cardname": "user_cardname",
"processed_plain_text": "processed_plain_text", "processed_plain_text": "processed_plain_text",
"detailed_plain_text": "detailed_plain_text",
"memorized_times": "memorized_times", "memorized_times": "memorized_times",
}, },
enable_validation=False, # 禁用数据验证 enable_validation=False, # 禁用数据验证

View File

@@ -9,7 +9,7 @@ from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config
from src.mood.mood_manager import mood_manager # 导入情绪管理器 from src.mood.mood_manager import mood_manager # 导入情绪管理器
from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream
from src.chat.message_receive.message import MessageRecv from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
from src.chat.message_receive.storage import MessageStorage from src.chat.message_receive.storage import MessageStorage
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
@@ -141,6 +141,29 @@ class ChatBot:
logger.error(f"处理命令时出错: {e}") logger.error(f"处理命令时出错: {e}")
return False, None, True # 出错时继续处理消息 return False, None, True # 出错时继续处理消息
async def do_s4u(self, message_data: Dict[str, Any]):
message = MessageRecvS4U(message_data)
group_info = message.message_info.group_info
user_info = message.message_info.user_info
get_chat_manager().register_message(message)
chat = await get_chat_manager().get_or_create_stream(
platform=message.message_info.platform, # type: ignore
user_info=user_info, # type: ignore
group_info=group_info,
)
message.update_chat_stream(chat)
# 处理消息内容
await message.process()
await self.s4u_message_processor.process_message(message)
return
async def message_process(self, message_data: Dict[str, Any]) -> None: async def message_process(self, message_data: Dict[str, Any]) -> None:
"""处理转化后的统一格式消息 """处理转化后的统一格式消息
这个函数本质是预处理一些数据,根据配置信息和消息内容,预处理消息,并分发到合适的消息处理器中 这个函数本质是预处理一些数据,根据配置信息和消息内容,预处理消息,并分发到合适的消息处理器中
@@ -159,6 +182,10 @@ class ChatBot:
# 确保所有任务已启动 # 确保所有任务已启动
await self._ensure_started() await self._ensure_started()
if ENABLE_S4U_CHAT:
await self.do_s4u(message_data)
return
if message_data["message_info"].get("group_info") is not None: if message_data["message_info"].get("group_info") is not None:
message_data["message_info"]["group_info"]["group_id"] = str( message_data["message_info"]["group_info"]["group_id"] = str(
message_data["message_info"]["group_info"]["group_id"] message_data["message_info"]["group_info"]["group_id"]
@@ -221,11 +248,6 @@ class ChatBot:
template_group_name = None template_group_name = None
async def preprocess(): async def preprocess():
if ENABLE_S4U_CHAT:
logger.info("进入S4U流程")
await self.s4u_message_processor.process_message(message)
return
await self.heartflow_message_receiver.process_message(message) await self.heartflow_message_receiver.process_message(message)
if template_group_name: if template_group_name:

View File

@@ -38,7 +38,6 @@ class Message(MessageBase):
message_segment: Optional[Seg] = None, message_segment: Optional[Seg] = None,
timestamp: Optional[float] = None, timestamp: Optional[float] = None,
reply: Optional["MessageRecv"] = None, reply: Optional["MessageRecv"] = None,
detailed_plain_text: str = "",
processed_plain_text: str = "", processed_plain_text: str = "",
): ):
# 使用传入的时间戳或当前时间 # 使用传入的时间戳或当前时间
@@ -58,7 +57,6 @@ class Message(MessageBase):
self.chat_stream = chat_stream self.chat_stream = chat_stream
# 文本处理相关属性 # 文本处理相关属性
self.processed_plain_text = processed_plain_text self.processed_plain_text = processed_plain_text
self.detailed_plain_text = detailed_plain_text
# 回复消息 # 回复消息
self.reply = reply self.reply = reply
@@ -104,7 +102,6 @@ class MessageRecv(Message):
self.message_segment = Seg.from_dict(message_dict.get("message_segment", {})) self.message_segment = Seg.from_dict(message_dict.get("message_segment", {}))
self.raw_message = message_dict.get("raw_message") self.raw_message = message_dict.get("raw_message")
self.processed_plain_text = message_dict.get("processed_plain_text", "") self.processed_plain_text = message_dict.get("processed_plain_text", "")
self.detailed_plain_text = message_dict.get("detailed_plain_text", "")
self.is_emoji = False self.is_emoji = False
self.has_emoji = False self.has_emoji = False
self.is_picid = False self.is_picid = False
@@ -123,7 +120,6 @@ class MessageRecv(Message):
这个方法必须在创建实例后显式调用,因为它包含异步操作。 这个方法必须在创建实例后显式调用,因为它包含异步操作。
""" """
self.processed_plain_text = await self._process_message_segments(self.message_segment) self.processed_plain_text = await self._process_message_segments(self.message_segment)
self.detailed_plain_text = self._generate_detailed_text()
async def _process_single_segment(self, segment: Seg) -> str: async def _process_single_segment(self, segment: Seg) -> str:
"""处理单个消息段 """处理单个消息段
@@ -182,12 +178,97 @@ class MessageRecv(Message):
logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}") logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}")
return f"[处理失败的{segment.type}消息]" return f"[处理失败的{segment.type}消息]"
def _generate_detailed_text(self) -> str: @dataclass
"""生成详细文本,包含时间和用户信息""" class MessageRecvS4U(MessageRecv):
timestamp = self.message_info.time def __init__(self, message_dict: dict[str, Any]):
user_info = self.message_info.user_info super().__init__(message_dict)
name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>" # type: ignore self.is_gift = False
return f"[{timestamp}] {name}: {self.processed_plain_text}\n" self.is_superchat = False
self.gift_info = None
self.gift_name = None
self.gift_count = None
self.superchat_info = None
self.superchat_price = None
self.superchat_message_text = None
async def process(self) -> None:
self.processed_plain_text = await self._process_message_segments(self.message_segment)
async def _process_single_segment(self, segment: Seg) -> str:
"""处理单个消息段
Args:
segment: 消息段
Returns:
str: 处理后的文本
"""
try:
if segment.type == "text":
self.is_picid = False
self.is_emoji = False
return segment.data # type: ignore
elif segment.type == "image":
# 如果是base64图片数据
if isinstance(segment.data, str):
self.has_picid = True
self.is_picid = True
self.is_emoji = False
image_manager = get_image_manager()
# print(f"segment.data: {segment.data}")
_, processed_text = await image_manager.process_image(segment.data)
return processed_text
return "[发了一张图片,网卡了加载不出来]"
elif segment.type == "emoji":
self.has_emoji = True
self.is_emoji = True
self.is_picid = False
if isinstance(segment.data, str):
return await get_image_manager().get_emoji_description(segment.data)
return "[发了一个表情包,网卡了加载不出来]"
elif segment.type == "mention_bot":
self.is_picid = False
self.is_emoji = False
self.is_mentioned = float(segment.data) # type: ignore
return ""
elif segment.type == "priority_info":
self.is_picid = False
self.is_emoji = False
if isinstance(segment.data, dict):
# 处理优先级信息
self.priority_mode = "priority"
self.priority_info = segment.data
"""
{
'message_type': 'vip', # vip or normal
'message_priority': 1.0, # 优先级大为优先float
}
"""
return ""
elif segment.type == "gift":
self.is_gift = True
# 解析gift_info格式为"名称:数量"
name, count = segment.data.split(":", 1)
self.gift_info = segment.data
self.gift_name = name.strip()
self.gift_count = int(count.strip())
return ""
elif segment.type == "superchat":
self.is_superchat = True
self.superchat_info = segment.data
price,message_text = segment.data.split(":", 1)
self.superchat_price = price.strip()
self.superchat_message_text = message_text.strip()
self.processed_plain_text = str(self.superchat_message_text)
self.processed_plain_text += f"(注意:这是一条超级弹幕信息,价值{self.superchat_price}元,请你认真回复)"
return self.processed_plain_text
else:
return ""
except Exception as e:
logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}")
return f"[处理失败的{segment.type}消息]"
@dataclass @dataclass
@@ -472,7 +553,6 @@ def message_from_db_dict(db_dict: dict) -> MessageRecv:
"message_segment": {"type": "text", "data": processed_text}, # 从纯文本重建消息段 "message_segment": {"type": "text", "data": processed_text}, # 从纯文本重建消息段
"raw_message": None, # 数据库中未存储原始消息 "raw_message": None, # 数据库中未存储原始消息
"processed_plain_text": processed_text, "processed_plain_text": processed_text,
"detailed_plain_text": db_dict.get("detailed_plain_text", ""),
} }
# 创建 MessageRecv 实例 # 创建 MessageRecv 实例

View File

@@ -121,27 +121,6 @@ async def get_embedding(text, request_type="embedding"):
return embedding return embedding
def get_recent_group_detailed_plain_text(chat_stream_id: str, limit: int = 12, combine=False):
filter_query = {"chat_id": chat_stream_id}
sort_order = [("time", -1)]
recent_messages = find_messages(message_filter=filter_query, sort=sort_order, limit=limit)
if not recent_messages:
return []
# 反转消息列表,使最新的消息在最后
recent_messages.reverse()
if combine:
return "".join(str(msg_db_data["detailed_plain_text"]) for msg_db_data in recent_messages)
message_detailed_plain_text_list = []
for msg_db_data in recent_messages:
message_detailed_plain_text_list.append(msg_db_data["detailed_plain_text"])
return message_detailed_plain_text_list
def get_recent_group_speaker(chat_stream_id: str, sender, limit: int = 12) -> list: def get_recent_group_speaker(chat_stream_id: str, sender, limit: int = 12) -> list:
# 获取当前群聊记录内发言的人 # 获取当前群聊记录内发言的人
filter_query = {"chat_id": chat_stream_id} filter_query = {"chat_id": chat_stream_id}

View File

@@ -153,7 +153,6 @@ class Messages(BaseModel):
processed_plain_text = TextField(null=True) # 处理后的纯文本消息 processed_plain_text = TextField(null=True) # 处理后的纯文本消息
display_message = TextField(null=True) # 显示的消息 display_message = TextField(null=True) # 显示的消息
detailed_plain_text = TextField(null=True) # 详细的纯文本消息
memorized_times = IntegerField(default=0) # 被记忆的次数 memorized_times = IntegerField(default=0) # 被记忆的次数
priority_mode = TextField(null=True) priority_mode = TextField(null=True)

View File

@@ -403,6 +403,10 @@ MODULE_COLORS = {
"model_utils": "\033[38;5;164m", # 紫红色 "model_utils": "\033[38;5;164m", # 紫红色
"relationship_fetcher": "\033[38;5;170m", # 浅紫色 "relationship_fetcher": "\033[38;5;170m", # 浅紫色
"relationship_builder": "\033[38;5;93m", # 浅蓝色 "relationship_builder": "\033[38;5;93m", # 浅蓝色
#s4u
"context_web_api": "\033[38;5;240m", # 深灰色
"S4U_chat": "\033[92m", # 深灰色
} }
RESET_COLOR = "\033[0m" RESET_COLOR = "\033[0m"

View File

@@ -0,0 +1,36 @@
[inner]
version = "1.0.0"
#----以下是S4U聊天系统配置文件----
# S4U (Smart 4 U) 聊天系统是MaiBot的核心对话模块
# 支持优先级队列、消息中断、VIP用户等高级功能
#
# 如果你想要修改配置文件请在修改后将version的值进行变更
# 如果新增项目请参考src/mais4u/s4u_config.py中的S4UConfig类
#
# 版本格式:主版本号.次版本号.修订号
#----S4U配置说明结束----
[s4u]
# 消息管理配置
message_timeout_seconds = 120 # 普通消息存活时间(秒),超过此时间的消息将被丢弃
recent_message_keep_count = 6 # 保留最近N条消息超出范围的普通消息将被移除
# 优先级系统配置
at_bot_priority_bonus = 100.0 # @机器人时的优先级加成分数
vip_queue_priority = true # 是否启用VIP队列优先级系统
enable_message_interruption = true # 是否允许高优先级消息中断当前回复
# 打字效果配置
typing_delay = 0.1 # 打字延迟时间(秒),模拟真实打字速度
enable_dynamic_typing_delay = false # 是否启用基于文本长度的动态打字延迟
# 动态打字延迟参数仅在enable_dynamic_typing_delay=true时生效
chars_per_second = 15.0 # 每秒字符数,用于计算动态打字延迟
min_typing_delay = 0.2 # 最小打字延迟(秒)
max_typing_delay = 2.0 # 最大打字延迟(秒)
# 系统功能开关
enable_old_message_cleanup = true # 是否自动清理过旧的普通消息
enable_loading_indicator = true # 是否显示加载提示

View File

@@ -0,0 +1,38 @@
[inner]
version = "1.0.1"
#----以下是S4U聊天系统配置文件----
# S4U (Smart 4 U) 聊天系统是MaiBot的核心对话模块
# 支持优先级队列、消息中断、VIP用户等高级功能
#
# 如果你想要修改配置文件请在修改后将version的值进行变更
# 如果新增项目请参考src/mais4u/s4u_config.py中的S4UConfig类
#
# 版本格式:主版本号.次版本号.修订号
#----S4U配置说明结束----
[s4u]
# 消息管理配置
message_timeout_seconds = 120 # 普通消息存活时间(秒),超过此时间的消息将被丢弃
recent_message_keep_count = 6 # 保留最近N条消息超出范围的普通消息将被移除
# 优先级系统配置
at_bot_priority_bonus = 100.0 # @机器人时的优先级加成分数
vip_queue_priority = true # 是否启用VIP队列优先级系统
enable_message_interruption = true # 是否允许高优先级消息中断当前回复
# 打字效果配置
typing_delay = 0.1 # 打字延迟时间(秒),模拟真实打字速度
enable_dynamic_typing_delay = false # 是否启用基于文本长度的动态打字延迟
# 动态打字延迟参数仅在enable_dynamic_typing_delay=true时生效
chars_per_second = 15.0 # 每秒字符数,用于计算动态打字延迟
min_typing_delay = 0.2 # 最小打字延迟(秒)
max_typing_delay = 2.0 # 最大打字延迟(秒)
# 系统功能开关
enable_old_message_cleanup = true # 是否自动清理过旧的普通消息
enable_loading_indicator = true # 是否显示加载提示
max_context_message_length = 20
max_core_message_length = 30

View File

@@ -0,0 +1,38 @@
[inner]
version = "1.0.1"
#----以下是S4U聊天系统配置文件----
# S4U (Smart 4 U) 聊天系统是MaiBot的核心对话模块
# 支持优先级队列、消息中断、VIP用户等高级功能
#
# 如果你想要修改配置文件请在修改后将version的值进行变更
# 如果新增项目请参考src/mais4u/s4u_config.py中的S4UConfig类
#
# 版本格式:主版本号.次版本号.修订号
#----S4U配置说明结束----
[s4u]
# 消息管理配置
message_timeout_seconds = 120 # 普通消息存活时间(秒),超过此时间的消息将被丢弃
recent_message_keep_count = 6 # 保留最近N条消息超出范围的普通消息将被移除
# 优先级系统配置
at_bot_priority_bonus = 100.0 # @机器人时的优先级加成分数
vip_queue_priority = true # 是否启用VIP队列优先级系统
enable_message_interruption = true # 是否允许高优先级消息中断当前回复
# 打字效果配置
typing_delay = 0.1 # 打字延迟时间(秒),模拟真实打字速度
enable_dynamic_typing_delay = false # 是否启用基于文本长度的动态打字延迟
# 动态打字延迟参数仅在enable_dynamic_typing_delay=true时生效
chars_per_second = 15.0 # 每秒字符数,用于计算动态打字延迟
min_typing_delay = 0.2 # 最小打字延迟(秒)
max_typing_delay = 2.0 # 最大打字延迟(秒)
# 系统功能开关
enable_old_message_cleanup = true # 是否自动清理过旧的普通消息
enable_loading_indicator = true # 是否显示加载提示
max_context_message_length = 20
max_core_message_length = 30

View File

@@ -0,0 +1 @@

View File

@@ -24,13 +24,32 @@ class ContextMessage:
self.timestamp = datetime.now() self.timestamp = datetime.now()
self.group_name = message.message_info.group_info.group_name if message.message_info.group_info else "私聊" self.group_name = message.message_info.group_info.group_name if message.message_info.group_info else "私聊"
# 识别消息类型
self.is_gift = getattr(message, 'is_gift', False)
self.is_superchat = getattr(message, 'is_superchat', False)
# 添加礼物和SC相关信息
if self.is_gift:
self.gift_name = getattr(message, 'gift_name', '')
self.gift_count = getattr(message, 'gift_count', '1')
self.content = f"送出了 {self.gift_name} x{self.gift_count}"
elif self.is_superchat:
self.superchat_price = getattr(message, 'superchat_price', '0')
self.superchat_message = getattr(message, 'superchat_message_text', '')
if self.superchat_message:
self.content = f"{self.superchat_price}] {self.superchat_message}"
else:
self.content = f"{self.superchat_price}] {self.content}"
def to_dict(self): def to_dict(self):
return { return {
"user_name": self.user_name, "user_name": self.user_name,
"user_id": self.user_id, "user_id": self.user_id,
"content": self.content, "content": self.content,
"timestamp": self.timestamp.strftime("%m-%d %H:%M:%S"), "timestamp": self.timestamp.strftime("%m-%d %H:%M:%S"),
"group_name": self.group_name "group_name": self.group_name,
"is_gift": self.is_gift,
"is_superchat": self.is_superchat
} }
@@ -155,6 +174,44 @@ class ContextWebManager:
transform: translateX(5px); transform: translateX(5px);
transition: all 0.3s ease; transition: all 0.3s ease;
} }
.message.gift {
border-left: 4px solid #ff8800;
background: rgba(255, 136, 0, 0.2);
}
.message.gift:hover {
background: rgba(255, 136, 0, 0.3);
}
.message.gift .username {
color: #ff8800;
}
.message.superchat {
border-left: 4px solid #ff6b6b;
background: linear-gradient(135deg, rgba(255, 107, 107, 0.2), rgba(107, 255, 107, 0.2), rgba(107, 107, 255, 0.2));
background-size: 200% 200%;
animation: rainbow 3s ease infinite;
}
.message.superchat:hover {
background: linear-gradient(135deg, rgba(255, 107, 107, 0.4), rgba(107, 255, 107, 0.4), rgba(107, 107, 255, 0.4));
background-size: 200% 200%;
}
.message.superchat .username {
background: linear-gradient(45deg, #ff6b6b, #4ecdc4, #45b7d1, #96ceb4, #feca57);
background-size: 300% 300%;
animation: rainbow-text 2s ease infinite;
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
background-clip: text;
}
@keyframes rainbow {
0% { background-position: 0% 50%; }
50% { background-position: 100% 50%; }
100% { background-position: 0% 50%; }
}
@keyframes rainbow-text {
0% { background-position: 0% 50%; }
50% { background-position: 100% 50%; }
100% { background-position: 0% 50%; }
}
.message-line { .message-line {
line-height: 1.4; line-height: 1.4;
word-wrap: break-word; word-wrap: break-word;
@@ -373,7 +430,20 @@ class ContextWebManager:
function createMessageElement(msg, isNew = false) { function createMessageElement(msg, isNew = false) {
const messageDiv = document.createElement('div'); const messageDiv = document.createElement('div');
messageDiv.className = 'message' + (isNew ? ' new-message' : ''); let className = 'message';
// 根据消息类型添加对应的CSS类
if (msg.is_gift) {
className += ' gift';
} else if (msg.is_superchat) {
className += ' superchat';
}
if (isNew) {
className += ' new-message';
}
messageDiv.className = className;
messageDiv.innerHTML = ` messageDiv.innerHTML = `
<div class="message-line"> <div class="message-line">
<span class="username">${escapeHtml(msg.user_name)}</span><span class="content">${escapeHtml(msg.content)}</span> <span class="username">${escapeHtml(msg.user_name)}</span><span class="content">${escapeHtml(msg.content)}</span>

View File

@@ -0,0 +1,155 @@
import asyncio
from typing import Dict, Tuple, Callable, Optional
from dataclasses import dataclass
from src.chat.message_receive.message import MessageRecvS4U
from src.common.logger import get_logger
logger = get_logger("gift_manager")
@dataclass
class PendingGift:
"""等待中的礼物消息"""
message: MessageRecvS4U
total_count: int
timer_task: asyncio.Task
callback: Callable[[MessageRecvS4U], None]
class GiftManager:
"""礼物管理器,提供防抖功能"""
def __init__(self):
"""初始化礼物管理器"""
self.pending_gifts: Dict[Tuple[str, str], PendingGift] = {}
self.debounce_timeout = 3.0 # 3秒防抖时间
async def handle_gift(self, message: MessageRecvS4U, callback: Optional[Callable[[MessageRecvS4U], None]] = None) -> bool:
"""处理礼物消息,返回是否应该立即处理
Args:
message: 礼物消息
callback: 防抖完成后的回调函数
Returns:
bool: False表示消息被暂存等待防抖True表示应该立即处理
"""
if not message.is_gift:
return True
# 构建礼物的唯一键:(发送人ID, 礼物名称)
gift_key = (message.message_info.user_info.user_id, message.gift_name)
# 如果已经有相同的礼物在等待中,则合并
if gift_key in self.pending_gifts:
await self._merge_gift(gift_key, message)
return False
# 创建新的等待礼物
await self._create_pending_gift(gift_key, message, callback)
return False
async def _merge_gift(self, gift_key: Tuple[str, str], new_message: MessageRecvS4U) -> None:
"""合并礼物消息"""
pending_gift = self.pending_gifts[gift_key]
# 取消之前的定时器
if not pending_gift.timer_task.cancelled():
pending_gift.timer_task.cancel()
# 累加礼物数量
try:
new_count = int(new_message.gift_count)
pending_gift.total_count += new_count
# 更新消息为最新的(保留最新的消息,但累加数量)
pending_gift.message = new_message
pending_gift.message.gift_count = str(pending_gift.total_count)
pending_gift.message.gift_info = f"{pending_gift.message.gift_name}:{pending_gift.total_count}"
except ValueError:
logger.warning(f"无法解析礼物数量: {new_message.gift_count}")
# 如果无法解析数量,保持原有数量不变
# 重新创建定时器
pending_gift.timer_task = asyncio.create_task(
self._gift_timeout(gift_key)
)
logger.debug(f"合并礼物: {gift_key}, 总数量: {pending_gift.total_count}")
async def _create_pending_gift(
self,
gift_key: Tuple[str, str],
message: MessageRecvS4U,
callback: Optional[Callable[[MessageRecvS4U], None]]
) -> None:
"""创建新的等待礼物"""
try:
initial_count = int(message.gift_count)
except ValueError:
initial_count = 1
logger.warning(f"无法解析礼物数量: {message.gift_count}默认设为1")
# 创建定时器任务
timer_task = asyncio.create_task(self._gift_timeout(gift_key))
# 创建等待礼物对象
pending_gift = PendingGift(
message=message,
total_count=initial_count,
timer_task=timer_task,
callback=callback
)
self.pending_gifts[gift_key] = pending_gift
logger.debug(f"创建等待礼物: {gift_key}, 初始数量: {initial_count}")
async def _gift_timeout(self, gift_key: Tuple[str, str]) -> None:
"""礼物防抖超时处理"""
try:
# 等待防抖时间
await asyncio.sleep(self.debounce_timeout)
# 获取等待中的礼物
if gift_key not in self.pending_gifts:
return
pending_gift = self.pending_gifts.pop(gift_key)
logger.info(f"礼物防抖完成: {gift_key}, 最终数量: {pending_gift.total_count}")
message = pending_gift.message
message.processed_plain_text = f"用户{message.message_info.user_info.user_nickname}送出了礼物{message.gift_name} x{pending_gift.total_count}"
# 执行回调
if pending_gift.callback:
try:
pending_gift.callback(message)
except Exception as e:
logger.error(f"礼物回调执行失败: {e}", exc_info=True)
except asyncio.CancelledError:
# 定时器被取消,不需要处理
pass
except Exception as e:
logger.error(f"礼物防抖处理异常: {e}", exc_info=True)
def get_pending_count(self) -> int:
"""获取当前等待中的礼物数量"""
return len(self.pending_gifts)
async def flush_all(self) -> None:
"""立即处理所有等待中的礼物"""
for gift_key in list(self.pending_gifts.keys()):
pending_gift = self.pending_gifts.get(gift_key)
if pending_gift and not pending_gift.timer_task.cancelled():
pending_gift.timer_task.cancel()
await self._gift_timeout(gift_key)
# 创建全局礼物管理器实例
gift_manager = GiftManager()

View File

@@ -1,4 +1,5 @@
import asyncio import asyncio
import traceback
import time import time
import random import random
from typing import Optional, Dict, Tuple # 导入类型提示 from typing import Optional, Dict, Tuple # 导入类型提示
@@ -6,7 +7,7 @@ from maim_message import UserInfo, Seg
from src.common.logger import get_logger from src.common.logger import get_logger
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from .s4u_stream_generator import S4UStreamGenerator from .s4u_stream_generator import S4UStreamGenerator
from src.chat.message_receive.message import MessageSending, MessageRecv from src.chat.message_receive.message import MessageSending, MessageRecv, MessageRecvS4U
from src.config.config import global_config from src.config.config import global_config
from src.common.message.api import get_global_api from src.common.message.api import get_global_api
from src.chat.message_receive.storage import MessageStorage from src.chat.message_receive.storage import MessageStorage
@@ -14,6 +15,9 @@ from .s4u_watching_manager import watching_manager
import json import json
from src.person_info.relationship_builder_manager import relationship_builder_manager from src.person_info.relationship_builder_manager import relationship_builder_manager
from .loading import send_loading, send_unloading from .loading import send_loading, send_unloading
from src.mais4u.s4u_config import s4u_config
from src.person_info.person_info import PersonInfoManager
from .super_chat_manager import get_super_chat_manager
logger = get_logger("S4U_chat") logger = get_logger("S4U_chat")
@@ -49,9 +53,9 @@ class MessageSenderContainer:
def _calculate_typing_delay(self, text: str) -> float: def _calculate_typing_delay(self, text: str) -> float:
"""根据文本长度计算模拟打字延迟。""" """根据文本长度计算模拟打字延迟。"""
chars_per_second = 15.0 chars_per_second = s4u_config.chars_per_second
min_delay = 0.2 min_delay = s4u_config.min_typing_delay
max_delay = 2.0 max_delay = s4u_config.max_typing_delay
delay = len(text) / chars_per_second delay = len(text) / chars_per_second
return max(min_delay, min(delay, max_delay)) return max(min_delay, min(delay, max_delay))
@@ -73,8 +77,11 @@ class MessageSenderContainer:
# Check for pause signal *after* getting an item. # Check for pause signal *after* getting an item.
await self._paused_event.wait() await self._paused_event.wait()
# delay = self._calculate_typing_delay(chunk) # 根据配置选择延迟模式
delay = 0.1 if s4u_config.enable_dynamic_typing_delay:
delay = self._calculate_typing_delay(chunk)
else:
delay = s4u_config.typing_delay
await asyncio.sleep(delay) await asyncio.sleep(delay)
current_time = time.time() current_time = time.time()
@@ -144,8 +151,6 @@ def get_s4u_chat_manager() -> S4UChatManager:
class S4UChat: class S4UChat:
_MESSAGE_TIMEOUT_SECONDS = 120 # 普通消息存活时间(秒)
def __init__(self, chat_stream: ChatStream): def __init__(self, chat_stream: ChatStream):
"""初始化 S4UChat 实例。""" """初始化 S4UChat 实例。"""
@@ -169,8 +174,7 @@ class S4UChat:
self._is_replying = False self._is_replying = False
self.gpt = S4UStreamGenerator() self.gpt = S4UStreamGenerator()
self.interest_dict: Dict[str, float] = {} # 用户兴趣分 self.interest_dict: Dict[str, float] = {} # 用户兴趣分
self.at_bot_priority_bonus = 100.0 # @机器人的优先级加成
self.recent_message_keep_count = 6 # 保留最近N条消息超出范围的普通消息将被移除
logger.info(f"[{self.stream_name}] S4UChat with two-queue system initialized.") logger.info(f"[{self.stream_name}] S4UChat with two-queue system initialized.")
def _get_priority_info(self, message: MessageRecv) -> dict: def _get_priority_info(self, message: MessageRecv) -> dict:
@@ -194,16 +198,13 @@ class S4UChat:
"""获取用户的兴趣分默认为1.0""" """获取用户的兴趣分默认为1.0"""
return self.interest_dict.get(user_id, 1.0) return self.interest_dict.get(user_id, 1.0)
def _calculate_base_priority_score(self, message: MessageRecv, priority_info: dict) -> float: def _calculate_base_priority_score(self, message: MessageRecv, priority_info: dict) -> float:
""" """
为消息计算基础优先级分数。分数越高,优先级越高。 为消息计算基础优先级分数。分数越高,优先级越高。
""" """
score = 0.0 score = 0.0
# 如果消息 @ 了机器人,则增加一个很大的分数
# if f"@{global_config.bot.nickname}" in message.processed_plain_text or any(
# f"@{alias}" in message.processed_plain_text for alias in global_config.bot.alias_names
# ):
# score += self.at_bot_priority_bonus
# 加上消息自带的优先级 # 加上消息自带的优先级
score += priority_info.get("message_priority", 0.0) score += priority_info.get("message_priority", 0.0)
@@ -212,17 +213,55 @@ class S4UChat:
score += self._get_interest_score(message.message_info.user_info.user_id) score += self._get_interest_score(message.message_info.user_info.user_id)
return score return score
async def add_message(self, message: MessageRecv) -> None: def decay_interest_score(self,message: MessageRecvS4U|MessageRecv):
"""根据VIP状态和中断逻辑将消息放入相应队列。""" for person_id, score in self.interest_dict.items():
if score > 0:
self.interest_dict[person_id] = score * 0.95
else:
self.interest_dict[person_id] = 0
await self.relationship_builder.build_relation() async def add_message(self, message: MessageRecvS4U|MessageRecv) -> None:
self.decay_interest_score(message)
"""根据VIP状态和中断逻辑将消息放入相应队列。"""
user_id = message.message_info.user_info.user_id
platform = message.message_info.platform
person_id = PersonInfoManager.get_person_id(platform, user_id)
try:
is_gift = message.is_gift
is_superchat = message.is_superchat
print(is_gift)
print(is_superchat)
if is_gift:
await self.relationship_builder.build_relation(immediate_build=person_id)
# 安全地增加兴趣分如果person_id不存在则先初始化为1.0
current_score = self.interest_dict.get(person_id, 1.0)
self.interest_dict[person_id] = current_score + 0.1 * message.gift_count
elif is_superchat:
await self.relationship_builder.build_relation(immediate_build=person_id)
# 安全地增加兴趣分如果person_id不存在则先初始化为1.0
current_score = self.interest_dict.get(person_id, 1.0)
self.interest_dict[person_id] = current_score + 0.1 * float(message.superchat_price)
# 添加SuperChat到管理器
super_chat_manager = get_super_chat_manager()
await super_chat_manager.add_superchat(message)
else:
await self.relationship_builder.build_relation(20)
except Exception as e:
traceback.print_exc()
logger.info(f"[{self.stream_name}] 消息处理完毕,消息内容:{message.processed_plain_text}")
priority_info = self._get_priority_info(message) priority_info = self._get_priority_info(message)
is_vip = self._is_vip(priority_info) is_vip = self._is_vip(priority_info)
new_priority_score = self._calculate_base_priority_score(message, priority_info) new_priority_score = self._calculate_base_priority_score(message, priority_info)
should_interrupt = False should_interrupt = False
if self._current_generation_task and not self._current_generation_task.done(): if (s4u_config.enable_message_interruption and
self._current_generation_task and not self._current_generation_task.done()):
if self._current_message_being_replied: if self._current_message_being_replied:
current_queue, current_priority, _, current_msg = self._current_message_being_replied current_queue, current_priority, _, current_msg = self._current_message_being_replied
@@ -260,7 +299,7 @@ class S4UChat:
# 这样,原始分数越高的消息,在队列中的优先级数字越小,越靠前 # 这样,原始分数越高的消息,在队列中的优先级数字越小,越靠前
item = (-new_priority_score, self._entry_counter, time.time(), message) item = (-new_priority_score, self._entry_counter, time.time(), message)
if is_vip: if is_vip and s4u_config.vip_queue_priority:
await self._vip_queue.put(item) await self._vip_queue.put(item)
logger.info(f"[{self.stream_name}] VIP message added to queue.") logger.info(f"[{self.stream_name}] VIP message added to queue.")
else: else:
@@ -271,11 +310,11 @@ class S4UChat:
def _cleanup_old_normal_messages(self): def _cleanup_old_normal_messages(self):
"""清理普通队列中不在最近N条消息范围内的消息""" """清理普通队列中不在最近N条消息范围内的消息"""
if self._normal_queue.empty(): if not s4u_config.enable_old_message_cleanup or self._normal_queue.empty():
return return
# 计算阈值:保留最近 recent_message_keep_count 条消息 # 计算阈值:保留最近 recent_message_keep_count 条消息
cutoff_counter = max(0, self._entry_counter - self.recent_message_keep_count) cutoff_counter = max(0, self._entry_counter - s4u_config.recent_message_keep_count)
# 临时存储需要保留的消息 # 临时存储需要保留的消息
temp_messages = [] temp_messages = []
@@ -302,7 +341,7 @@ class S4UChat:
self._normal_queue.put_nowait(item) self._normal_queue.put_nowait(item)
if removed_count > 0: if removed_count > 0:
logger.info(f"[{self.stream_name}] Cleaned up {removed_count} old normal messages outside recent {self.recent_message_keep_count} range.") logger.info(f"[{self.stream_name}] Cleaned up {removed_count} old normal messages outside recent {s4u_config.recent_message_keep_count} range.")
async def _message_processor(self): async def _message_processor(self):
"""调度器优先处理VIP队列然后处理普通队列。""" """调度器优先处理VIP队列然后处理普通队列。"""
@@ -325,7 +364,7 @@ class S4UChat:
neg_priority, entry_count, timestamp, message = self._normal_queue.get_nowait() neg_priority, entry_count, timestamp, message = self._normal_queue.get_nowait()
priority = -neg_priority priority = -neg_priority
# 检查普通消息是否超时 # 检查普通消息是否超时
if time.time() - timestamp > self._MESSAGE_TIMEOUT_SECONDS: if time.time() - timestamp > s4u_config.message_timeout_seconds:
logger.info( logger.info(
f"[{self.stream_name}] Discarding stale normal message: {message.processed_plain_text[:20]}..." f"[{self.stream_name}] Discarding stale normal message: {message.processed_plain_text[:20]}..."
) )
@@ -369,18 +408,24 @@ class S4UChat:
logger.error(f"[{self.stream_name}] Message processor main loop error: {e}", exc_info=True) logger.error(f"[{self.stream_name}] Message processor main loop error: {e}", exc_info=True)
await asyncio.sleep(1) await asyncio.sleep(1)
async def delay_change_watching_state(self):
random_delay = random.randint(1, 3)
await asyncio.sleep(random_delay)
chat_watching = watching_manager.get_watching_by_chat_id(self.stream_id)
await chat_watching.on_message_received()
async def _generate_and_send(self, message: MessageRecv): async def _generate_and_send(self, message: MessageRecv):
"""为单个消息生成文本回复。整个过程可以被中断。""" """为单个消息生成文本回复。整个过程可以被中断。"""
self._is_replying = True self._is_replying = True
total_chars_sent = 0 # 跟踪发送的总字符数
if s4u_config.enable_loading_indicator:
await send_loading(self.stream_id, "......") await send_loading(self.stream_id, "......")
# 视线管理:开始生成回复时切换视线状态 # 视线管理:开始生成回复时切换视线状态
chat_watching = watching_manager.get_watching_by_chat_id(self.stream_id) chat_watching = watching_manager.get_watching_by_chat_id(self.stream_id)
await chat_watching.on_reply_start() asyncio.create_task(self.delay_change_watching_state())
# 回复生成实时展示:开始生成
user_name = message.message_info.user_info.user_nickname
sender_container = MessageSenderContainer(self.chat_stream, message) sender_container = MessageSenderContainer(self.chat_stream, message)
sender_container.start() sender_container.start()
@@ -395,12 +440,18 @@ class S4UChat:
# a. 发送文本块 # a. 发送文本块
await sender_container.add_message(chunk) await sender_container.add_message(chunk)
total_chars_sent += len(chunk) # 累计字符数
# 等待所有文本消息发送完成 # 等待所有文本消息发送完成
await sender_container.close() await sender_container.close()
await sender_container.join() await sender_container.join()
# 回复完成后延迟每个字延迟0.4秒
if total_chars_sent > 0:
delay_time = total_chars_sent * 0.4
logger.info(f"[{self.stream_name}] 回复完成,共发送 {total_chars_sent} 个字符,等待 {delay_time:.1f} 秒后继续处理下一个消息。")
await asyncio.sleep(delay_time)
logger.info(f"[{self.stream_name}] 所有文本块处理完毕。") logger.info(f"[{self.stream_name}] 所有文本块处理完毕。")
@@ -408,11 +459,13 @@ class S4UChat:
logger.info(f"[{self.stream_name}] 回复流程(文本)被中断。") logger.info(f"[{self.stream_name}] 回复流程(文本)被中断。")
raise # 将取消异常向上传播 raise # 将取消异常向上传播
except Exception as e: except Exception as e:
traceback.print_exc()
logger.error(f"[{self.stream_name}] 回复生成过程中出现错误: {e}", exc_info=True) logger.error(f"[{self.stream_name}] 回复生成过程中出现错误: {e}", exc_info=True)
# 回复生成实时展示:清空内容(出错时) # 回复生成实时展示:清空内容(出错时)
finally: finally:
self._is_replying = False self._is_replying = False
if s4u_config.enable_loading_indicator:
await send_unloading(self.stream_id) await send_unloading(self.stream_id)
# 视线管理:回复结束时切换视线状态 # 视线管理:回复结束时切换视线状态
@@ -442,3 +495,8 @@ class S4UChat:
await self._processing_task await self._processing_task
except asyncio.CancelledError: except asyncio.CancelledError:
logger.info(f"处理任务已成功取消: {self.stream_name}") logger.info(f"处理任务已成功取消: {self.stream_name}")
# 注意SuperChat管理器是全局的不需要在单个S4UChat关闭时关闭
# 如果需要关闭SuperChat管理器应该在应用程序关闭时调用
# super_chat_manager = get_super_chat_manager()
# await super_chat_manager.shutdown()

View File

@@ -214,7 +214,7 @@ class ChatMood:
sorrow=self.mood_values["sorrow"], sorrow=self.mood_values["sorrow"],
fear=self.mood_values["fear"], fear=self.mood_values["fear"],
) )
logger.info(f"numerical mood prompt: {prompt}") logger.debug(f"numerical mood prompt: {prompt}")
response, (reasoning_content, model_name) = await self.mood_model_numerical.generate_response_async( response, (reasoning_content, model_name) = await self.mood_model_numerical.generate_response_async(
prompt=prompt prompt=prompt
) )

View File

@@ -3,7 +3,7 @@ import math
from typing import Tuple from typing import Tuple
from src.chat.memory_system.Hippocampus import hippocampus_manager from src.chat.memory_system.Hippocampus import hippocampus_manager
from src.chat.message_receive.message import MessageRecv from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
from src.chat.message_receive.storage import MessageStorage from src.chat.message_receive.storage import MessageStorage
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.timer_calculator import Timer from src.chat.utils.timer_calculator import Timer
@@ -14,6 +14,7 @@ from src.mais4u.mais4u_chat.body_emotion_action_manager import action_manager
from src.mais4u.mais4u_chat.s4u_mood_manager import mood_manager from src.mais4u.mais4u_chat.s4u_mood_manager import mood_manager
from src.mais4u.mais4u_chat.s4u_watching_manager import watching_manager from src.mais4u.mais4u_chat.s4u_watching_manager import watching_manager
from src.mais4u.mais4u_chat.context_web_manager import get_context_web_manager from src.mais4u.mais4u_chat.context_web_manager import get_context_web_manager
from src.mais4u.mais4u_chat.gift_manager import gift_manager
from .s4u_chat import get_s4u_chat_manager from .s4u_chat import get_s4u_chat_manager
@@ -66,7 +67,7 @@ class S4UMessageProcessor:
"""初始化心流处理器,创建消息存储实例""" """初始化心流处理器,创建消息存储实例"""
self.storage = MessageStorage() self.storage = MessageStorage()
async def process_message(self, message: MessageRecv) -> None: async def process_message(self, message: MessageRecvS4U, skip_gift_debounce: bool = False) -> None:
"""处理接收到的原始消息数据 """处理接收到的原始消息数据
主要流程: 主要流程:
@@ -80,8 +81,6 @@ class S4UMessageProcessor:
message_data: 原始消息字符串 message_data: 原始消息字符串
""" """
target_user_id_list = ["1026294844", "964959351"]
# 1. 消息解析与初始化 # 1. 消息解析与初始化
groupinfo = message.message_info.group_info groupinfo = message.message_info.group_info
userinfo = message.message_info.user_info userinfo = message.message_info.user_info
@@ -93,25 +92,29 @@ class S4UMessageProcessor:
group_info=groupinfo, group_info=groupinfo,
) )
# 处理礼物消息,如果消息被暂存则停止当前处理流程
if not skip_gift_debounce and not await self.handle_if_gift(message):
return
await self.check_if_fake_gift(message)
await self.storage.store_message(message, chat) await self.storage.store_message(message, chat)
s4u_chat = get_s4u_chat_manager().get_or_create_chat(chat) s4u_chat = get_s4u_chat_manager().get_or_create_chat(chat)
if userinfo.user_id in target_user_id_list:
await s4u_chat.add_message(message)
else:
await s4u_chat.add_message(message) await s4u_chat.add_message(message)
interested_rate, _ = await _calculate_interest(message) _interested_rate, _ = await _calculate_interest(message)
await mood_manager.start() await mood_manager.start()
# 一系列llm驱动的前处理
chat_mood = mood_manager.get_mood_by_chat_id(chat.stream_id) chat_mood = mood_manager.get_mood_by_chat_id(chat.stream_id)
asyncio.create_task(chat_mood.update_mood_by_message(message)) asyncio.create_task(chat_mood.update_mood_by_message(message))
chat_action = action_manager.get_action_state_by_chat_id(chat.stream_id) chat_action = action_manager.get_action_state_by_chat_id(chat.stream_id)
asyncio.create_task(chat_action.update_action_by_message(message)) asyncio.create_task(chat_action.update_action_by_message(message))
# asyncio.create_task(chat_action.update_facial_expression_by_message(message, interested_rate))
# 视线管理:收到消息时切换视线状态 # 视线管理:收到消息时切换视线状态
chat_watching = watching_manager.get_watching_by_chat_id(chat.stream_id) chat_watching = watching_manager.get_watching_by_chat_id(chat.stream_id)
asyncio.create_task(chat_watching.on_message_received()) asyncio.create_task(chat_watching.on_message_received())
@@ -119,9 +122,44 @@ class S4UMessageProcessor:
# 上下文网页管理启动独立task处理消息上下文 # 上下文网页管理启动独立task处理消息上下文
asyncio.create_task(self._handle_context_web_update(chat.stream_id, message)) asyncio.create_task(self._handle_context_web_update(chat.stream_id, message))
# 7. 日志记录 # 日志记录
if message.is_gift:
logger.info(f"[S4U-礼物] {userinfo.user_nickname} 送出了 {message.gift_name} x{message.gift_count}")
else:
logger.info(f"[S4U]{userinfo.user_nickname}:{message.processed_plain_text}") logger.info(f"[S4U]{userinfo.user_nickname}:{message.processed_plain_text}")
async def check_if_fake_gift(self, message: MessageRecvS4U) -> bool:
"""检查消息是否为假礼物"""
if message.is_gift:
return False
gift_keywords = ["送出了礼物", "礼物", "送出了"]
if any(keyword in message.processed_plain_text for keyword in gift_keywords):
message.processed_plain_text += "(注意:这是一条普通弹幕信息,对方没有真的发送礼物,不是礼物信息,注意区分)"
return True
return False
async def handle_if_gift(self, message: MessageRecvS4U) -> bool:
"""处理礼物消息
Returns:
bool: True表示应该继续处理消息False表示消息已被暂存不需要继续处理
"""
if message.is_gift:
# 定义防抖完成后的回调函数
def gift_callback(merged_message: MessageRecvS4U):
"""礼物防抖完成后的回调"""
# 创建异步任务来处理合并后的礼物消息,跳过防抖处理
asyncio.create_task(self.process_message(merged_message, skip_gift_debounce=True))
# 交给礼物管理器处理,并传入回调函数
# 对于礼物消息handle_gift 总是返回 False消息被暂存
await gift_manager.handle_gift(message, gift_callback)
return False # 消息被暂存,不继续处理
return True # 非礼物消息,继续正常处理
async def _handle_context_web_update(self, chat_id: str, message: MessageRecv): async def _handle_context_web_update(self, chat_id: str, message: MessageRecv):
"""处理上下文网页更新的独立task """处理上下文网页更新的独立task

View File

@@ -8,10 +8,13 @@ from src.chat.memory_system.Hippocampus import hippocampus_manager
import random import random
from datetime import datetime from datetime import datetime
import asyncio import asyncio
from src.mais4u.s4u_config import s4u_config
import ast import ast
from src.chat.message_receive.message import MessageSending, MessageRecvS4U
from src.person_info.person_info import get_person_info_manager from src.person_info.person_info import get_person_info_manager
from src.person_info.relationship_manager import get_relationship_manager from src.person_info.relationship_manager import get_relationship_manager
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.mais4u.mais4u_chat.super_chat_manager import get_super_chat_manager
logger = get_logger("prompt") logger = get_logger("prompt")
@@ -22,13 +25,19 @@ def init_prompt():
Prompt("你回想起了一些事情:\n{memory_info}\n", "memory_prompt") Prompt("你回想起了一些事情:\n{memory_info}\n", "memory_prompt")
Prompt( Prompt(
"""{identity_block} """
你的名字是麦麦, 是千石可乐开发的程序可以在QQ微信等平台发言你现在正在哔哩哔哩作为虚拟主播进行直播
虽然是个AI但你在努力扮演真人, 人设的配色是橙色和紫色和绿色
你有一头法式橙色卷发,你很可爱,穿戴英式侦探风格学院风裙子和帽子,你正在一个教室场景里进行虚拟线上直播。 你有一头法式橙色卷发,你很可爱,穿戴英式侦探风格学院风裙子和帽子,你正在一个教室场景里进行虚拟线上直播。
你可以看见用户发送的弹幕礼物和superchat
你可以看见面前的屏幕,
{relation_info_block} {relation_info_block}
{memory_block} {memory_block}
你现在的主要任务是和 {sender_name} 聊天。同时,也有其他用户会参与你们的聊天,你可以参考他们的回复内容,但是你主要还是关注你和{sender_name}的聊天内容。 你现在的主要任务是和 {sender_name} 发送的弹幕聊天。同时,也有其他用户会参与你们的聊天,你可以参考他们的回复内容,但是你主要还是关注你和{sender_name}的聊天内容。
{sc_info}
{background_dialogue_prompt} {background_dialogue_prompt}
-------------------------------- --------------------------------
@@ -37,6 +46,7 @@ def init_prompt():
{core_dialogue_prompt} {core_dialogue_prompt}
对方最新发送的内容:{message_txt} 对方最新发送的内容:{message_txt}
{gift_info}
回复可以简短一些。可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞,平淡一些。 回复可以简短一些。可以参考贴吧,知乎和微博的回复风格,回复不要浮夸,不要用夸张修辞,平淡一些。
不要输出多余内容(包括前后缀,冒号和引号,括号()表情包at或 @等 )。只输出回复内容,现在{sender_name}正在等待你的回复。 不要输出多余内容(包括前后缀,冒号和引号,括号()表情包at或 @等 )。只输出回复内容,现在{sender_name}正在等待你的回复。
你的回复风格不要浮夸,有逻辑和条理,请你继续回复{sender_name} 你的回复风格不要浮夸,有逻辑和条理,请你继续回复{sender_name}
@@ -117,14 +127,14 @@ class PromptBuilder:
return await global_prompt_manager.format_prompt("memory_prompt", memory_info=related_memory_info) return await global_prompt_manager.format_prompt("memory_prompt", memory_info=related_memory_info)
return "" return ""
def build_chat_history_prompts(self, chat_stream, message) -> (str, str): def build_chat_history_prompts(self, chat_stream: ChatStream, message: MessageRecvS4U):
message_list_before_now = get_raw_msg_before_timestamp_with_chat( message_list_before_now = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_stream.stream_id, chat_id=chat_stream.stream_id,
timestamp=time.time(), timestamp=time.time(),
limit=100, limit=200,
) )
talk_type = message.message_info.platform + ":" + message.chat_stream.user_info.user_id talk_type = message.message_info.platform + ":" + str(message.chat_stream.user_info.user_id)
core_dialogue_list = [] core_dialogue_list = []
background_dialogue_list = [] background_dialogue_list = []
@@ -148,10 +158,9 @@ class PromptBuilder:
background_dialogue_prompt = "" background_dialogue_prompt = ""
if background_dialogue_list: if background_dialogue_list:
latest_25_msgs = background_dialogue_list[-25:] context_msgs = background_dialogue_list[-s4u_config.max_context_message_length:]
background_dialogue_prompt_str = build_readable_messages( background_dialogue_prompt_str = build_readable_messages(
latest_25_msgs, context_msgs,
merge_messages=True,
timestamp_mode="normal_no_YMD", timestamp_mode="normal_no_YMD",
show_pic=False, show_pic=False,
) )
@@ -159,7 +168,7 @@ class PromptBuilder:
core_msg_str = "" core_msg_str = ""
if core_dialogue_list: if core_dialogue_list:
core_dialogue_list = core_dialogue_list[-50:] core_dialogue_list = core_dialogue_list[-s4u_config.max_core_message_length:]
first_msg = core_dialogue_list[0] first_msg = core_dialogue_list[0]
start_speaking_user_id = first_msg.get("user_id") start_speaking_user_id = first_msg.get("user_id")
@@ -196,10 +205,19 @@ class PromptBuilder:
return core_msg_str, background_dialogue_prompt return core_msg_str, background_dialogue_prompt
def build_gift_info(self, message: MessageRecvS4U):
if message.is_gift:
return f"这是一条礼物信息,{message.gift_name} x{message.gift_count},请注意这位用户"
return ""
def build_sc_info(self, message: MessageRecvS4U):
super_chat_manager = get_super_chat_manager()
return super_chat_manager.build_superchat_summary_string(message.chat_stream.stream_id)
async def build_prompt_normal( async def build_prompt_normal(
self, self,
message, message: MessageRecvS4U,
chat_stream, chat_stream: ChatStream,
message_txt: str, message_txt: str,
sender_name: str = "某人", sender_name: str = "某人",
) -> str: ) -> str:
@@ -209,6 +227,10 @@ class PromptBuilder:
core_dialogue_prompt, background_dialogue_prompt = self.build_chat_history_prompts(chat_stream, message) core_dialogue_prompt, background_dialogue_prompt = self.build_chat_history_prompts(chat_stream, message)
gift_info = self.build_gift_info(message)
sc_info = self.build_sc_info(message)
time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}" time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
template_name = "s4u_prompt" template_name = "s4u_prompt"
@@ -219,12 +241,16 @@ class PromptBuilder:
time_block=time_block, time_block=time_block,
relation_info_block=relation_info_block, relation_info_block=relation_info_block,
memory_block=memory_block, memory_block=memory_block,
gift_info=gift_info,
sc_info=sc_info,
sender_name=sender_name, sender_name=sender_name,
core_dialogue_prompt=core_dialogue_prompt, core_dialogue_prompt=core_dialogue_prompt,
background_dialogue_prompt=background_dialogue_prompt, background_dialogue_prompt=background_dialogue_prompt,
message_txt=message_txt, message_txt=message_txt,
) )
print(prompt)
return prompt return prompt

View File

@@ -0,0 +1,307 @@
import asyncio
import time
from dataclasses import dataclass
from typing import Dict, List, Optional
from src.common.logger import get_logger
from src.chat.message_receive.message import MessageRecvS4U, MessageRecv
from src.mais4u.s4u_config import s4u_config
logger = get_logger("super_chat_manager")
@dataclass
class SuperChatRecord:
"""SuperChat记录数据类"""
user_id: str
user_nickname: str
platform: str
chat_id: str
price: float
message_text: str
timestamp: float
expire_time: float
group_name: Optional[str] = None
def is_expired(self) -> bool:
"""检查SuperChat是否已过期"""
return time.time() > self.expire_time
def remaining_time(self) -> float:
"""获取剩余时间(秒)"""
return max(0, self.expire_time - time.time())
def to_dict(self) -> dict:
"""转换为字典格式"""
return {
"user_id": self.user_id,
"user_nickname": self.user_nickname,
"platform": self.platform,
"chat_id": self.chat_id,
"price": self.price,
"message_text": self.message_text,
"timestamp": self.timestamp,
"expire_time": self.expire_time,
"group_name": self.group_name,
"remaining_time": self.remaining_time()
}
class SuperChatManager:
"""SuperChat管理器负责管理和跟踪SuperChat消息"""
def __init__(self):
self.super_chats: Dict[str, List[SuperChatRecord]] = {} # chat_id -> SuperChat列表
self._cleanup_task: Optional[asyncio.Task] = None
self._is_initialized = False
logger.info("SuperChat管理器已初始化")
def _ensure_cleanup_task_started(self):
"""确保清理任务已启动(延迟启动)"""
if self._cleanup_task is None or self._cleanup_task.done():
try:
loop = asyncio.get_running_loop()
self._cleanup_task = loop.create_task(self._cleanup_expired_superchats())
self._is_initialized = True
logger.info("SuperChat清理任务已启动")
except RuntimeError:
# 没有运行的事件循环,稍后再启动
logger.debug("当前没有运行的事件循环,将在需要时启动清理任务")
def _start_cleanup_task(self):
"""启动清理任务(已弃用,保留向后兼容)"""
self._ensure_cleanup_task_started()
async def _cleanup_expired_superchats(self):
"""定期清理过期的SuperChat"""
while True:
try:
current_time = time.time()
total_removed = 0
for chat_id in list(self.super_chats.keys()):
original_count = len(self.super_chats[chat_id])
# 移除过期的SuperChat
self.super_chats[chat_id] = [
sc for sc in self.super_chats[chat_id]
if not sc.is_expired()
]
removed_count = original_count - len(self.super_chats[chat_id])
total_removed += removed_count
if removed_count > 0:
logger.info(f"从聊天 {chat_id} 中清理了 {removed_count} 个过期的SuperChat")
# 如果列表为空,删除该聊天的记录
if not self.super_chats[chat_id]:
del self.super_chats[chat_id]
if total_removed > 0:
logger.info(f"总共清理了 {total_removed} 个过期的SuperChat")
# 每30秒检查一次
await asyncio.sleep(30)
except Exception as e:
logger.error(f"清理过期SuperChat时出错: {e}", exc_info=True)
await asyncio.sleep(60) # 出错时等待更长时间
def _calculate_expire_time(self, price: float) -> float:
"""根据SuperChat金额计算过期时间"""
current_time = time.time()
# 根据金额阶梯设置不同的存活时间
if price >= 500:
# 500元以上保持4小时
duration = 4 * 3600
elif price >= 200:
# 200-499元保持2小时
duration = 2 * 3600
elif price >= 100:
# 100-199元保持1小时
duration = 1 * 3600
elif price >= 50:
# 50-99元保持30分钟
duration = 30 * 60
elif price >= 20:
# 20-49元保持15分钟
duration = 15 * 60
elif price >= 10:
# 10-19元保持10分钟
duration = 10 * 60
else:
# 10元以下保持5分钟
duration = 5 * 60
return current_time + duration
async def add_superchat(self, message: MessageRecvS4U) -> None:
"""添加新的SuperChat记录"""
# 确保清理任务已启动
self._ensure_cleanup_task_started()
if not message.is_superchat or not message.superchat_price:
logger.warning("尝试添加非SuperChat消息到SuperChat管理器")
return
try:
price = float(message.superchat_price)
except (ValueError, TypeError):
logger.error(f"无效的SuperChat价格: {message.superchat_price}")
return
user_info = message.message_info.user_info
group_info = message.message_info.group_info
chat_id = getattr(message, 'chat_stream', None)
if chat_id:
chat_id = chat_id.stream_id
else:
# 生成chat_id的备用方法
chat_id = f"{message.message_info.platform}_{user_info.user_id}"
if group_info:
chat_id = f"{message.message_info.platform}_{group_info.group_id}"
expire_time = self._calculate_expire_time(price)
record = SuperChatRecord(
user_id=user_info.user_id,
user_nickname=user_info.user_nickname,
platform=message.message_info.platform,
chat_id=chat_id,
price=price,
message_text=message.superchat_message_text or "",
timestamp=message.message_info.time,
expire_time=expire_time,
group_name=group_info.group_name if group_info else None
)
# 添加到对应聊天的SuperChat列表
if chat_id not in self.super_chats:
self.super_chats[chat_id] = []
self.super_chats[chat_id].append(record)
# 按价格降序排序(价格高的在前)
self.super_chats[chat_id].sort(key=lambda x: x.price, reverse=True)
logger.info(f"添加SuperChat记录: {user_info.user_nickname} - {price}元 - {message.superchat_message_text}")
def get_superchats_by_chat(self, chat_id: str) -> List[SuperChatRecord]:
"""获取指定聊天的所有有效SuperChat"""
# 确保清理任务已启动
self._ensure_cleanup_task_started()
if chat_id not in self.super_chats:
return []
# 过滤掉过期的SuperChat
valid_superchats = [sc for sc in self.super_chats[chat_id] if not sc.is_expired()]
return valid_superchats
def get_all_valid_superchats(self) -> Dict[str, List[SuperChatRecord]]:
"""获取所有有效的SuperChat"""
# 确保清理任务已启动
self._ensure_cleanup_task_started()
result = {}
for chat_id, superchats in self.super_chats.items():
valid_superchats = [sc for sc in superchats if not sc.is_expired()]
if valid_superchats:
result[chat_id] = valid_superchats
return result
def build_superchat_display_string(self, chat_id: str, max_count: int = 10) -> str:
"""构建SuperChat显示字符串"""
superchats = self.get_superchats_by_chat(chat_id)
if not superchats:
return ""
# 限制显示数量
display_superchats = superchats[:max_count]
lines = []
lines.append("📢 当前有效超级弹幕:")
for i, sc in enumerate(display_superchats, 1):
remaining_minutes = int(sc.remaining_time() / 60)
remaining_seconds = int(sc.remaining_time() % 60)
time_display = f"{remaining_minutes}{remaining_seconds}" if remaining_minutes > 0 else f"{remaining_seconds}"
line = f"{i}. 【{sc.price}元】{sc.user_nickname}: {sc.message_text}"
if len(line) > 100: # 限制单行长度
line = line[:97] + "..."
line += f" (剩余{time_display})"
lines.append(line)
if len(superchats) > max_count:
lines.append(f"... 还有{len(superchats) - max_count}条SuperChat")
return "\n".join(lines)
def build_superchat_summary_string(self, chat_id: str) -> str:
"""构建SuperChat摘要字符串"""
superchats = self.get_superchats_by_chat(chat_id)
if not superchats:
return "当前没有有效的超级弹幕"
lines = []
for sc in superchats:
single_sc_str = f"{sc.user_nickname} - {sc.price}元 - {sc.message_text}"
if len(single_sc_str) > 100:
single_sc_str = single_sc_str[:97] + "..."
single_sc_str += f" (剩余{int(sc.remaining_time())}秒)"
lines.append(single_sc_str)
total_amount = sum(sc.price for sc in superchats)
count = len(superchats)
highest_amount = max(sc.price for sc in superchats)
final_str = f"当前有{count}条超级弹幕,总金额{total_amount}元,最高单笔{highest_amount}"
if lines:
final_str += "\n" + "\n".join(lines)
return final_str
def get_superchat_statistics(self, chat_id: str) -> dict:
"""获取SuperChat统计信息"""
superchats = self.get_superchats_by_chat(chat_id)
if not superchats:
return {
"count": 0,
"total_amount": 0,
"average_amount": 0,
"highest_amount": 0,
"lowest_amount": 0
}
amounts = [sc.price for sc in superchats]
return {
"count": len(superchats),
"total_amount": sum(amounts),
"average_amount": sum(amounts) / len(amounts),
"highest_amount": max(amounts),
"lowest_amount": min(amounts)
}
async def shutdown(self):
"""关闭管理器,清理资源"""
if self._cleanup_task and not self._cleanup_task.done():
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
logger.info("SuperChat管理器已关闭")
# 全局SuperChat管理器实例
super_chat_manager = SuperChatManager()
def get_super_chat_manager() -> SuperChatManager:
"""获取全局SuperChat管理器实例"""
return super_chat_manager

296
src/mais4u/s4u_config.py Normal file
View File

@@ -0,0 +1,296 @@
import os
import tomlkit
import shutil
from datetime import datetime
from tomlkit import TOMLDocument
from tomlkit.items import Table
from dataclasses import dataclass, fields, MISSING
from typing import TypeVar, Type, Any, get_origin, get_args, Literal
from src.common.logger import get_logger
logger = get_logger("s4u_config")
# 获取mais4u模块目录
MAIS4U_ROOT = os.path.dirname(__file__)
CONFIG_DIR = os.path.join(MAIS4U_ROOT, "config")
TEMPLATE_PATH = os.path.join(CONFIG_DIR, "s4u_config_template.toml")
CONFIG_PATH = os.path.join(CONFIG_DIR, "s4u_config.toml")
# S4U配置版本
S4U_VERSION = "1.0.0"
T = TypeVar("T", bound="S4UConfigBase")
@dataclass
class S4UConfigBase:
"""S4U配置类的基类"""
@classmethod
def from_dict(cls: Type[T], data: dict[str, Any]) -> T:
"""从字典加载配置字段"""
if not isinstance(data, dict):
raise TypeError(f"Expected a dictionary, got {type(data).__name__}")
init_args: dict[str, Any] = {}
for f in fields(cls):
field_name = f.name
if field_name.startswith("_"):
# 跳过以 _ 开头的字段
continue
if field_name not in data:
if f.default is not MISSING or f.default_factory is not MISSING:
# 跳过未提供且有默认值/默认构造方法的字段
continue
else:
raise ValueError(f"Missing required field: '{field_name}'")
value = data[field_name]
field_type = f.type
try:
init_args[field_name] = cls._convert_field(value, field_type) # type: ignore
except TypeError as e:
raise TypeError(f"Field '{field_name}' has a type error: {e}") from e
except Exception as e:
raise RuntimeError(f"Failed to convert field '{field_name}' to target type: {e}") from e
return cls(**init_args)
@classmethod
def _convert_field(cls, value: Any, field_type: Type[Any]) -> Any:
"""转换字段值为指定类型"""
# 如果是嵌套的 dataclass递归调用 from_dict 方法
if isinstance(field_type, type) and issubclass(field_type, S4UConfigBase):
if not isinstance(value, dict):
raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
return field_type.from_dict(value)
# 处理泛型集合类型list, set, tuple
field_origin_type = get_origin(field_type)
field_type_args = get_args(field_type)
if field_origin_type in {list, set, tuple}:
if not isinstance(value, list):
raise TypeError(f"Expected an list for {field_type.__name__}, got {type(value).__name__}")
if field_origin_type is list:
if (
field_type_args
and isinstance(field_type_args[0], type)
and issubclass(field_type_args[0], S4UConfigBase)
):
return [field_type_args[0].from_dict(item) for item in value]
return [cls._convert_field(item, field_type_args[0]) for item in value]
elif field_origin_type is set:
return {cls._convert_field(item, field_type_args[0]) for item in value}
elif field_origin_type is tuple:
if len(value) != len(field_type_args):
raise TypeError(
f"Expected {len(field_type_args)} items for {field_type.__name__}, got {len(value)}"
)
return tuple(cls._convert_field(item, arg) for item, arg in zip(value, field_type_args, strict=False))
if field_origin_type is dict:
if not isinstance(value, dict):
raise TypeError(f"Expected a dictionary for {field_type.__name__}, got {type(value).__name__}")
if len(field_type_args) != 2:
raise TypeError(f"Expected a dictionary with two type arguments for {field_type.__name__}")
key_type, value_type = field_type_args
return {cls._convert_field(k, key_type): cls._convert_field(v, value_type) for k, v in value.items()}
# 处理基础类型,例如 int, str 等
if field_origin_type is type(None) and value is None: # 处理Optional类型
return None
# 处理Literal类型
if field_origin_type is Literal or get_origin(field_type) is Literal:
allowed_values = get_args(field_type)
if value in allowed_values:
return value
else:
raise TypeError(f"Value '{value}' is not in allowed values {allowed_values} for Literal type")
if field_type is Any or isinstance(value, field_type):
return value
# 其他类型,尝试直接转换
try:
return field_type(value)
except (ValueError, TypeError) as e:
raise TypeError(f"Cannot convert {type(value).__name__} to {field_type.__name__}") from e
@dataclass
class S4UConfig(S4UConfigBase):
"""S4U聊天系统配置类"""
message_timeout_seconds: int = 120
"""普通消息存活时间(秒),超过此时间的消息将被丢弃"""
at_bot_priority_bonus: float = 100.0
"""@机器人时的优先级加成分数"""
recent_message_keep_count: int = 6
"""保留最近N条消息超出范围的普通消息将被移除"""
typing_delay: float = 0.1
"""打字延迟时间(秒),模拟真实打字速度"""
chars_per_second: float = 15.0
"""每秒字符数,用于计算动态打字延迟"""
min_typing_delay: float = 0.2
"""最小打字延迟(秒)"""
max_typing_delay: float = 2.0
"""最大打字延迟(秒)"""
enable_dynamic_typing_delay: bool = False
"""是否启用基于文本长度的动态打字延迟"""
vip_queue_priority: bool = True
"""是否启用VIP队列优先级系统"""
enable_message_interruption: bool = True
"""是否允许高优先级消息中断当前回复"""
enable_old_message_cleanup: bool = True
"""是否自动清理过旧的普通消息"""
enable_loading_indicator: bool = True
"""是否显示加载提示"""
max_context_message_length: int = 20
"""上下文消息最大长度"""
max_core_message_length: int = 30
"""核心消息最大长度"""
@dataclass
class S4UGlobalConfig(S4UConfigBase):
"""S4U总配置类"""
s4u: S4UConfig
S4U_VERSION: str = S4U_VERSION
def update_s4u_config():
"""更新S4U配置文件"""
# 创建配置目录(如果不存在)
os.makedirs(CONFIG_DIR, exist_ok=True)
# 检查模板文件是否存在
if not os.path.exists(TEMPLATE_PATH):
logger.error(f"S4U配置模板文件不存在: {TEMPLATE_PATH}")
logger.error("请确保模板文件存在后重新运行")
raise FileNotFoundError(f"S4U配置模板文件不存在: {TEMPLATE_PATH}")
# 检查配置文件是否存在
if not os.path.exists(CONFIG_PATH):
logger.info("S4U配置文件不存在从模板创建新配置")
shutil.copy2(TEMPLATE_PATH, CONFIG_PATH)
logger.info(f"已创建S4U配置文件: {CONFIG_PATH}")
return
# 读取旧配置文件和模板文件
with open(CONFIG_PATH, "r", encoding="utf-8") as f:
old_config = tomlkit.load(f)
with open(TEMPLATE_PATH, "r", encoding="utf-8") as f:
new_config = tomlkit.load(f)
# 检查version是否相同
if old_config and "inner" in old_config and "inner" in new_config:
old_version = old_config["inner"].get("version") # type: ignore
new_version = new_config["inner"].get("version") # type: ignore
if old_version and new_version and old_version == new_version:
logger.info(f"检测到S4U配置文件版本号相同 (v{old_version}),跳过更新")
return
else:
logger.info(f"检测到S4U配置版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}")
else:
logger.info("S4U配置文件未检测到版本号可能是旧版本。将进行更新")
# 创建备份目录
old_config_dir = os.path.join(CONFIG_DIR, "old")
os.makedirs(old_config_dir, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
old_backup_path = os.path.join(old_config_dir, f"s4u_config_{timestamp}.toml")
# 移动旧配置文件到old目录
shutil.move(CONFIG_PATH, old_backup_path)
logger.info(f"已备份旧S4U配置文件到: {old_backup_path}")
# 复制模板文件到配置目录
shutil.copy2(TEMPLATE_PATH, CONFIG_PATH)
logger.info(f"已创建新S4U配置文件: {CONFIG_PATH}")
def update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict):
"""
将source字典的值更新到target字典中如果target中存在相同的键
"""
for key, value in source.items():
# 跳过version字段的更新
if key == "version":
continue
if key in target:
target_value = target[key]
if isinstance(value, dict) and isinstance(target_value, (dict, Table)):
update_dict(target_value, value)
else:
try:
# 对数组类型进行特殊处理
if isinstance(value, list):
target[key] = tomlkit.array(str(value)) if value else tomlkit.array()
else:
# 其他类型使用item方法创建新值
target[key] = tomlkit.item(value)
except (TypeError, ValueError):
# 如果转换失败,直接赋值
target[key] = value
# 将旧配置的值更新到新配置中
logger.info("开始合并S4U新旧配置...")
update_dict(new_config, old_config)
# 保存更新后的配置(保留注释和格式)
with open(CONFIG_PATH, "w", encoding="utf-8") as f:
f.write(tomlkit.dumps(new_config))
logger.info("S4U配置文件更新完成")
def load_s4u_config(config_path: str) -> S4UGlobalConfig:
"""
加载S4U配置文件
:param config_path: 配置文件路径
:return: S4UGlobalConfig对象
"""
# 读取配置文件
with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
# 创建S4UGlobalConfig对象
try:
return S4UGlobalConfig.from_dict(config_data)
except Exception as e:
logger.critical("S4U配置文件解析失败")
raise e
# 初始化S4U配置
logger.info(f"S4U当前版本: {S4U_VERSION}")
update_s4u_config()
logger.info("正在加载S4U配置文件...")
s4u_config_main = load_s4u_config(config_path=CONFIG_PATH)
logger.info("S4U配置文件加载完成")
s4u_config: S4UConfig = s4u_config_main.s4u

View File

@@ -161,6 +161,60 @@ class PersonInfoManager:
await asyncio.to_thread(_db_create_sync, final_data) await asyncio.to_thread(_db_create_sync, final_data)
async def _safe_create_person_info(self, person_id: str, data: Optional[dict] = None):
"""安全地创建用户信息,处理竞态条件"""
if not person_id:
logger.debug("创建失败person_id不存在")
return
_person_info_default = copy.deepcopy(person_info_default)
model_fields = PersonInfo._meta.fields.keys() # type: ignore
final_data = {"person_id": person_id}
# Start with defaults for all model fields
for key, default_value in _person_info_default.items():
if key in model_fields:
final_data[key] = default_value
# Override with provided data
if data:
for key, value in data.items():
if key in model_fields:
final_data[key] = value
# Ensure person_id is correctly set from the argument
final_data["person_id"] = person_id
# Serialize JSON fields
for key in JSON_SERIALIZED_FIELDS:
if key in final_data:
if isinstance(final_data[key], (list, dict)):
final_data[key] = json.dumps(final_data[key], ensure_ascii=False)
elif final_data[key] is None: # Default for lists is [], store as "[]"
final_data[key] = json.dumps([], ensure_ascii=False)
def _db_safe_create_sync(p_data: dict):
try:
# 首先检查是否已存在
existing = PersonInfo.get_or_none(PersonInfo.person_id == p_data["person_id"])
if existing:
logger.debug(f"用户 {p_data['person_id']} 已存在,跳过创建")
return True
# 尝试创建
PersonInfo.create(**p_data)
return True
except Exception as e:
if "UNIQUE constraint failed" in str(e):
logger.debug(f"检测到并发创建用户 {p_data.get('person_id')},跳过错误")
return True # 其他协程已创建,视为成功
else:
logger.error(f"创建 PersonInfo 记录 {p_data.get('person_id')} 失败 (Peewee): {e}")
return False
await asyncio.to_thread(_db_safe_create_sync, final_data)
async def update_one_field(self, person_id: str, field_name: str, value, data: Optional[Dict] = None): async def update_one_field(self, person_id: str, field_name: str, value, data: Optional[Dict] = None):
"""更新某一个字段,会补全""" """更新某一个字段,会补全"""
if field_name not in PersonInfo._meta.fields: # type: ignore if field_name not in PersonInfo._meta.fields: # type: ignore
@@ -221,7 +275,8 @@ class PersonInfoManager:
if data and "user_id" in data: if data and "user_id" in data:
creation_data["user_id"] = data["user_id"] creation_data["user_id"] = data["user_id"]
await self.create_person_info(person_id, creation_data) # 使用安全的创建方法,处理竞态条件
await self._safe_create_person_info(person_id, creation_data)
@staticmethod @staticmethod
async def has_one_field(person_id: str, field_name: str): async def has_one_field(person_id: str, field_name: str):
@@ -529,16 +584,31 @@ class PersonInfoManager:
""" """
根据 platform 和 user_id 获取 person_id。 根据 platform 和 user_id 获取 person_id。
如果对应的用户不存在,则使用提供的可选信息创建新用户。 如果对应的用户不存在,则使用提供的可选信息创建新用户。
使用try-except处理竞态条件避免重复创建错误。
""" """
person_id = self.get_person_id(platform, user_id) person_id = self.get_person_id(platform, user_id)
def _db_check_exists_sync(p_id: str): def _db_get_or_create_sync(p_id: str, init_data: dict):
return PersonInfo.get_or_none(PersonInfo.person_id == p_id) """原子性的获取或创建操作"""
# 首先尝试获取现有记录
record = PersonInfo.get_or_none(PersonInfo.person_id == p_id)
if record:
return record, False # 记录存在,未创建
record = await asyncio.to_thread(_db_check_exists_sync, person_id) # 记录不存在,尝试创建
try:
PersonInfo.create(**init_data)
return PersonInfo.get(PersonInfo.person_id == p_id), True # 创建成功
except Exception as e:
# 如果创建失败(可能是因为竞态条件),再次尝试获取
if "UNIQUE constraint failed" in str(e):
logger.debug(f"检测到并发创建用户 {p_id},获取现有记录")
record = PersonInfo.get_or_none(PersonInfo.person_id == p_id)
if record:
return record, False # 其他协程已创建,返回现有记录
# 如果仍然失败,重新抛出异常
raise e
if record is None:
logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录 (Peewee)。")
unique_nickname = await self._generate_unique_person_name(nickname) unique_nickname = await self._generate_unique_person_name(nickname)
initial_data = { initial_data = {
"person_id": person_id, "person_id": person_id,
@@ -554,11 +624,25 @@ class PersonInfoManager:
"points": [], "points": [],
"forgotten_points": [], "forgotten_points": [],
} }
# 序列化JSON字段
for key in JSON_SERIALIZED_FIELDS:
if key in initial_data:
if isinstance(initial_data[key], (list, dict)):
initial_data[key] = json.dumps(initial_data[key], ensure_ascii=False)
elif initial_data[key] is None:
initial_data[key] = json.dumps([], ensure_ascii=False)
model_fields = PersonInfo._meta.fields.keys() # type: ignore model_fields = PersonInfo._meta.fields.keys() # type: ignore
filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields} filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields}
await self.create_person_info(person_id, data=filtered_initial_data) record, was_created = await asyncio.to_thread(_db_get_or_create_sync, person_id, filtered_initial_data)
if was_created:
logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录 (Peewee)。")
logger.info(f"已为 {person_id} 创建新记录,初始数据 (filtered for model): {filtered_initial_data}") logger.info(f"已为 {person_id} 创建新记录,初始数据 (filtered for model): {filtered_initial_data}")
else:
logger.debug(f"用户 {platform}:{user_id} (person_id: {person_id}) 已存在,返回现有记录。")
return person_id return person_id

View File

@@ -60,9 +60,9 @@ class RelationshipBuilder:
# 获取聊天名称用于日志 # 获取聊天名称用于日志
try: try:
chat_name = get_chat_manager().get_stream_name(self.chat_id) chat_name = get_chat_manager().get_stream_name(self.chat_id)
self.log_prefix = f"[{chat_name}] 关系构建" self.log_prefix = f"[{chat_name}]"
except Exception: except Exception:
self.log_prefix = f"[{self.chat_id}] 关系构建" self.log_prefix = f"[{self.chat_id}]"
# 加载持久化的缓存 # 加载持久化的缓存
self._load_cache() self._load_cache()
@@ -349,11 +349,14 @@ class RelationshipBuilder:
# 统筹各模块协作、对外提供服务接口 # 统筹各模块协作、对外提供服务接口
# ================================ # ================================
async def build_relation(self): async def build_relation(self,immediate_build: str = "",max_build_threshold: int = MAX_MESSAGE_COUNT):
"""构建关系""" """构建关系
immediate_build: 立即构建关系,可选值为"all"或person_id
"""
self._cleanup_old_segments() self._cleanup_old_segments()
current_time = time.time() current_time = time.time()
if latest_messages := get_raw_msg_by_timestamp_with_chat( if latest_messages := get_raw_msg_by_timestamp_with_chat(
self.chat_id, self.chat_id,
self.last_processed_message_time, self.last_processed_message_time,
@@ -374,7 +377,7 @@ class RelationshipBuilder:
): ):
person_id = PersonInfoManager.get_person_id(platform, user_id) person_id = PersonInfoManager.get_person_id(platform, user_id)
self._update_message_segments(person_id, msg_time) self._update_message_segments(person_id, msg_time)
logger.debug( logger.info(
f"{self.log_prefix} 更新用户 {person_id} 的消息段,消息时间:{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(msg_time))}" f"{self.log_prefix} 更新用户 {person_id} 的消息段,消息时间:{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(msg_time))}"
) )
self.last_processed_message_time = max(self.last_processed_message_time, msg_time) self.last_processed_message_time = max(self.last_processed_message_time, msg_time)
@@ -383,15 +386,17 @@ class RelationshipBuilder:
users_to_build_relationship = [] users_to_build_relationship = []
for person_id, segments in self.person_engaged_cache.items(): for person_id, segments in self.person_engaged_cache.items():
total_message_count = self._get_total_message_count(person_id) total_message_count = self._get_total_message_count(person_id)
if total_message_count >= MAX_MESSAGE_COUNT: person_name = get_person_info_manager().get_value_sync(person_id, "person_name") or person_id
if total_message_count >= max_build_threshold or (total_message_count >= 5 and (immediate_build == person_id or immediate_build == "all")):
users_to_build_relationship.append(person_id) users_to_build_relationship.append(person_id)
logger.debug( logger.info(
f"{self.log_prefix} 用户 {person_id} 满足关系构建条件,总消息数:{total_message_count},消息段数:{len(segments)}" f"{self.log_prefix} 用户 {person_name} 满足关系构建条件,总消息数:{total_message_count},消息段数:{len(segments)}"
) )
elif total_message_count > 0: elif total_message_count > 0:
# 记录进度信息 # 记录进度信息
logger.debug( logger.info(
f"{self.log_prefix} 用户 {person_id} 进度:{total_message_count}60 条消息,{len(segments)} 个消息段" f"{self.log_prefix} 用户 {person_name} 进度:{total_message_count}/60 条消息,{len(segments)} 个消息段"
) )
# 2. 为满足条件的用户构建关系 # 2. 为满足条件的用户构建关系
@@ -405,6 +410,7 @@ class RelationshipBuilder:
del self.person_engaged_cache[person_id] del self.person_engaged_cache[person_id]
self._save_cache() self._save_cache()
# ================================ # ================================
# 关系构建模块 # 关系构建模块
# 负责触发关系构建、整合消息段、更新用户印象 # 负责触发关系构建、整合消息段、更新用户印象
@@ -413,7 +419,7 @@ class RelationshipBuilder:
async def update_impression_on_segments(self, person_id: str, chat_id: str, segments: List[Dict[str, Any]]): async def update_impression_on_segments(self, person_id: str, chat_id: str, segments: List[Dict[str, Any]]):
"""基于消息段更新用户印象""" """基于消息段更新用户印象"""
original_segment_count = len(segments) original_segment_count = len(segments)
logger.debug(f"开始为 {person_id} 基于 {original_segment_count} 个消息段更新印象") logger.info(f"开始为 {person_id} 基于 {original_segment_count} 个消息段更新印象")
try: try:
# 筛选要处理的消息段每个消息段有10%的概率被丢弃 # 筛选要处理的消息段每个消息段有10%的概率被丢弃
segments_to_process = [s for s in segments if random.random() >= 0.1] segments_to_process = [s for s in segments if random.random() >= 0.1]

View File

@@ -44,8 +44,8 @@ class RelationshipManager:
"konw_time": int(time.time()), "konw_time": int(time.time()),
"person_name": unique_nickname, # 使用唯一的 person_name "person_name": unique_nickname, # 使用唯一的 person_name
} }
# 先创建用户基本信息 # 先创建用户基本信息,使用安全创建方法避免竞态条件
await person_info_manager.create_person_info(person_id=person_id, data=data) await person_info_manager._safe_create_person_info(person_id=person_id, data=data)
# 更新昵称 # 更新昵称
await person_info_manager.update_one_field( await person_info_manager.update_one_field(
person_id=person_id, field_name="nickname", value=user_nickname, data=data person_id=person_id, field_name="nickname", value=user_nickname, data=data

View File

@@ -250,7 +250,6 @@ async def _find_reply_message(target_stream, reply_to: str) -> Optional[MessageR
message_dict = { message_dict = {
"message_info": message_info, "message_info": message_info,
"raw_message": find_msg.get("processed_plain_text"), "raw_message": find_msg.get("processed_plain_text"),
"detailed_plain_text": find_msg.get("processed_plain_text"),
"processed_plain_text": find_msg.get("processed_plain_text"), "processed_plain_text": find_msg.get("processed_plain_text"),
} }