This commit is contained in:
tcmofashi
2025-04-17 15:53:26 +08:00
52 changed files with 299 additions and 336 deletions

View File

@@ -103,7 +103,8 @@ class ChatManager:
except Exception as e:
logger.error(f"聊天流自动保存失败: {str(e)}")
def _ensure_collection(self):
@staticmethod
def _ensure_collection():
"""确保数据库集合存在并创建索引"""
if "chat_streams" not in db.list_collection_names():
db.create_collection("chat_streams")
@@ -111,7 +112,8 @@ class ChatManager:
db.chat_streams.create_index([("stream_id", 1)], unique=True)
db.chat_streams.create_index([("platform", 1), ("user_info.user_id", 1), ("group_info.group_id", 1)])
def _generate_stream_id(self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str:
@staticmethod
def _generate_stream_id(platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None) -> str:
"""生成聊天流唯一ID"""
if group_info:
# 组合关键信息
@@ -188,7 +190,8 @@ class ChatManager:
stream_id = self._generate_stream_id(platform, user_info, group_info)
return self.streams.get(stream_id)
async def _save_stream(self, stream: ChatStream):
@staticmethod
async def _save_stream(stream: ChatStream):
"""保存聊天流到数据库"""
if not stream.saved:
db.chat_streams.update_one({"stream_id": stream.stream_id}, {"$set": stream.to_dict()}, upsert=True)

View File

@@ -82,7 +82,8 @@ class EmojiManager:
if not self._initialized:
raise RuntimeError("EmojiManager not initialized")
def _ensure_emoji_collection(self):
@staticmethod
def _ensure_emoji_collection():
"""确保emoji集合存在并创建索引
这个函数用于确保MongoDB数据库中存在emoji集合,并创建必要的索引。
@@ -193,7 +194,8 @@ class EmojiManager:
logger.error(f"[错误] 获取表情包失败: {str(e)}")
return None
async def _get_emoji_description(self, image_base64: str) -> str:
@staticmethod
async def _get_emoji_description(image_base64: str) -> str:
"""获取表情包的标签使用image_manager的描述生成功能"""
try:
@@ -554,7 +556,8 @@ class EmojiManager:
self.check_emoji_file_full()
await asyncio.sleep(global_config.EMOJI_CHECK_INTERVAL * 60)
async def delete_all_images(self):
@staticmethod
async def delete_all_images():
"""删除 data/image 目录下的所有文件"""
try:
image_dir = os.path.join("data", "image")

View File

@@ -31,7 +31,7 @@ class Message(MessageBase):
def __init__(
self,
message_id: str,
time: float,
timestamp: float,
chat_stream: ChatStream,
user_info: UserInfo,
message_segment: Optional[Seg] = None,
@@ -43,7 +43,7 @@ class Message(MessageBase):
message_info = BaseMessageInfo(
platform=chat_stream.platform,
message_id=message_id,
time=time,
time=timestamp,
group_info=chat_stream.group_info,
user_info=user_info,
)
@@ -143,7 +143,7 @@ class MessageRecv(Message):
def _generate_detailed_text(self) -> str:
"""生成详细文本,包含时间和用户信息"""
# time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.message_info.time))
time = self.message_info.time
timestamp = self.message_info.time
user_info = self.message_info.user_info
# name = (
# f"{user_info.user_nickname}(ta的昵称:{user_info.user_cardname},ta的id:{user_info.user_id})"
@@ -151,7 +151,7 @@ class MessageRecv(Message):
# else f"{user_info.user_nickname}(ta的id:{user_info.user_id})"
# )
name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>"
return f"[{time}] {name}: {self.processed_plain_text}\n"
return f"[{timestamp}] {name}: {self.processed_plain_text}\n"
@dataclass
@@ -170,7 +170,7 @@ class MessageProcessBase(Message):
# 调用父类初始化
super().__init__(
message_id=message_id,
time=round(time.time(), 3), # 保留3位小数
timestamp=round(time.time(), 3), # 保留3位小数
chat_stream=chat_stream,
user_info=bot_user_info,
message_segment=message_segment,
@@ -242,7 +242,7 @@ class MessageProcessBase(Message):
def _generate_detailed_text(self) -> str:
"""生成详细文本,包含时间和用户信息"""
# time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.message_info.time))
time = self.message_info.time
timestamp = self.message_info.time
user_info = self.message_info.user_info
# name = (
# f"{user_info.user_nickname}(ta的昵称:{user_info.user_cardname},ta的id:{user_info.user_id})"
@@ -250,7 +250,7 @@ class MessageProcessBase(Message):
# else f"{user_info.user_nickname}(ta的id:{user_info.user_id})"
# )
name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>"
return f"[{time}] {name}: {self.processed_plain_text}\n"
return f"[{timestamp}] {name}: {self.processed_plain_text}\n"
@dataclass

View File

@@ -26,7 +26,8 @@ class MessageBuffer:
self.buffer_pool: Dict[str, OrderedDict[str, CacheMessages]] = {}
self.lock = asyncio.Lock()
def get_person_id_(self, platform: str, user_id: str, group_info: GroupInfo):
@staticmethod
def get_person_id_(platform: str, user_id: str, group_info: GroupInfo):
"""获取唯一id"""
if group_info:
group_id = group_info.group_id
@@ -150,20 +151,20 @@ class MessageBuffer:
keep_msgs[msg_id] = msg
elif msg.result == "F":
# 收集F消息的文本内容
F_type = "seglist"
f_type = "seglist"
if msg.message.message_segment.type != "seglist":
F_type = msg.message.message_segment.type
f_type = msg.message.message_segment.type
else:
if (
isinstance(msg.message.message_segment.data, list)
and all(isinstance(x, Seg) for x in msg.message.message_segment.data)
and len(msg.message.message_segment.data) == 1
):
F_type = msg.message.message_segment.data[0].type
f_type = msg.message.message_segment.data[0].type
if hasattr(msg.message, "processed_plain_text") and msg.message.processed_plain_text:
if F_type == "text":
if f_type == "text":
combined_text.append(msg.message.processed_plain_text)
elif F_type != "text":
elif f_type != "text":
is_update = False
elif msg.result == "U":
logger.debug(f"异常未处理信息id {msg.message.message_info.message_id}")
@@ -185,7 +186,8 @@ class MessageBuffer:
logger.debug(f"查询超时消息id {message.message_info.message_id}")
return False
async def save_message_interval(self, person_id: str, message: BaseMessageInfo):
@staticmethod
async def save_message_interval(person_id: str, message: BaseMessageInfo):
message_interval_list = await person_info_manager.get_value(person_id, "msg_interval_list")
now_time_ms = int(round(time.time() * 1000))
if len(message_interval_list) < 1000:

View File

@@ -35,7 +35,8 @@ class MessageSender:
"""设置当前bot实例"""
pass
def get_recalled_messages(self, stream_id: str) -> list:
@staticmethod
def get_recalled_messages(stream_id: str) -> list:
"""获取所有撤回的消息"""
recalled_messages = []
@@ -43,7 +44,8 @@ class MessageSender:
# 按thinking_start_time排序时间早的在前面
return recalled_messages
async def send_via_ws(self, message: MessageSending) -> None:
@staticmethod
async def send_via_ws(message: MessageSending) -> None:
try:
await global_api.send_message(message)
except Exception as e:

View File

@@ -135,7 +135,7 @@ async def get_recent_group_messages(chat_id: str, limit: int = 12) -> list:
msg = Message(
message_id=msg_data["message_id"],
chat_stream=chat_stream,
time=msg_data["time"],
timestamp=msg_data["time"],
user_info=user_info,
processed_plain_text=msg_data.get("processed_text", ""),
detailed_plain_text=msg_data.get("detailed_plain_text", ""),

View File

@@ -38,7 +38,8 @@ class ImageManager:
"""确保图像存储目录存在"""
os.makedirs(self.IMAGE_DIR, exist_ok=True)
def _ensure_image_collection(self):
@staticmethod
def _ensure_image_collection():
"""确保images集合存在并创建索引"""
if "images" not in db.list_collection_names():
db.create_collection("images")
@@ -50,7 +51,8 @@ class ImageManager:
db.images.create_index([("url", 1)])
db.images.create_index([("path", 1)])
def _ensure_description_collection(self):
@staticmethod
def _ensure_description_collection():
"""确保image_descriptions集合存在并创建索引"""
if "image_descriptions" not in db.list_collection_names():
db.create_collection("image_descriptions")
@@ -60,7 +62,8 @@ class ImageManager:
# 创建新的复合索引
db.image_descriptions.create_index([("hash", 1), ("type", 1)], unique=True)
def _get_description_from_db(self, image_hash: str, description_type: str) -> Optional[str]:
@staticmethod
def _get_description_from_db(image_hash: str, description_type: str) -> Optional[str]:
"""从数据库获取图片描述
Args:
@@ -73,7 +76,8 @@ class ImageManager:
result = db.image_descriptions.find_one({"hash": image_hash, "type": description_type})
return result["description"] if result else None
def _save_description_to_db(self, image_hash: str, description: str, description_type: str) -> None:
@staticmethod
def _save_description_to_db(image_hash: str, description: str, description_type: str) -> None:
"""保存图片描述到数据库
Args:
@@ -226,7 +230,8 @@ class ImageManager:
logger.error(f"获取图片描述失败: {str(e)}")
return "[图片]"
def transform_gif(self, gif_base64: str) -> str:
@staticmethod
def transform_gif(gif_base64: str) -> str:
"""将GIF转换为水平拼接的静态图像
Args: