fix:调整目录结构,优化hfc prompt,移除日程,移除动态和llm判断willing模式,

This commit is contained in:
SengokuCola
2025-05-13 18:37:55 +08:00
parent 6376da0682
commit fed71bccad
131 changed files with 422 additions and 1500 deletions

View File

@@ -0,0 +1,14 @@
from ..emoji_system.emoji_manager import emoji_manager
from ..person_info.relationship_manager import relationship_manager
from .chat_stream import chat_manager
from .message_sender import message_manager
from .storage import MessageStorage
__all__ = [
"emoji_manager",
"relationship_manager",
"chat_manager",
"message_manager",
"MessageStorage",
]

View File

@@ -0,0 +1,153 @@
import traceback
from typing import Dict, Any
from src.common.logger_manager import get_logger
from src.manager.mood_manager import mood_manager # 导入情绪管理器
from src.chat.message_receive.chat_stream import chat_manager
from src.chat.message_receive.message import MessageRecv
from src.experimental.only_message_process import MessageProcessor
from src.experimental.PFC.pfc_manager import PFCManager
from src.chat.focus_chat.heartflow_processor import HeartFCProcessor
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.config.config import global_config
# 定义日志配置
# 配置主程序日志格式
logger = get_logger("chat")
class ChatBot:
def __init__(self):
self.bot = None # bot 实例引用
self._started = False
self.mood_manager = mood_manager # 获取情绪管理器单例
self.heartflow_processor = HeartFCProcessor() # 新增
# 创建初始化PFC管理器的任务会在_ensure_started时执行
self.only_process_chat = MessageProcessor()
self.pfc_manager = PFCManager.get_instance()
async def _ensure_started(self):
"""确保所有任务已启动"""
if not self._started:
logger.trace("确保ChatBot所有任务已启动")
self._started = True
async def _create_pfc_chat(self, message: MessageRecv):
try:
chat_id = str(message.chat_stream.stream_id)
private_name = str(message.message_info.user_info.user_nickname)
if global_config.enable_pfc_chatting:
await self.pfc_manager.get_or_create_conversation(chat_id, private_name)
except Exception as e:
logger.error(f"创建PFC聊天失败: {e}")
async def message_process(self, message_data: Dict[str, Any]) -> None:
"""处理转化后的统一格式消息
这个函数本质是预处理一些数据,根据配置信息和消息内容,预处理消息,并分发到合适的消息处理器中
heart_flow模式使用思维流系统进行回复
- 包含思维流状态管理
- 在回复前进行观察和状态更新
- 回复后更新思维流状态
- 消息过滤
- 记忆激活
- 意愿计算
- 消息生成和发送
- 表情包处理
- 性能计时
"""
try:
# 确保所有任务已启动
await self._ensure_started()
if message_data["message_info"].get("group_info") is not None:
message_data["message_info"]["group_info"]["group_id"] = str(
message_data["message_info"]["group_info"]["group_id"]
)
message_data["message_info"]["user_info"]["user_id"] = str(
message_data["message_info"]["user_info"]["user_id"]
)
logger.trace(f"处理消息:{str(message_data)[:120]}...")
message = MessageRecv(message_data)
groupinfo = message.message_info.group_info
userinfo = message.message_info.user_info
# 用户黑名单拦截
if userinfo.user_id in global_config.ban_user_id:
logger.debug(f"用户{userinfo.user_id}被禁止回复")
return
if groupinfo is None:
logger.trace("检测到私聊消息,检查")
# 好友黑名单拦截
if userinfo.user_id not in global_config.talk_allowed_private:
logger.debug(f"用户{userinfo.user_id}没有私聊权限")
return
# 群聊黑名单拦截
if groupinfo is not None and groupinfo.group_id not in global_config.talk_allowed_groups:
logger.trace(f"{groupinfo.group_id}被禁止回复")
return
# 确认从接口发来的message是否有自定义的prompt模板信息
if message.message_info.template_info and not message.message_info.template_info.template_default:
template_group_name = message.message_info.template_info.template_name
template_items = message.message_info.template_info.template_items
async with global_prompt_manager.async_message_scope(template_group_name):
if isinstance(template_items, dict):
for k in template_items.keys():
await Prompt.create_async(template_items[k], k)
print(f"注册{template_items[k]},{k}")
else:
template_group_name = None
async def preprocess():
logger.trace("开始预处理消息...")
# 如果在私聊中
if groupinfo is None:
logger.trace("检测到私聊消息")
# 是否在配置信息中开启私聊模式
if global_config.enable_friend_chat:
logger.trace("私聊模式已启用")
# 是否进入PFC
if global_config.enable_pfc_chatting:
logger.trace("进入PFC私聊处理流程")
userinfo = message.message_info.user_info
messageinfo = message.message_info
# 创建聊天流
logger.trace(f"{userinfo.user_id}创建/获取聊天流")
chat = await chat_manager.get_or_create_stream(
platform=messageinfo.platform,
user_info=userinfo,
group_info=groupinfo,
)
message.update_chat_stream(chat)
await self.only_process_chat.process_message(message)
await self._create_pfc_chat(message)
# 禁止PFC进入普通的心流消息处理逻辑
else:
logger.trace("进入普通心流私聊处理")
await self.heartflow_processor.process_message(message_data)
# 群聊默认进入心流消息处理逻辑
else:
logger.trace(f"检测到群聊消息群ID: {groupinfo.group_id}")
await self.heartflow_processor.process_message(message_data)
if template_group_name:
async with global_prompt_manager.async_message_scope(template_group_name):
await preprocess()
else:
await preprocess()
except Exception as e:
logger.error(f"预处理消息失败: {e}")
traceback.print_exc()
# 创建全局ChatBot实例
chat_bot = ChatBot()

View File

@@ -0,0 +1,232 @@
import asyncio
import hashlib
import time
import copy
from typing import Dict, Optional
from ...common.database import db
from maim_message import GroupInfo, UserInfo
from src.common.logger_manager import get_logger
from rich.traceback import install
install(extra_lines=3)
logger = get_logger("chat_stream")
class ChatStream:
"""聊天流对象,存储一个完整的聊天上下文"""
def __init__(
self,
stream_id: str,
platform: str,
user_info: UserInfo,
group_info: Optional[GroupInfo] = None,
data: dict = None,
):
self.stream_id = stream_id
self.platform = platform
self.user_info = user_info
self.group_info = group_info
self.create_time = data.get("create_time", time.time()) if data else time.time()
self.last_active_time = data.get("last_active_time", self.create_time) if data else self.create_time
self.saved = False
def to_dict(self) -> dict:
"""转换为字典格式"""
result = {
"stream_id": self.stream_id,
"platform": self.platform,
"user_info": self.user_info.to_dict() if self.user_info else None,
"group_info": self.group_info.to_dict() if self.group_info else None,
"create_time": self.create_time,
"last_active_time": self.last_active_time,
}
return result
@classmethod
def from_dict(cls, data: dict) -> "ChatStream":
"""从字典创建实例"""
user_info = UserInfo.from_dict(data.get("user_info", {})) if data.get("user_info") else None
group_info = GroupInfo.from_dict(data.get("group_info", {})) if data.get("group_info") else None
return cls(
stream_id=data["stream_id"],
platform=data["platform"],
user_info=user_info,
group_info=group_info,
data=data,
)
def update_active_time(self):
"""更新最后活跃时间"""
self.last_active_time = time.time()
self.saved = False
class ChatManager:
"""聊天管理器,管理所有聊天流"""
_instance = None
_initialized = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not self._initialized:
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
self._ensure_collection()
self._initialized = True
# 在事件循环中启动初始化
# asyncio.create_task(self._initialize())
# # 启动自动保存任务
# asyncio.create_task(self._auto_save_task())
async def _initialize(self):
"""异步初始化"""
try:
await self.load_all_streams()
logger.success(f"聊天管理器已启动,已加载 {len(self.streams)} 个聊天流")
except Exception as e:
logger.error(f"聊天管理器启动失败: {str(e)}")
async def _auto_save_task(self):
"""定期自动保存所有聊天流"""
while True:
await asyncio.sleep(300) # 每5分钟保存一次
try:
await self._save_all_streams()
logger.info("聊天流自动保存完成")
except Exception as e:
logger.error(f"聊天流自动保存失败: {str(e)}")
@staticmethod
def _ensure_collection():
"""确保数据库集合存在并创建索引"""
if "chat_streams" not in db.list_collection_names():
db.create_collection("chat_streams")
# 创建索引
db.chat_streams.create_index([("stream_id", 1)], unique=True)
db.chat_streams.create_index([("platform", 1), ("user_info.user_id", 1), ("group_info.group_id", 1)])
@staticmethod
def _generate_stream_id(platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str:
"""生成聊天流唯一ID"""
if group_info:
# 组合关键信息
components = [platform, str(group_info.group_id)]
else:
components = [platform, str(user_info.user_id), "private"]
# 使用MD5生成唯一ID
key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest()
async def get_or_create_stream(
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
) -> ChatStream:
"""获取或创建聊天流
Args:
platform: 平台标识
user_info: 用户信息
group_info: 群组信息(可选)
Returns:
ChatStream: 聊天流对象
"""
# 生成stream_id
try:
stream_id = self._generate_stream_id(platform, user_info, group_info)
# 检查内存中是否存在
if stream_id in self.streams:
stream = self.streams[stream_id]
# 更新用户信息和群组信息
stream.update_active_time()
stream = copy.deepcopy(stream)
stream.user_info = user_info
if group_info:
stream.group_info = group_info
return stream
# 检查数据库中是否存在
data = db.chat_streams.find_one({"stream_id": stream_id})
if data:
stream = ChatStream.from_dict(data)
# 更新用户信息和群组信息
stream.user_info = user_info
if group_info:
stream.group_info = group_info
stream.update_active_time()
else:
# 创建新的聊天流
stream = ChatStream(
stream_id=stream_id,
platform=platform,
user_info=user_info,
group_info=group_info,
)
except Exception as e:
logger.error(f"创建聊天流失败: {e}")
raise e
# 保存到内存和数据库
self.streams[stream_id] = stream
await self._save_stream(stream)
return copy.deepcopy(stream)
def get_stream(self, stream_id: str) -> Optional[ChatStream]:
"""通过stream_id获取聊天流"""
return self.streams.get(stream_id)
def get_stream_by_info(
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
) -> Optional[ChatStream]:
"""通过信息获取聊天流"""
stream_id = self._generate_stream_id(platform, user_info, group_info)
return self.streams.get(stream_id)
def get_stream_name(self, stream_id: str) -> Optional[str]:
"""根据 stream_id 获取聊天流名称"""
stream = self.get_stream(stream_id)
if not stream:
return None
if stream.group_info and stream.group_info.group_name:
return stream.group_info.group_name
elif stream.user_info and stream.user_info.user_nickname:
return f"{stream.user_info.user_nickname}的私聊"
else:
# 如果没有群名或用户昵称,返回 None 或其他默认值
return None
@staticmethod
async def _save_stream(stream: ChatStream):
"""保存聊天流到数据库"""
if not stream.saved:
db.chat_streams.update_one({"stream_id": stream.stream_id}, {"$set": stream.to_dict()}, upsert=True)
stream.saved = True
async def _save_all_streams(self):
"""保存所有聊天流"""
for stream in self.streams.values():
await self._save_stream(stream)
async def load_all_streams(self):
"""从数据库加载所有聊天流"""
all_streams = db.chat_streams.find({})
for data in all_streams:
stream = ChatStream.from_dict(data)
self.streams[stream.stream_id] = stream
# 创建全局单例
chat_manager = ChatManager()

View File

@@ -0,0 +1,405 @@
import time
from abc import abstractmethod
from dataclasses import dataclass
from typing import Optional, Any
import urllib3
from src.common.logger_manager import get_logger
from .chat_stream import ChatStream
from ..utils.utils_image import image_manager
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
from rich.traceback import install
install(extra_lines=3)
logger = get_logger("chat_message")
# 禁用SSL警告
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
# 这个类是消息数据类,用于存储和管理消息数据。
# 它定义了消息的属性包括群组ID、用户ID、消息ID、原始消息内容、纯文本内容和时间戳。
# 它还定义了两个辅助属性keywords用于提取消息的关键词is_plain_text用于判断消息是否为纯文本。
@dataclass
class Message(MessageBase):
chat_stream: ChatStream = None
reply: Optional["Message"] = None
detailed_plain_text: str = ""
processed_plain_text: str = ""
memorized_times: int = 0
def __init__(
self,
message_id: str,
chat_stream: ChatStream,
user_info: UserInfo,
message_segment: Optional[Seg] = None,
timestamp: Optional[float] = None,
reply: Optional["MessageRecv"] = None,
detailed_plain_text: str = "",
processed_plain_text: str = "",
):
# 使用传入的时间戳或当前时间
current_timestamp = timestamp if timestamp is not None else round(time.time(), 3)
# 构造基础消息信息
message_info = BaseMessageInfo(
platform=chat_stream.platform,
message_id=message_id,
time=current_timestamp,
group_info=chat_stream.group_info,
user_info=user_info,
)
# 调用父类初始化
super().__init__(message_info=message_info, message_segment=message_segment, raw_message=None)
self.chat_stream = chat_stream
# 文本处理相关属性
self.processed_plain_text = processed_plain_text
self.detailed_plain_text = detailed_plain_text
# 回复消息
self.reply = reply
async def _process_message_segments(self, segment: Seg) -> str:
"""递归处理消息段,转换为文字描述
Args:
segment: 要处理的消息段
Returns:
str: 处理后的文本
"""
if segment.type == "seglist":
# 处理消息段列表
segments_text = []
for seg in segment.data:
processed = await self._process_message_segments(seg)
if processed:
segments_text.append(processed)
return " ".join(segments_text)
else:
# 处理单个消息段
return await self._process_single_segment(segment)
@abstractmethod
async def _process_single_segment(self, segment):
pass
@dataclass
class MessageRecv(Message):
"""接收消息类用于处理从MessageCQ序列化的消息"""
def __init__(self, message_dict: dict[str, Any]):
"""从MessageCQ的字典初始化
Args:
message_dict: MessageCQ序列化后的字典
"""
# print(f"message_dict: {message_dict}")
self.message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {}))
self.message_segment = Seg.from_dict(message_dict.get("message_segment", {}))
self.raw_message = message_dict.get("raw_message")
# 处理消息内容
self.processed_plain_text = "" # 初始化为空字符串
self.detailed_plain_text = "" # 初始化为空字符串
self.is_emoji = False
def update_chat_stream(self, chat_stream: ChatStream):
self.chat_stream = chat_stream
async def process(self) -> None:
"""处理消息内容,生成纯文本和详细文本
这个方法必须在创建实例后显式调用,因为它包含异步操作。
"""
self.processed_plain_text = await self._process_message_segments(self.message_segment)
self.detailed_plain_text = self._generate_detailed_text()
async def _process_single_segment(self, seg: Seg) -> str:
"""处理单个消息段
Args:
seg: 要处理的消息段
Returns:
str: 处理后的文本
"""
try:
if seg.type == "text":
return seg.data
elif seg.type == "image":
# 如果是base64图片数据
if isinstance(seg.data, str):
return await image_manager.get_image_description(seg.data)
return "[发了一张图片,网卡了加载不出来]"
elif seg.type == "emoji":
self.is_emoji = True
if isinstance(seg.data, str):
return await image_manager.get_emoji_description(seg.data)
return "[发了一个表情包,网卡了加载不出来]"
else:
return f"[{seg.type}:{str(seg.data)}]"
except Exception as e:
logger.error(f"处理消息段失败: {str(e)}, 类型: {seg.type}, 数据: {seg.data}")
return f"[处理失败的{seg.type}消息]"
def _generate_detailed_text(self) -> str:
"""生成详细文本,包含时间和用户信息"""
timestamp = self.message_info.time
user_info = self.message_info.user_info
name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>"
return f"[{timestamp}] {name}: {self.processed_plain_text}\n"
@dataclass
class MessageProcessBase(Message):
"""消息处理基类,用于处理中和发送中的消息"""
def __init__(
self,
message_id: str,
chat_stream: ChatStream,
bot_user_info: UserInfo,
message_segment: Optional[Seg] = None,
reply: Optional["MessageRecv"] = None,
thinking_start_time: float = 0,
timestamp: Optional[float] = None,
):
# 调用父类初始化,传递时间戳
super().__init__(
message_id=message_id,
timestamp=timestamp,
chat_stream=chat_stream,
user_info=bot_user_info,
message_segment=message_segment,
reply=reply,
)
# 处理状态相关属性
self.thinking_start_time = thinking_start_time
self.thinking_time = 0
def update_thinking_time(self) -> float:
"""更新思考时间"""
self.thinking_time = round(time.time() - self.thinking_start_time, 2)
return self.thinking_time
async def _process_single_segment(self, seg: Seg) -> str | None:
"""处理单个消息段
Args:
seg: 要处理的消息段
Returns:
str: 处理后的文本
"""
try:
if seg.type == "text":
return seg.data
elif seg.type == "image":
# 如果是base64图片数据
if isinstance(seg.data, str):
return await image_manager.get_image_description(seg.data)
return "[图片,网卡了加载不出来]"
elif seg.type == "emoji":
if isinstance(seg.data, str):
return await image_manager.get_emoji_description(seg.data)
return "[表情,网卡了加载不出来]"
elif seg.type == "at":
return f"[@{seg.data}]"
elif seg.type == "reply":
if self.reply and hasattr(self.reply, "processed_plain_text"):
return f"[回复:{self.reply.processed_plain_text}]"
return None
else:
return f"[{seg.type}:{str(seg.data)}]"
except Exception as e:
logger.error(f"处理消息段失败: {str(e)}, 类型: {seg.type}, 数据: {seg.data}")
return f"[处理失败的{seg.type}消息]"
def _generate_detailed_text(self) -> str:
"""生成详细文本,包含时间和用户信息"""
# time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.message_info.time))
timestamp = self.message_info.time
user_info = self.message_info.user_info
name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>"
return f"[{timestamp}]{name} 说:{self.processed_plain_text}\n"
@dataclass
class MessageThinking(MessageProcessBase):
"""思考状态的消息类"""
def __init__(
self,
message_id: str,
chat_stream: ChatStream,
bot_user_info: UserInfo,
reply: Optional["MessageRecv"] = None,
thinking_start_time: float = 0,
timestamp: Optional[float] = None,
):
# 调用父类初始化,传递时间戳
super().__init__(
message_id=message_id,
chat_stream=chat_stream,
bot_user_info=bot_user_info,
message_segment=None, # 思考状态不需要消息段
reply=reply,
thinking_start_time=thinking_start_time,
timestamp=timestamp,
)
# 思考状态特有属性
self.interrupt = False
@dataclass
class MessageSending(MessageProcessBase):
"""发送状态的消息类"""
def __init__(
self,
message_id: str,
chat_stream: ChatStream,
bot_user_info: UserInfo,
sender_info: UserInfo | None, # 用来记录发送者信息,用于私聊回复
message_segment: Seg,
reply: Optional["MessageRecv"] = None,
is_head: bool = False,
is_emoji: bool = False,
thinking_start_time: float = 0,
apply_set_reply_logic: bool = False,
):
# 调用父类初始化
super().__init__(
message_id=message_id,
chat_stream=chat_stream,
bot_user_info=bot_user_info,
message_segment=message_segment,
reply=reply,
thinking_start_time=thinking_start_time,
)
# 发送状态特有属性
self.sender_info = sender_info
self.reply_to_message_id = reply.message_info.message_id if reply else None
self.is_head = is_head
self.is_emoji = is_emoji
self.apply_set_reply_logic = apply_set_reply_logic
def set_reply(self, reply: Optional["MessageRecv"] = None):
"""设置回复消息"""
# print(f"set_reply: {reply}")
# if self.message_info.format_info is not None and "reply" in self.message_info.format_info.accept_format:
if True:
if reply:
self.reply = reply
if self.reply:
self.reply_to_message_id = self.reply.message_info.message_id
self.message_segment = Seg(
type="seglist",
data=[
Seg(type="reply", data=self.reply.message_info.message_id),
self.message_segment,
],
)
async def process(self) -> None:
"""处理消息内容,生成纯文本和详细文本"""
if self.message_segment:
self.processed_plain_text = await self._process_message_segments(self.message_segment)
self.detailed_plain_text = self._generate_detailed_text()
@classmethod
def from_thinking(
cls,
thinking: MessageThinking,
message_segment: Seg,
is_head: bool = False,
is_emoji: bool = False,
) -> "MessageSending":
"""从思考状态消息创建发送状态消息"""
return cls(
message_id=thinking.message_info.message_id,
chat_stream=thinking.chat_stream,
message_segment=message_segment,
bot_user_info=thinking.message_info.user_info,
reply=thinking.reply,
is_head=is_head,
is_emoji=is_emoji,
sender_info=None,
)
def to_dict(self):
ret = super().to_dict()
ret["message_info"]["user_info"] = self.chat_stream.user_info.to_dict()
return ret
def is_private_message(self) -> bool:
"""判断是否为私聊消息"""
return self.message_info.group_info is None or self.message_info.group_info.group_id is None
@dataclass
class MessageSet:
"""消息集合类,可以存储多个发送消息"""
def __init__(self, chat_stream: ChatStream, message_id: str):
self.chat_stream = chat_stream
self.message_id = message_id
self.messages: list[MessageSending] = []
self.time = round(time.time(), 3) # 保留3位小数
def add_message(self, message: MessageSending) -> None:
"""添加消息到集合"""
if not isinstance(message, MessageSending):
raise TypeError("MessageSet只能添加MessageSending类型的消息")
self.messages.append(message)
self.messages.sort(key=lambda x: x.message_info.time)
def get_message_by_index(self, index: int) -> Optional[MessageSending]:
"""通过索引获取消息"""
if 0 <= index < len(self.messages):
return self.messages[index]
return None
def get_message_by_time(self, target_time: float) -> Optional[MessageSending]:
"""获取最接近指定时间的消息"""
if not self.messages:
return None
left, right = 0, len(self.messages) - 1
while left < right:
mid = (left + right) // 2
if self.messages[mid].message_info.time < target_time:
left = mid + 1
else:
right = mid
return self.messages[left]
def clear_messages(self) -> None:
"""清空所有消息"""
self.messages.clear()
def remove_message(self, message: MessageSending) -> bool:
"""移除指定消息"""
if message in self.messages:
self.messages.remove(message)
return True
return False
def __str__(self) -> str:
return f"MessageSet(id={self.message_id}, count={len(self.messages)})"
def __len__(self) -> int:
return len(self.messages)

View File

@@ -0,0 +1,216 @@
from ..person_info.person_info import person_info_manager
from src.common.logger_manager import get_logger
import asyncio
from dataclasses import dataclass, field
from .message import MessageRecv
from maim_message import BaseMessageInfo, GroupInfo
import hashlib
from typing import Dict
from collections import OrderedDict
import random
import time
from ...config.config import global_config
logger = get_logger("message_buffer")
@dataclass
class CacheMessages:
message: MessageRecv
cache_determination: asyncio.Event = field(default_factory=asyncio.Event) # 判断缓冲是否产生结果
result: str = "U"
class MessageBuffer:
def __init__(self):
self.buffer_pool: Dict[str, OrderedDict[str, CacheMessages]] = {}
self.lock = asyncio.Lock()
@staticmethod
def get_person_id_(platform: str, user_id: str, group_info: GroupInfo):
"""获取唯一id"""
if group_info:
group_id = group_info.group_id
else:
group_id = "私聊"
key = f"{platform}_{user_id}_{group_id}"
return hashlib.md5(key.encode()).hexdigest()
async def start_caching_messages(self, message: MessageRecv):
"""添加消息,启动缓冲"""
if not global_config.message_buffer:
person_id = person_info_manager.get_person_id(
message.message_info.user_info.platform, message.message_info.user_info.user_id
)
asyncio.create_task(self.save_message_interval(person_id, message.message_info))
return
person_id_ = self.get_person_id_(
message.message_info.platform, message.message_info.user_info.user_id, message.message_info.group_info
)
async with self.lock:
if person_id_ not in self.buffer_pool:
self.buffer_pool[person_id_] = OrderedDict()
# 标记该用户之前的未处理消息
for cache_msg in self.buffer_pool[person_id_].values():
if cache_msg.result == "U":
cache_msg.result = "F"
cache_msg.cache_determination.set()
logger.debug(f"被新消息覆盖信息id: {cache_msg.message.message_info.message_id}")
# 查找最近的处理成功消息(T)
recent_f_count = 0
for msg_id in reversed(self.buffer_pool[person_id_]):
msg = self.buffer_pool[person_id_][msg_id]
if msg.result == "T":
break
elif msg.result == "F":
recent_f_count += 1
# 判断条件最近T之后有超过3-5条F
if recent_f_count >= random.randint(3, 5):
new_msg = CacheMessages(message=message, result="T")
new_msg.cache_determination.set()
self.buffer_pool[person_id_][message.message_info.message_id] = new_msg
logger.debug(f"快速处理消息(已堆积{recent_f_count}条F): {message.message_info.message_id}")
return
# 添加新消息
self.buffer_pool[person_id_][message.message_info.message_id] = CacheMessages(message=message)
# 启动3秒缓冲计时器
person_id = person_info_manager.get_person_id(
message.message_info.user_info.platform, message.message_info.user_info.user_id
)
asyncio.create_task(self.save_message_interval(person_id, message.message_info))
asyncio.create_task(self._debounce_processor(person_id_, message.message_info.message_id, person_id))
async def _debounce_processor(self, person_id_: str, message_id: str, person_id: str):
"""等待3秒无新消息"""
interval_time = await person_info_manager.get_value(person_id, "msg_interval")
if not isinstance(interval_time, (int, str)) or not str(interval_time).isdigit():
logger.debug("debounce_processor无效的时间")
return
interval_time = max(0.5, int(interval_time) / 1000)
await asyncio.sleep(interval_time)
async with self.lock:
if person_id_ not in self.buffer_pool or message_id not in self.buffer_pool[person_id_]:
logger.debug(f"消息已被清理msgid: {message_id}")
return
cache_msg = self.buffer_pool[person_id_][message_id]
if cache_msg.result == "U":
cache_msg.result = "T"
cache_msg.cache_determination.set()
async def query_buffer_result(self, message: MessageRecv) -> bool:
"""查询缓冲结果,并清理"""
if not global_config.message_buffer:
return True
person_id_ = self.get_person_id_(
message.message_info.platform, message.message_info.user_info.user_id, message.message_info.group_info
)
async with self.lock:
user_msgs = self.buffer_pool.get(person_id_, {})
cache_msg = user_msgs.get(message.message_info.message_id)
if not cache_msg:
logger.debug(f"查询异常消息不存在msgid: {message.message_info.message_id}")
return False # 消息不存在或已清理
try:
await asyncio.wait_for(cache_msg.cache_determination.wait(), timeout=10)
result = cache_msg.result == "T"
if result:
async with self.lock: # 再次加锁
# 清理所有早于当前消息的已处理消息, 收集所有早于当前消息的F消息的processed_plain_text
keep_msgs = OrderedDict() # 用于存放 T 消息之后的消息
collected_texts = [] # 用于收集 T 消息及之前 F 消息的文本
process_target_found = False
# 遍历当前用户的所有缓冲消息
for msg_id, cache_msg in self.buffer_pool[person_id_].items():
# 如果找到了目标处理消息 (T 状态)
if msg_id == message.message_info.message_id:
process_target_found = True
# 收集这条 T 消息的文本 (如果有)
if (
hasattr(cache_msg.message, "processed_plain_text")
and cache_msg.message.processed_plain_text
):
collected_texts.append(cache_msg.message.processed_plain_text)
# 不立即放入 keep_msgs因为它之前的 F 消息也处理完了
# 如果已经找到了目标 T 消息,之后的消息需要保留
elif process_target_found:
keep_msgs[msg_id] = cache_msg
# 如果还没找到目标 T 消息,说明是之前的消息 (F 或 U)
else:
if cache_msg.result == "F":
# 收集这条 F 消息的文本 (如果有)
if (
hasattr(cache_msg.message, "processed_plain_text")
and cache_msg.message.processed_plain_text
):
collected_texts.append(cache_msg.message.processed_plain_text)
elif cache_msg.result == "U":
# 理论上不应该在 T 消息之前还有 U 消息,记录日志
logger.warning(
f"异常状态:在目标 T 消息 {message.message_info.message_id} 之前发现未处理的 U 消息 {cache_msg.message.message_info.message_id}"
)
# 也可以选择收集其文本
if (
hasattr(cache_msg.message, "processed_plain_text")
and cache_msg.message.processed_plain_text
):
collected_texts.append(cache_msg.message.processed_plain_text)
# 更新当前消息 (message) 的 processed_plain_text
# 只有在收集到的文本多于一条,或者只有一条但与原始文本不同时才合并
if collected_texts:
# 使用 OrderedDict 去重,同时保留原始顺序
unique_texts = list(OrderedDict.fromkeys(collected_texts))
merged_text = "".join(unique_texts)
# 只有在合并后的文本与原始文本不同时才更新
# 并且确保不是空合并
if merged_text and merged_text != message.processed_plain_text:
message.processed_plain_text = merged_text
# 如果合并了文本,原消息不再视为纯 emoji
if hasattr(message, "is_emoji"):
message.is_emoji = False
logger.debug(
f"合并了 {len(unique_texts)} 条消息的文本内容到当前消息 {message.message_info.message_id}"
)
# 更新缓冲池,只保留 T 消息之后的消息
self.buffer_pool[person_id_] = keep_msgs
return result
except asyncio.TimeoutError:
logger.debug(f"查询超时消息id {message.message_info.message_id}")
return False
@staticmethod
async def save_message_interval(person_id: str, message: BaseMessageInfo):
message_interval_list = await person_info_manager.get_value(person_id, "msg_interval_list")
now_time_ms = int(round(time.time() * 1000))
if len(message_interval_list) < 1000:
message_interval_list.append(now_time_ms)
else:
message_interval_list.pop(0)
message_interval_list.append(now_time_ms)
data = {
"platform": message.platform,
"user_id": message.user_info.user_id,
"nickname": message.user_info.user_nickname,
"konw_time": int(time.time()),
}
await person_info_manager.update_one_field(person_id, "msg_interval_list", message_interval_list, data)
message_buffer = MessageBuffer()

View File

@@ -0,0 +1,343 @@
# src/plugins/chat/message_sender.py
import asyncio
import time
from asyncio import Task
from typing import Union
from src.common.message.api import global_api
# from ...common.database import db # 数据库依赖似乎不需要了,注释掉
from .message import MessageSending, MessageThinking, MessageSet
from .storage import MessageStorage
from ...config.config import global_config
from ..utils.utils import truncate_message, calculate_typing_time, count_messages_between
from src.common.logger_manager import get_logger
from rich.traceback import install
install(extra_lines=3)
logger = get_logger("sender")
async def send_via_ws(message: MessageSending) -> None:
"""通过 WebSocket 发送消息"""
try:
await global_api.send_message(message)
except Exception as e:
logger.error(f"WS发送失败: {e}")
raise ValueError(f"未找到平台:{message.message_info.platform} 的url配置请检查配置文件") from e
async def send_message(
message: MessageSending,
) -> None:
"""发送消息(核心发送逻辑)"""
# --- 添加计算打字和延迟的逻辑 (从 heartflow_message_sender 移动并调整) ---
typing_time = calculate_typing_time(
input_string=message.processed_plain_text,
thinking_start_time=message.thinking_start_time,
is_emoji=message.is_emoji,
)
# logger.trace(f"{message.processed_plain_text},{typing_time},计算输入时间结束") # 减少日志
await asyncio.sleep(typing_time)
# logger.trace(f"{message.processed_plain_text},{typing_time},等待输入时间结束") # 减少日志
# --- 结束打字延迟 ---
message_preview = truncate_message(message.processed_plain_text)
try:
await send_via_ws(message)
logger.success(f"发送消息 '{message_preview}' 成功") # 调整日志格式
except Exception as e:
logger.error(f"发送消息 '{message_preview}' 失败: {str(e)}")
class MessageSender:
"""发送器 (不再是单例)"""
def __init__(self):
self.message_interval = (0.5, 1) # 消息间隔时间范围(秒)
self.last_send_time = 0
self._current_bot = None
def set_bot(self, bot):
"""设置当前bot实例"""
pass
class MessageContainer:
"""单个聊天流的发送/思考消息容器"""
def __init__(self, chat_id: str, max_size: int = 100):
self.chat_id = chat_id
self.max_size = max_size
self.messages: list[MessageThinking | MessageSending] = [] # 明确类型
self.last_send_time = 0
self.thinking_wait_timeout = 20 # 思考等待超时时间(秒) - 从旧 sender 合并
def count_thinking_messages(self) -> int:
"""计算当前容器中思考消息的数量"""
return sum(1 for msg in self.messages if isinstance(msg, MessageThinking))
def get_timeout_sending_messages(self) -> list[MessageSending]:
"""获取所有超时的MessageSending对象思考时间超过20秒按thinking_start_time排序 - 从旧 sender 合并"""
current_time = time.time()
timeout_messages = []
for msg in self.messages:
# 只检查 MessageSending 类型
if isinstance(msg, MessageSending):
# 确保 thinking_start_time 有效
if msg.thinking_start_time and current_time - msg.thinking_start_time > self.thinking_wait_timeout:
timeout_messages.append(msg)
# 按thinking_start_time排序时间早的在前面
timeout_messages.sort(key=lambda x: x.thinking_start_time)
return timeout_messages
def get_earliest_message(self):
"""获取thinking_start_time最早的消息对象"""
if not self.messages:
return None
earliest_time = float("inf")
earliest_message = None
for msg in self.messages:
# 确保消息有 thinking_start_time 属性
msg_time = getattr(msg, "thinking_start_time", float("inf"))
if msg_time < earliest_time:
earliest_time = msg_time
earliest_message = msg
return earliest_message
def add_message(self, message: Union[MessageThinking, MessageSending, MessageSet]):
"""添加消息到队列"""
if isinstance(message, MessageSet):
for single_message in message.messages:
self.messages.append(single_message)
else:
self.messages.append(message)
def remove_message(self, message_to_remove: Union[MessageThinking, MessageSending]):
"""移除指定的消息对象如果消息存在则返回True否则返回False"""
try:
_initial_len = len(self.messages)
# 使用列表推导式或 message_filter 创建新列表,排除要删除的元素
# self.messages = [msg for msg in self.messages if msg is not message_to_remove]
# 或者直接 remove (如果确定对象唯一性)
if message_to_remove in self.messages:
self.messages.remove(message_to_remove)
return True
# logger.debug(f"Removed message {getattr(message_to_remove, 'message_info', {}).get('message_id', 'UNKNOWN')}. Old len: {initial_len}, New len: {len(self.messages)}")
# return len(self.messages) < initial_len
return False
except Exception as e:
logger.exception(f"移除消息时发生错误: {e}")
return False
def has_messages(self) -> bool:
"""检查是否有待发送的消息"""
return bool(self.messages)
def get_all_messages(self) -> list[MessageThinking | MessageSending]:
"""获取所有消息"""
return list(self.messages) # 返回副本
class MessageManager:
"""管理所有聊天流的消息容器 (不再是单例)"""
def __init__(self):
self._processor_task: Task | None = None
self.containers: dict[str, MessageContainer] = {}
self.storage = MessageStorage() # 添加 storage 实例
self._running = True # 处理器运行状态
self._container_lock = asyncio.Lock() # 保护 containers 字典的锁
# self.message_sender = MessageSender() # 创建发送器实例 (改为全局实例)
async def start(self):
"""启动后台处理器任务。"""
# 检查是否已有任务在运行,避免重复启动
if self._processor_task is not None and not self._processor_task.done():
logger.warning("Processor task already running.")
return
self._processor_task = asyncio.create_task(self._start_processor_loop())
logger.debug("MessageManager processor task started.")
def stop(self):
"""停止后台处理器任务。"""
self._running = False
if self._processor_task is not None and not self._processor_task.done():
self._processor_task.cancel()
logger.debug("MessageManager processor task stopping.")
else:
logger.debug("MessageManager processor task not running or already stopped.")
async def get_container(self, chat_id: str) -> MessageContainer:
"""获取或创建聊天流的消息容器 (异步,使用锁)"""
async with self._container_lock:
if chat_id not in self.containers:
self.containers[chat_id] = MessageContainer(chat_id)
return self.containers[chat_id]
async def add_message(self, message: Union[MessageThinking, MessageSending, MessageSet]) -> None:
"""添加消息到对应容器"""
chat_stream = message.chat_stream
if not chat_stream:
logger.error("消息缺少 chat_stream无法添加到容器")
return # 或者抛出异常
container = await self.get_container(chat_stream.stream_id)
container.add_message(message)
def check_if_sending_message_exist(self, chat_id, thinking_id):
"""检查指定聊天流的容器中是否存在具有特定 thinking_id 的 MessageSending 消息 或 emoji 消息"""
# 这个方法现在是非异步的,因为它只读取数据
container = self.containers.get(chat_id) # 直接 get因为读取不需要锁
if container and container.has_messages():
for message in container.get_all_messages():
if isinstance(message, MessageSending):
msg_id = getattr(message.message_info, "message_id", None)
# 检查 message_id 是否匹配 thinking_id 或以 "me" 开头 (emoji)
if msg_id == thinking_id or (msg_id and msg_id.startswith("me")):
# logger.debug(f"检查到存在相同thinking_id或emoji的消息: {msg_id} for {thinking_id}")
return True
return False
async def _handle_sending_message(self, container: MessageContainer, message: MessageSending):
"""处理单个 MessageSending 消息 (包含 set_reply 逻辑)"""
try:
_ = message.update_thinking_time() # 更新思考时间
thinking_start_time = message.thinking_start_time
now_time = time.time()
# logger.debug(f"thinking_start_time:{thinking_start_time},now_time:{now_time}")
thinking_messages_count, thinking_messages_length = count_messages_between(
start_time=thinking_start_time, end_time=now_time, stream_id=message.chat_stream.stream_id
)
# print(f"message.reply:{message.reply}")
# --- 条件应用 set_reply 逻辑 ---
# logger.debug(
# f"[message.apply_set_reply_logic:{message.apply_set_reply_logic},message.is_head:{message.is_head},thinking_messages_count:{thinking_messages_count},thinking_messages_length:{thinking_messages_length},message.is_private_message():{message.is_private_message()}]"
# )
if (
message.apply_set_reply_logic # 检查标记
and message.is_head
and (thinking_messages_count > 3 or thinking_messages_length > 200)
and not message.is_private_message()
):
logger.debug(
f"[{message.chat_stream.stream_id}] 应用 set_reply 逻辑: {message.processed_plain_text[:20]}..."
)
message.set_reply(message.reply)
# --- 结束条件 set_reply ---
await message.process() # 预处理消息内容
# logger.debug(f"{message}")
# 使用全局 message_sender 实例
await send_message(message)
await self.storage.store_message(message, message.chat_stream)
# 移除消息要在发送 *之后*
container.remove_message(message)
# logger.debug(f"[{message.chat_stream.stream_id}] Sent and removed message: {message.message_info.message_id}")
except Exception as e:
logger.error(
f"[{message.chat_stream.stream_id}] 处理发送消息 {getattr(message.message_info, 'message_id', 'N/A')} 时出错: {e}"
)
logger.exception("详细错误信息:")
# 考虑是否移除出错的消息,防止无限循环
removed = container.remove_message(message)
if removed:
logger.warning(f"[{message.chat_stream.stream_id}] 已移除处理出错的消息。")
async def _process_chat_messages(self, chat_id: str):
"""处理单个聊天流消息 (合并后的逻辑)"""
container = await self.get_container(chat_id) # 获取容器是异步的了
if container.has_messages():
message_earliest = container.get_earliest_message()
if not message_earliest: # 如果最早消息为空,则退出
return
if isinstance(message_earliest, MessageThinking):
# --- 处理思考消息 (来自旧 sender) ---
message_earliest.update_thinking_time()
thinking_time = message_earliest.thinking_time
# 减少控制台刷新频率或只在时间显著变化时打印
if int(thinking_time) % 5 == 0: # 每5秒打印一次
print(
f"消息 {message_earliest.message_info.message_id} 正在思考中,已思考 {int(thinking_time)}\r",
end="",
flush=True,
)
# 检查是否超时
if thinking_time > global_config.thinking_timeout:
logger.warning(
f"[{chat_id}] 消息思考超时 ({thinking_time:.1f}秒),移除消息 {message_earliest.message_info.message_id}"
)
container.remove_message(message_earliest)
print() # 超时后换行,避免覆盖下一条日志
elif isinstance(message_earliest, MessageSending):
# --- 处理发送消息 ---
await self._handle_sending_message(container, message_earliest)
# --- 处理超时发送消息 (来自旧 sender) ---
# 在处理完最早的消息后,检查是否有超时的发送消息
timeout_sending_messages = container.get_timeout_sending_messages()
if timeout_sending_messages:
logger.debug(f"[{chat_id}] 发现 {len(timeout_sending_messages)} 条超时的发送消息")
for msg in timeout_sending_messages:
# 确保不是刚刚处理过的最早消息 (虽然理论上应该已被移除,但以防万一)
if msg is message_earliest:
continue
logger.info(f"[{chat_id}] 处理超时发送消息: {msg.message_info.message_id}")
await self._handle_sending_message(container, msg) # 复用处理逻辑
# 清理空容器 (可选)
# async with self._container_lock:
# if not container.has_messages() and chat_id in self.containers:
# logger.debug(f"[{chat_id}] 容器已空,准备移除。")
# del self.containers[chat_id]
async def _start_processor_loop(self):
"""消息处理器主循环"""
while self._running:
tasks = []
# 使用异步锁保护迭代器创建过程
async with self._container_lock:
# 创建 keys 的快照以安全迭代
chat_ids = list(self.containers.keys())
for chat_id in chat_ids:
# 为每个 chat_id 创建一个处理任务
tasks.append(asyncio.create_task(self._process_chat_messages(chat_id)))
if tasks:
try:
# 等待当前批次的所有任务完成
await asyncio.gather(*tasks)
except Exception as e:
logger.error(f"消息处理循环 gather 出错: {e}")
# 等待一小段时间避免CPU空转
try:
await asyncio.sleep(0.1) # 稍微降低轮询频率
except asyncio.CancelledError:
logger.info("Processor loop sleep cancelled.")
break # 退出循环
logger.info("MessageManager processor loop finished.")
# --- 创建全局实例 ---
message_manager = MessageManager()
message_sender = MessageSender()
# --- 结束全局实例 ---

View File

@@ -0,0 +1,72 @@
import re
from typing import Union
from ...common.database import db
from .message import MessageSending, MessageRecv
from .chat_stream import ChatStream
from src.common.logger import get_module_logger
logger = get_module_logger("message_storage")
class MessageStorage:
@staticmethod
async def store_message(message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None:
"""存储消息到数据库"""
try:
# 莫越权 救世啊
pattern = r"<MainRule>.*?</MainRule>|<schedule>.*?</schedule>|<UserMessage>.*?</UserMessage>"
processed_plain_text = message.processed_plain_text
if processed_plain_text:
filtered_processed_plain_text = re.sub(pattern, "", processed_plain_text, flags=re.DOTALL)
else:
filtered_processed_plain_text = ""
detailed_plain_text = message.detailed_plain_text
if detailed_plain_text:
filtered_detailed_plain_text = re.sub(pattern, "", detailed_plain_text, flags=re.DOTALL)
else:
filtered_detailed_plain_text = ""
message_data = {
"message_id": message.message_info.message_id,
"time": message.message_info.time,
"chat_id": chat_stream.stream_id,
"chat_info": chat_stream.to_dict(),
"user_info": message.message_info.user_info.to_dict(),
# 使用过滤后的文本
"processed_plain_text": filtered_processed_plain_text,
"detailed_plain_text": filtered_detailed_plain_text,
"memorized_times": message.memorized_times,
}
db.messages.insert_one(message_data)
except Exception:
logger.exception("存储消息失败")
@staticmethod
async def store_recalled_message(message_id: str, time: str, chat_stream: ChatStream) -> None:
"""存储撤回消息到数据库"""
if "recalled_messages" not in db.list_collection_names():
db.create_collection("recalled_messages")
else:
try:
message_data = {
"message_id": message_id,
"time": time,
"stream_id": chat_stream.stream_id,
}
db.recalled_messages.insert_one(message_data)
except Exception:
logger.exception("存储撤回消息失败")
@staticmethod
async def remove_recalled_message(time: str) -> None:
"""删除撤回消息"""
try:
db.recalled_messages.delete_many({"time": {"$lt": time - 300}})
except Exception:
logger.exception("删除撤回消息失败")
# 如果需要其他存储相关的函数,可以在这里添加