feat: 史上最好的消息流重构和图片管理

This commit is contained in:
tcmofashi
2025-03-11 04:42:24 +08:00
parent 7899e67cb2
commit 8cbf9bb048
13 changed files with 272 additions and 235 deletions

View File

@@ -18,6 +18,7 @@ from .config import global_config
from .emoji_manager import emoji_manager from .emoji_manager import emoji_manager
from .relationship_manager import relationship_manager from .relationship_manager import relationship_manager
from .willing_manager import willing_manager from .willing_manager import willing_manager
from .chat_stream import chat_manager
# 创建LLM统计实例 # 创建LLM统计实例
llm_stats = LLMStatistics("llm_statistics.txt") llm_stats = LLMStatistics("llm_statistics.txt")
@@ -101,6 +102,8 @@ async def _(bot: Bot):
asyncio.create_task(emoji_manager._periodic_scan(interval_MINS=global_config.EMOJI_REGISTER_INTERVAL)) asyncio.create_task(emoji_manager._periodic_scan(interval_MINS=global_config.EMOJI_REGISTER_INTERVAL))
print("\033[1;38;5;208m-----------开始偷表情包!-----------\033[0m") print("\033[1;38;5;208m-----------开始偷表情包!-----------\033[0m")
asyncio.create_task(chat_manager._initialize())
asyncio.create_task(chat_manager._auto_save_task())
@group_msg.handle() @group_msg.handle()
async def _(bot: Bot, event: GroupMessageEvent, state: T_State): async def _(bot: Bot, event: GroupMessageEvent, state: T_State):

View File

@@ -18,7 +18,7 @@ from .chat_stream import chat_manager
from .message_sender import message_manager # 导入新的消息管理器 from .message_sender import message_manager # 导入新的消息管理器
from .relationship_manager import relationship_manager from .relationship_manager import relationship_manager
from .storage import MessageStorage from .storage import MessageStorage
from .utils import calculate_typing_time, is_mentioned_bot_in_txt from .utils import calculate_typing_time, is_mentioned_bot_in_message
from .utils_image import image_path_to_base64 from .utils_image import image_path_to_base64
from .willing_manager import willing_manager # 导入意愿管理器 from .willing_manager import willing_manager # 导入意愿管理器
from .message_base import UserInfo, GroupInfo, Seg from .message_base import UserInfo, GroupInfo, Seg
@@ -45,8 +45,8 @@ class ChatBot:
self.bot = bot # 更新 bot 实例 self.bot = bot # 更新 bot 实例
group_info = await bot.get_group_info(group_id=event.group_id) # group_info = await bot.get_group_info(group_id=event.group_id)
sender_info = await bot.get_group_member_info(group_id=event.group_id, user_id=event.user_id, no_cache=True) # sender_info = await bot.get_group_member_info(group_id=event.group_id, user_id=event.user_id, no_cache=True)
# 白名单设定由nontbot侧完成 # 白名单设定由nontbot侧完成
if event.group_id: if event.group_id:
@@ -54,19 +54,32 @@ class ChatBot:
return return
if event.user_id in global_config.ban_user_id: if event.user_id in global_config.ban_user_id:
return return
user_info=UserInfo(
user_id=event.user_id,
user_nickname=event.sender.nickname,
user_cardname=event.sender.card or None,
platform='qq'
)
group_info=GroupInfo(
group_id=event.group_id,
group_name=None,
platform='qq'
)
message_cq=MessageRecvCQ( message_cq=MessageRecvCQ(
message_id=event.message_id, message_id=event.message_id,
user_id=event.user_id, user_info=user_info,
raw_message=str(event.original_message), raw_message=str(event.original_message),
group_id=event.group_id, group_info=group_info,
reply_message=event.reply, reply_message=event.reply,
platform='qq' platform='qq'
) )
message_json=message_cq.to_dict() message_json=message_cq.to_dict()
# 进入maimbot # 进入maimbot
message=MessageRecv(**message_json) message=MessageRecv(message_json)
groupinfo=message.message_info.group_info groupinfo=message.message_info.group_info
userinfo=message.message_info.user_info userinfo=message.message_info.user_info
@@ -75,6 +88,7 @@ class ChatBot:
# 消息过滤涉及到config有待更新 # 消息过滤涉及到config有待更新
chat = await chat_manager.get_or_create_stream(platform=messageinfo.platform, user_info=userinfo, group_info=groupinfo) chat = await chat_manager.get_or_create_stream(platform=messageinfo.platform, user_info=userinfo, group_info=groupinfo)
message.update_chat_stream(chat)
await relationship_manager.update_relationship(chat_stream=chat,) await relationship_manager.update_relationship(chat_stream=chat,)
await relationship_manager.update_relationship_value(chat_stream=chat, relationship_value = 0.5) await relationship_manager.update_relationship_value(chat_stream=chat, relationship_value = 0.5)
@@ -99,7 +113,7 @@ class ChatBot:
await self.storage.store_message(message,chat, topic[0] if topic else None) await self.storage.store_message(message,chat, topic[0] if topic else None)
is_mentioned = is_mentioned_bot_in_txt(message.processed_plain_text) is_mentioned = is_mentioned_bot_in_message(message)
reply_probability = await willing_manager.change_reply_willing_received( reply_probability = await willing_manager.change_reply_willing_received(
chat_stream=chat, chat_stream=chat,
topic=topic[0] if topic else None, topic=topic[0] if topic else None,

View File

@@ -1,6 +1,7 @@
import asyncio import asyncio
import hashlib import hashlib
import time import time
import copy
from typing import Dict, Optional from typing import Dict, Optional
from loguru import logger from loguru import logger
@@ -86,9 +87,9 @@ class ChatManager:
self._ensure_collection() self._ensure_collection()
self._initialized = True self._initialized = True
# 在事件循环中启动初始化 # 在事件循环中启动初始化
asyncio.create_task(self._initialize()) # asyncio.create_task(self._initialize())
# 启动自动保存任务 # # 启动自动保存任务
asyncio.create_task(self._auto_save_task()) # asyncio.create_task(self._auto_save_task())
async def _initialize(self): async def _initialize(self):
"""异步初始化""" """异步初始化"""
@@ -122,12 +123,18 @@ class ChatManager:
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
) -> str: ) -> str:
"""生成聊天流唯一ID""" """生成聊天流唯一ID"""
# 组合关键信息 if group_info:
components = [ # 组合关键信息
platform, components = [
str(user_info.user_id), platform,
str(group_info.group_id) if group_info else "private", str(group_info.group_id)
] ]
else:
components = [
platform,
str(user_info.user_id),
"private"
]
# 使用MD5生成唯一ID # 使用MD5生成唯一ID
key = "_".join(components) key = "_".join(components)
@@ -153,10 +160,11 @@ class ChatManager:
if stream_id in self.streams: if stream_id in self.streams:
stream = self.streams[stream_id] stream = self.streams[stream_id]
# 更新用户信息和群组信息 # 更新用户信息和群组信息
stream.update_active_time()
stream=copy.deepcopy(stream)
stream.user_info = user_info stream.user_info = user_info
if group_info: if group_info:
stream.group_info = group_info stream.group_info = group_info
stream.update_active_time()
return stream return stream
# 检查数据库中是否存在 # 检查数据库中是否存在
@@ -180,7 +188,7 @@ class ChatManager:
# 保存到内存和数据库 # 保存到内存和数据库
self.streams[stream_id] = stream self.streams[stream_id] = stream
await self._save_stream(stream) await self._save_stream(stream)
return stream return copy.deepcopy(stream)
def get_stream(self, stream_id: str) -> Optional[ChatStream]: def get_stream(self, stream_id: str) -> Optional[ChatStream]:
"""通过stream_id获取聊天流""" """通过stream_id获取聊天流"""

View File

@@ -59,6 +59,7 @@ class CQCode:
params: Dict[str, str] params: Dict[str, str]
group_id: int group_id: int
user_id: int user_id: int
user_nickname: str
group_name: str = "" group_name: str = ""
user_nickname: str = "" user_nickname: str = ""
translated_segments: Optional[Union[Seg, List[Seg]]] = None translated_segments: Optional[Union[Seg, List[Seg]]] = None
@@ -68,9 +69,7 @@ class CQCode:
def __post_init__(self): def __post_init__(self):
"""初始化LLM实例""" """初始化LLM实例"""
self._llm = LLM_request( pass
model=global_config.vlm, temperature=0.4, max_tokens=300
)
def translate(self): def translate(self):
"""根据CQ码类型进行相应的翻译处理转换为Seg对象""" """根据CQ码类型进行相应的翻译处理转换为Seg对象"""
@@ -225,8 +224,7 @@ class CQCode:
group_id=msg.get("group_id", 0), group_id=msg.get("group_id", 0),
) )
content_seg = Seg( content_seg = Seg(
type="seglist", data=message_obj.message_segments type="seglist", data=message_obj.message_segment )
)
else: else:
content_seg = Seg(type="text", data="[空消息]") content_seg = Seg(type="text", data="[空消息]")
else: else:
@@ -241,7 +239,7 @@ class CQCode:
group_id=msg.get("group_id", 0), group_id=msg.get("group_id", 0),
) )
content_seg = Seg( content_seg = Seg(
type="seglist", data=message_obj.message_segments type="seglist", data=message_obj.message_segment
) )
else: else:
content_seg = Seg(type="text", data="[空消息]") content_seg = Seg(type="text", data="[空消息]")
@@ -272,7 +270,7 @@ class CQCode:
) )
segments = [] segments = []
if message_obj.user_id == global_config.BOT_QQ: if message_obj.message_info.user_info.user_id == global_config.BOT_QQ:
segments.append( segments.append(
Seg( Seg(
type="text", data=f"[回复 {global_config.BOT_NICKNAME} 的消息: " type="text", data=f"[回复 {global_config.BOT_NICKNAME} 的消息: "
@@ -286,7 +284,7 @@ class CQCode:
) )
) )
segments.append(Seg(type="seglist", data=message_obj.message_segments)) segments.append(Seg(type="seglist", data=message_obj.message_segment))
segments.append(Seg(type="text", data="]")) segments.append(Seg(type="text", data="]"))
return segments return segments
else: else:
@@ -305,12 +303,13 @@ class CQCode:
class CQCode_tool: class CQCode_tool:
@staticmethod @staticmethod
def cq_from_dict_to_class(cq_code: Dict, reply: Optional[Dict] = None) -> CQCode: def cq_from_dict_to_class(cq_code: Dict,msg ,reply: Optional[Dict] = None) -> CQCode:
""" """
将CQ码字典转换为CQCode对象 将CQ码字典转换为CQCode对象
Args: Args:
cq_code: CQ码字典 cq_code: CQ码字典
msg: MessageCQ对象
reply: 回复消息的字典(可选) reply: 回复消息的字典(可选)
Returns: Returns:
@@ -326,7 +325,13 @@ class CQCode_tool:
params = cq_code.get("data", {}) params = cq_code.get("data", {})
instance = CQCode( instance = CQCode(
type=cq_type, params=params, group_id=0, user_id=0, reply_message=reply type=cq_type,
params=params,
group_id=msg.message_info.group_info.group_id,
user_id=msg.message_info.user_info.user_id,
user_nickname=msg.message_info.user_info.user_nickname,
group_name=msg.message_info.group_info.group_name,
reply_message=reply
) )
# 进行翻译处理 # 进行翻译处理
@@ -383,6 +388,25 @@ class CQCode_tool:
) )
# 生成CQ码设置sub_type=1表示这是表情包 # 生成CQ码设置sub_type=1表示这是表情包
return f"[CQ:image,file=base64://{escaped_base64},sub_type=1]" return f"[CQ:image,file=base64://{escaped_base64},sub_type=1]"
@staticmethod
def create_image_cq_base64(base64_data: str) -> str:
"""
创建表情包CQ码
Args:
base64_data: base64编码的表情包数据
Returns:
表情包CQ码字符串
"""
# 转义base64数据
escaped_base64 = (
base64_data.replace("&", "&")
.replace("[", "[")
.replace("]", "]")
.replace(",", ",")
)
# 生成CQ码设置sub_type=1表示这是表情包
return f"[CQ:image,file=base64://{escaped_base64},sub_type=0]"
cq_code_tool = CQCode_tool() cq_code_tool = CQCode_tool()

View File

@@ -239,7 +239,7 @@ class EmojiManager:
# 即使表情包已存在也检查是否需要同步到images集合 # 即使表情包已存在也检查是否需要同步到images集合
description = existing_emoji.get('discription') description = existing_emoji.get('discription')
# 检查是否在images集合中存在 # 检查是否在images集合中存在
existing_image = await image_manager.db.db.images.find_one({'hash': image_hash}) existing_image = image_manager.db.db.images.find_one({'hash': image_hash})
if not existing_image: if not existing_image:
# 同步到images集合 # 同步到images集合
image_doc = { image_doc = {
@@ -249,7 +249,7 @@ class EmojiManager:
'description': description, 'description': description,
'timestamp': int(time.time()) 'timestamp': int(time.time())
} }
await image_manager.db.db.images.update_one( image_manager.db.db.images.update_one(
{'hash': image_hash}, {'hash': image_hash},
{'$set': image_doc}, {'$set': image_doc},
upsert=True upsert=True
@@ -260,7 +260,7 @@ class EmojiManager:
continue continue
# 检查是否在images集合中已有描述 # 检查是否在images集合中已有描述
existing_description = await image_manager._get_description_from_db(image_hash, 'emoji') existing_description = image_manager._get_description_from_db(image_hash, 'emoji')
if existing_description: if existing_description:
description = existing_description description = existing_description
@@ -302,13 +302,13 @@ class EmojiManager:
'description': description, 'description': description,
'timestamp': int(time.time()) 'timestamp': int(time.time())
} }
await image_manager.db.db.images.update_one( image_manager.db.db.images.update_one(
{'hash': image_hash}, {'hash': image_hash},
{'$set': image_doc}, {'$set': image_doc},
upsert=True upsert=True
) )
# 保存描述到image_descriptions集合 # 保存描述到image_descriptions集合
await image_manager._save_description_to_db(image_hash, description, 'emoji') image_manager._save_description_to_db(image_hash, description, 'emoji')
logger.success(f"同步保存到images集合: {filename}") logger.success(f"同步保存到images集合: {filename}")
else: else:
logger.warning(f"跳过表情包: {filename}") logger.warning(f"跳过表情包: {filename}")

View File

@@ -7,7 +7,7 @@ from nonebot import get_driver
from ...common.database import Database from ...common.database import Database
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
from .config import global_config from .config import global_config
from .message import MessageRecv, MessageThinking, MessageSending from .message import MessageRecv, MessageThinking, MessageSending,Message
from .prompt_builder import prompt_builder from .prompt_builder import prompt_builder
from .relationship_manager import relationship_manager from .relationship_manager import relationship_manager
from .utils import process_llm_response from .utils import process_llm_response
@@ -144,7 +144,7 @@ class ResponseGenerator:
# content: str, content_check: str, reasoning_content: str, reasoning_content_check: str): # content: str, content_check: str, reasoning_content: str, reasoning_content_check: str):
def _save_to_db( def _save_to_db(
self, self,
message: Message, message: MessageRecv,
sender_name: str, sender_name: str,
prompt: str, prompt: str,
prompt_check: str, prompt_check: str,
@@ -155,7 +155,7 @@ class ResponseGenerator:
self.db.db.reasoning_logs.insert_one( self.db.db.reasoning_logs.insert_one(
{ {
"time": time.time(), "time": time.time(),
"group_id": message.group_id, "chat_id": message.chat_stream.stream_id,
"user": sender_name, "user": sender_name,
"message": message.processed_plain_text, "message": message.processed_plain_text,
"model": self.current_model_type, "model": self.current_model_type,

View File

@@ -7,7 +7,7 @@ from loguru import logger
from .utils_image import image_manager from .utils_image import image_manager
from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase
from .chat_stream import ChatStream from .chat_stream import ChatStream, chat_manager
# 禁用SSL警告 # 禁用SSL警告
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
@@ -25,8 +25,8 @@ class MessageRecv(MessageBase):
Args: Args:
message_dict: MessageCQ序列化后的字典 message_dict: MessageCQ序列化后的字典
""" """
message_info = BaseMessageInfo(**message_dict.get('message_info', {})) message_info = BaseMessageInfo.from_dict(message_dict.get('message_info', {}))
message_segment = Seg(**message_dict.get('message_segment', {})) message_segment = Seg.from_dict(message_dict.get('message_segment', {}))
raw_message = message_dict.get('raw_message') raw_message = message_dict.get('raw_message')
super().__init__( super().__init__(
@@ -39,7 +39,9 @@ class MessageRecv(MessageBase):
self.processed_plain_text = "" # 初始化为空字符串 self.processed_plain_text = "" # 初始化为空字符串
self.detailed_plain_text = "" # 初始化为空字符串 self.detailed_plain_text = "" # 初始化为空字符串
self.is_emoji=False self.is_emoji=False
def update_chat_stream(self,chat_stream:ChatStream):
self.chat_stream=chat_stream
async def process(self) -> None: async def process(self) -> None:
"""处理消息内容,生成纯文本和详细文本 """处理消息内容,生成纯文本和详细文本
@@ -83,12 +85,12 @@ class MessageRecv(MessageBase):
return seg.data return seg.data
elif seg.type == 'image': elif seg.type == 'image':
# 如果是base64图片数据 # 如果是base64图片数据
if isinstance(seg.data, str) and seg.data.startswith(('data:', 'base64:')): if isinstance(seg.data, str):
return await image_manager.get_image_description(seg.data) return await image_manager.get_image_description(seg.data)
return '[图片]' return '[图片]'
elif seg.type == 'emoji': elif seg.type == 'emoji':
self.is_emoji=True self.is_emoji=True
if isinstance(seg.data, str) and seg.data.startswith(('data:', 'base64:')): if isinstance(seg.data, str):
return await image_manager.get_emoji_description(seg.data) return await image_manager.get_emoji_description(seg.data)
return '[表情]' return '[表情]'
else: else:
@@ -217,11 +219,11 @@ class MessageProcessBase(Message):
return seg.data return seg.data
elif seg.type == 'image': elif seg.type == 'image':
# 如果是base64图片数据 # 如果是base64图片数据
if isinstance(seg.data, str) and seg.data.startswith(('data:', 'base64:')): if isinstance(seg.data, str):
return await image_manager.get_image_description(seg.data) return await image_manager.get_image_description(seg.data)
return '[图片]' return '[图片]'
elif seg.type == 'emoji': elif seg.type == 'emoji':
if isinstance(seg.data, str) and seg.data.startswith(('data:', 'base64:')): if isinstance(seg.data, str):
return await image_manager.get_emoji_description(seg.data) return await image_manager.get_emoji_description(seg.data)
return '[表情]' return '[表情]'
elif seg.type == 'at': elif seg.type == 'at':
@@ -296,10 +298,15 @@ class MessageSending(MessageProcessBase):
self.reply_to_message_id = reply.message_info.message_id if reply else None self.reply_to_message_id = reply.message_info.message_id if reply else None
self.is_head = is_head self.is_head = is_head
self.is_emoji = is_emoji self.is_emoji = is_emoji
if is_head:
def set_reply(self, reply: Optional['MessageRecv']) -> None:
"""设置回复消息"""
if reply:
self.reply = reply
self.reply_to_message_id = self.reply.message_info.message_id
self.message_segment = Seg(type='seglist', data=[ self.message_segment = Seg(type='seglist', data=[
Seg(type='reply', data=reply.message_info.message_id), Seg(type='reply', data=reply.message_info.message_id),
self.message_segment self.message_segment
]) ])
async def process(self) -> None: async def process(self) -> None:
@@ -329,7 +336,7 @@ class MessageSending(MessageProcessBase):
def to_dict(self): def to_dict(self):
ret= super().to_dict() ret= super().to_dict()
ret['mesage_info']['user_info']=self.chat_stream.user_info.to_dict() ret['message_info']['user_info']=self.chat_stream.user_info.to_dict()
return ret return ret
@dataclass @dataclass

View File

@@ -2,7 +2,7 @@ from dataclasses import dataclass, asdict
from typing import List, Optional, Union, Any, Dict from typing import List, Optional, Union, Any, Dict
@dataclass @dataclass
class Seg(dict): class Seg:
"""消息片段类,用于表示消息的不同部分 """消息片段类,用于表示消息的不同部分
Attributes: Attributes:
@@ -15,47 +15,25 @@ class Seg(dict):
""" """
type: str type: str
data: Union[str, List['Seg']] data: Union[str, List['Seg']]
translated_data: Optional[str] = None
def __init__(self, type: str, data: Union[str, List['Seg']], translated_data: Optional[str] = None): # def __init__(self, type: str, data: Union[str, List['Seg']],):
"""初始化实例,确保字典和属性同步""" # """初始化实例,确保字典和属性同步"""
# 先初始化字典 # # 先初始化字典
super().__init__(type=type, data=data) # self.type = type
if translated_data is not None: # self.data = data
self['translated_data'] = translated_data
# 再初始化属性
object.__setattr__(self, 'type', type)
object.__setattr__(self, 'data', data)
object.__setattr__(self, 'translated_data', translated_data)
# 验证数据类型 @classmethod
self._validate_data() def from_dict(cls, data: Dict) -> 'Seg':
"""从字典创建Seg实例"""
def _validate_data(self) -> None: type=data.get('type')
"""验证数据类型的正确性""" data=data.get('data')
if self.type == 'seglist' and not isinstance(self.data, list): if type == 'seglist':
raise ValueError("seglist类型的data必须是列表") data = [Seg.from_dict(seg) for seg in data]
elif self.type == 'text' and not isinstance(self.data, str): return cls(
raise ValueError("text类型的data必须是字符串") type=type,
elif self.type == 'image' and not isinstance(self.data, str): data=data
raise ValueError("image类型的data必须是字符串") )
def __setattr__(self, name: str, value: Any) -> None:
"""重写属性设置,同时更新字典值"""
# 更新属性
object.__setattr__(self, name, value)
# 同步更新字典
if name in ['type', 'data', 'translated_data']:
self[name] = value
def __setitem__(self, key: str, value: Any) -> None:
"""重写字典值设置,同时更新属性"""
# 更新字典
super().__setitem__(key, value)
# 同步更新属性
if key in ['type', 'data', 'translated_data']:
object.__setattr__(self, key, value)
def to_dict(self) -> Dict: def to_dict(self) -> Dict:
"""转换为字典格式""" """转换为字典格式"""
@@ -64,8 +42,6 @@ class Seg(dict):
result['data'] = [seg.to_dict() for seg in self.data] result['data'] = [seg.to_dict() for seg in self.data]
else: else:
result['data'] = self.data result['data'] = self.data
if self.translated_data is not None:
result['translated_data'] = self.translated_data
return result return result
@dataclass @dataclass
@@ -79,6 +55,7 @@ class GroupInfo:
"""转换为字典格式""" """转换为字典格式"""
return {k: v for k, v in asdict(self).items() if v is not None} return {k: v for k, v in asdict(self).items() if v is not None}
@classmethod
def from_dict(cls, data: Dict) -> 'GroupInfo': def from_dict(cls, data: Dict) -> 'GroupInfo':
"""从字典创建GroupInfo实例 """从字典创建GroupInfo实例
@@ -106,6 +83,7 @@ class UserInfo:
"""转换为字典格式""" """转换为字典格式"""
return {k: v for k, v in asdict(self).items() if v is not None} return {k: v for k, v in asdict(self).items() if v is not None}
@classmethod
def from_dict(cls, data: Dict) -> 'UserInfo': def from_dict(cls, data: Dict) -> 'UserInfo':
"""从字典创建UserInfo实例 """从字典创建UserInfo实例
@@ -126,7 +104,7 @@ class UserInfo:
class BaseMessageInfo: class BaseMessageInfo:
"""消息信息类""" """消息信息类"""
platform: Optional[str] = None platform: Optional[str] = None
message_id: Optional[int,str] = None message_id: Union[str,int,None] = None
time: Optional[int] = None time: Optional[int] = None
group_info: Optional[GroupInfo] = None group_info: Optional[GroupInfo] = None
user_info: Optional[UserInfo] = None user_info: Optional[UserInfo] = None
@@ -141,6 +119,25 @@ class BaseMessageInfo:
else: else:
result[field] = value result[field] = value
return result return result
@classmethod
def from_dict(cls, data: Dict) -> 'BaseMessageInfo':
"""从字典创建BaseMessageInfo实例
Args:
data: 包含必要字段的字典
Returns:
BaseMessageInfo: 新的实例
"""
group_info = GroupInfo(**data.get('group_info', {}))
user_info = UserInfo(**data.get('user_info', {}))
return cls(
platform=data.get('platform'),
message_id=data.get('message_id'),
time=data.get('time'),
group_info=group_info,
user_info=user_info
)
@dataclass @dataclass
class MessageBase: class MessageBase:

View File

@@ -27,27 +27,10 @@ class MessageCQ(MessageBase):
def __init__( def __init__(
self, self,
message_id: int, message_id: int,
user_id: int, user_info: UserInfo,
group_id: Optional[int] = None, group_info: Optional[GroupInfo] = None,
platform: str = "qq" platform: str = "qq"
): ):
# 构造用户信息
user_info = UserInfo(
platform=platform,
user_id=user_id,
user_nickname=get_user_nickname(user_id),
user_cardname=get_user_cardname(user_id) if group_id else None
)
# 构造群组信息(如果有)
group_info = None
if group_id:
group_info = GroupInfo(
platform=platform,
group_id=group_id,
group_name=get_groupname(group_id)
)
# 构造基础消息信息 # 构造基础消息信息
message_info = BaseMessageInfo( message_info = BaseMessageInfo(
platform=platform, platform=platform,
@@ -56,7 +39,6 @@ class MessageCQ(MessageBase):
group_info=group_info, group_info=group_info,
user_info=user_info user_info=user_info
) )
# 调用父类初始化message_segment 由子类设置 # 调用父类初始化message_segment 由子类设置
super().__init__( super().__init__(
message_info=message_info, message_info=message_info,
@@ -71,14 +53,17 @@ class MessageRecvCQ(MessageCQ):
def __init__( def __init__(
self, self,
message_id: int, message_id: int,
user_id: int, user_info: UserInfo,
raw_message: str, raw_message: str,
group_id: Optional[int] = None, group_info: Optional[GroupInfo] = None,
platform: str = "qq",
reply_message: Optional[Dict] = None, reply_message: Optional[Dict] = None,
platform: str = "qq"
): ):
# 调用父类初始化 # 调用父类初始化
super().__init__(message_id, user_id, group_id, platform) super().__init__(message_id, user_info, group_info, platform)
if group_info.group_name is None:
group_info.group_name = get_groupname(group_info.group_id)
# 解析消息段 # 解析消息段
self.message_segment = self._parse_message(raw_message, reply_message) self.message_segment = self._parse_message(raw_message, reply_message)
@@ -117,7 +102,7 @@ class MessageRecvCQ(MessageCQ):
# 转换CQ码为Seg对象 # 转换CQ码为Seg对象
for code_item in cq_code_dict_list: for code_item in cq_code_dict_list:
message_obj = cq_code_tool.cq_from_dict_to_class(code_item, reply=reply_message) message_obj = cq_code_tool.cq_from_dict_to_class(code_item,msg=self,reply=reply_message)
if message_obj.translated_segments: if message_obj.translated_segments:
segments.append(message_obj.translated_segments) segments.append(message_obj.translated_segments)
@@ -142,13 +127,14 @@ class MessageSendCQ(MessageCQ):
data: Dict data: Dict
): ):
# 调用父类初始化 # 调用父类初始化
message_info = BaseMessageInfo(**data.get('message_info', {})) message_info = BaseMessageInfo.from_dict(data.get('message_info', {}))
message_segment = Seg(**data.get('message_segment', {})) message_segment = Seg.from_dict(data.get('message_segment', {}))
super().__init__( super().__init__(
message_info.message_id, message_info.message_id,
message_info.user_info.user_id, message_info.user_info,
message_info.group_info.group_id if message_info.group_info else None, message_info.group_info if message_info.group_info else None,
message_info.platform) message_info.platform
)
self.message_segment = message_segment self.message_segment = message_segment
self.raw_message = self._generate_raw_message() self.raw_message = self._generate_raw_message()
@@ -171,11 +157,9 @@ class MessageSendCQ(MessageCQ):
if seg.type == 'text': if seg.type == 'text':
return str(seg.data) return str(seg.data)
elif seg.type == 'image': elif seg.type == 'image':
# 如果是base64图片数据 return cq_code_tool.create_image_cq_base64(seg.data)
if seg.data.startswith(('data:', 'base64:')): elif seg.type == 'emoji':
return cq_code_tool.create_emoji_cq_base64(seg.data) return cq_code_tool.create_emoji_cq_base64(seg.data)
# 如果是表情包(本地文件)
return cq_code_tool.create_emoji_cq(seg.data)
elif seg.type == 'at': elif seg.type == 'at':
return f"[CQ:at,qq={seg.data}]" return f"[CQ:at,qq={seg.data}]"
elif seg.type == 'reply': elif seg.type == 'reply':

View File

@@ -41,10 +41,10 @@ class Message_Sender:
message=message_send.raw_message, message=message_send.raw_message,
auto_escape=False auto_escape=False
) )
print(f"\033[1;34m[调试]\033[0m 发送消息{message}成功") print(f"\033[1;34m[调试]\033[0m 发送消息{message.processed_plain_text}成功")
except Exception as e: except Exception as e:
print(f"发生错误 {e}") print(f"发生错误 {e}")
print(f"\033[1;34m[调试]\033[0m 发送消息{message}失败") print(f"\033[1;34m[调试]\033[0m 发送消息{message.processed_plain_text}失败")
else: else:
try: try:
await self._current_bot.send_private_msg( await self._current_bot.send_private_msg(
@@ -52,10 +52,10 @@ class Message_Sender:
message=message_send.raw_message, message=message_send.raw_message,
auto_escape=False auto_escape=False
) )
print(f"\033[1;34m[调试]\033[0m 发送消息{message}成功") print(f"\033[1;34m[调试]\033[0m 发送消息{message.processed_plain_text}成功")
except Exception as e: except Exception as e:
print(f"发生错误 {e}") print(f"发生错误 {e}")
print(f"\033[1;34m[调试]\033[0m 发送消息{message}失败") print(f"\033[1;34m[调试]\033[0m 发送消息{message.processed_plain_text}失败")
class MessageContainer: class MessageContainer:
@@ -137,11 +137,7 @@ class MessageManager:
return self.containers[chat_id] return self.containers[chat_id]
def add_message(self, message: Union[MessageThinking, MessageSending, MessageSet]) -> None: def add_message(self, message: Union[MessageThinking, MessageSending, MessageSet]) -> None:
chat_stream = chat_manager.get_stream_by_info( chat_stream = message.chat_stream
platform=message.message_info.platform,
user_info=message.message_info.user_info,
group_info=message.message_info.group_info
)
if not chat_stream: if not chat_stream:
raise ValueError("无法找到对应的聊天流") raise ValueError("无法找到对应的聊天流")
container = self.get_container(chat_stream.stream_id) container = self.get_container(chat_stream.stream_id)
@@ -165,13 +161,14 @@ class MessageManager:
else: else:
print(f"\033[1;34m[调试]\033[0m 消息'{message_earliest.processed_plain_text}'正在发送中") print(f"\033[1;34m[调试]\033[0m 消息'{message_earliest.processed_plain_text}'正在发送中")
if message_earliest.is_head and message_earliest.update_thinking_time() > 30: if message_earliest.is_head and message_earliest.update_thinking_time() > 30:
await message_sender.send_message(message_earliest) await message_sender.send_message(message_earliest.set_reply())
else: else:
await message_sender.send_message(message_earliest) await message_sender.send_message(message_earliest)
if message_earliest.is_emoji: # if message_earliest.is_emoji:
message_earliest.processed_plain_text = "[表情包]" # message_earliest.processed_plain_text = "[表情包]"
await self.storage.store_message(message_earliest, None) await message_earliest.process()
await self.storage.store_message(message_earliest, message_earliest.chat_stream,None)
container.remove_message(message_earliest) container.remove_message(message_earliest)
@@ -184,13 +181,14 @@ class MessageManager:
try: try:
if msg.is_head and msg.update_thinking_time() > 30: if msg.is_head and msg.update_thinking_time() > 30:
await message_sender.send_group_message(chat_id, msg.processed_plain_text, auto_escape=False, reply_message_id=msg.reply_message_id) await message_sender.send_message(msg.set_reply())
else: else:
await message_sender.send_group_message(chat_id, msg.processed_plain_text, auto_escape=False) await message_sender.send_message(msg)
if msg.is_emoji: # if msg.is_emoji:
msg.processed_plain_text = "[表情包]" # msg.processed_plain_text = "[表情包]"
await self.storage.store_message(msg, None) await msg.process()
await self.storage.store_message(msg,msg.chat_stream, None)
if not container.remove_message(msg): if not container.remove_message(msg):
print("\033[1;33m[警告]\033[0m 尝试删除不存在的消息") print("\033[1;33m[警告]\033[0m 尝试删除不存在的消息")

View File

@@ -23,12 +23,12 @@ class Relationship:
saved = False saved = False
def __init__(self, chat:ChatStream=None,data:dict=None): def __init__(self, chat:ChatStream=None,data:dict=None):
self.user_id=chat.user_info.user_id if chat.user_info else data.get('user_id',0) self.user_id=chat.user_info.user_id if chat else data.get('user_id',0)
self.platform=chat.platform if chat.user_info else data.get('platform','') self.platform=chat.platform if chat else data.get('platform','')
self.nickname=chat.user_info.user_nickname if chat.user_info else data.get('nickname','') self.nickname=chat.user_info.user_nickname if chat else data.get('nickname','')
self.relationship_value=data.get('relationship_value',0) self.relationship_value=data.get('relationship_value',0) if data else 0
self.age=data.get('age',0) self.age=data.get('age',0) if data else 0
self.gender=data.get('gender','') self.gender=data.get('gender','') if data else ''
class RelationshipManager: class RelationshipManager:

View File

@@ -59,7 +59,7 @@ class ImageManager:
self.db.db.image_descriptions.create_index([('hash', 1)], unique=True) self.db.db.image_descriptions.create_index([('hash', 1)], unique=True)
self.db.db.image_descriptions.create_index([('type', 1)]) self.db.db.image_descriptions.create_index([('type', 1)])
async def _get_description_from_db(self, image_hash: str, description_type: str) -> Optional[str]: def _get_description_from_db(self, image_hash: str, description_type: str) -> Optional[str]:
"""从数据库获取图片描述 """从数据库获取图片描述
Args: Args:
@@ -69,13 +69,13 @@ class ImageManager:
Returns: Returns:
Optional[str]: 描述文本如果不存在则返回None Optional[str]: 描述文本如果不存在则返回None
""" """
result = await self.db.db.image_descriptions.find_one({ result= self.db.db.image_descriptions.find_one({
'hash': image_hash, 'hash': image_hash,
'type': description_type 'type': description_type
}) })
return result['description'] if result else None return result['description'] if result else None
async def _save_description_to_db(self, image_hash: str, description: str, description_type: str) -> None: def _save_description_to_db(self, image_hash: str, description: str, description_type: str) -> None:
"""保存图片描述到数据库 """保存图片描述到数据库
Args: Args:
@@ -83,7 +83,7 @@ class ImageManager:
description: 描述文本 description: 描述文本
description_type: 描述类型 ('emoji''image') description_type: 描述类型 ('emoji''image')
""" """
await self.db.db.image_descriptions.update_one( self.db.db.image_descriptions.update_one(
{'hash': image_hash, 'type': description_type}, {'hash': image_hash, 'type': description_type},
{ {
'$set': { '$set': {
@@ -253,8 +253,9 @@ class ImageManager:
image_hash = hashlib.md5(image_bytes).hexdigest() image_hash = hashlib.md5(image_bytes).hexdigest()
# 查询缓存的描述 # 查询缓存的描述
cached_description = await self._get_description_from_db(image_hash, 'emoji') cached_description = self._get_description_from_db(image_hash, 'emoji')
if cached_description: if cached_description:
logger.info(f"缓存表情包描述: {cached_description}")
return f"[表情包:{cached_description}]" return f"[表情包:{cached_description}]"
# 调用AI获取描述 # 调用AI获取描述
@@ -281,7 +282,7 @@ class ImageManager:
'description': description, 'description': description,
'timestamp': timestamp 'timestamp': timestamp
} }
await self.db.db.images.update_one( self.db.db.images.update_one(
{'hash': image_hash}, {'hash': image_hash},
{'$set': image_doc}, {'$set': image_doc},
upsert=True upsert=True
@@ -291,7 +292,7 @@ class ImageManager:
logger.error(f"保存表情包文件失败: {str(e)}") logger.error(f"保存表情包文件失败: {str(e)}")
# 保存描述到数据库 # 保存描述到数据库
await self._save_description_to_db(image_hash, description, 'emoji') self._save_description_to_db(image_hash, description, 'emoji')
return f"[表情包:{description}]" return f"[表情包:{description}]"
except Exception as e: except Exception as e:
@@ -306,7 +307,7 @@ class ImageManager:
image_hash = hashlib.md5(image_bytes).hexdigest() image_hash = hashlib.md5(image_bytes).hexdigest()
# 查询缓存的描述 # 查询缓存的描述
cached_description = await self._get_description_from_db(image_hash, 'image') cached_description = self._get_description_from_db(image_hash, 'image')
if cached_description: if cached_description:
return f"[图片:{cached_description}]" return f"[图片:{cached_description}]"
@@ -334,7 +335,7 @@ class ImageManager:
'description': description, 'description': description,
'timestamp': timestamp 'timestamp': timestamp
} }
await self.db.db.images.update_one( self.db.db.images.update_one(
{'hash': image_hash}, {'hash': image_hash},
{'$set': image_doc}, {'$set': image_doc},
upsert=True upsert=True
@@ -357,80 +358,6 @@ class ImageManager:
image_manager = ImageManager() image_manager = ImageManager()
def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 1024 * 1024) -> str:
"""压缩base64格式的图片到指定大小
Args:
base64_data: base64编码的图片数据
target_size: 目标文件大小字节默认0.8MB
Returns:
str: 压缩后的base64图片数据
"""
try:
# 将base64转换为字节数据
image_data = base64.b64decode(base64_data)
# 如果已经小于目标大小,直接返回原图
if len(image_data) <= 2*1024*1024:
return base64_data
# 将字节数据转换为图片对象
img = Image.open(io.BytesIO(image_data))
# 获取原始尺寸
original_width, original_height = img.size
# 计算缩放比例
scale = min(1.0, (target_size / len(image_data)) ** 0.5)
# 计算新的尺寸
new_width = int(original_width * scale)
new_height = int(original_height * scale)
# 创建内存缓冲区
output_buffer = io.BytesIO()
# 如果是GIF处理所有帧
if getattr(img, "is_animated", False):
frames = []
for frame_idx in range(img.n_frames):
img.seek(frame_idx)
new_frame = img.copy()
new_frame = new_frame.resize((new_width//2, new_height//2), Image.Resampling.LANCZOS) # 动图折上折
frames.append(new_frame)
# 保存到缓冲区
frames[0].save(
output_buffer,
format='GIF',
save_all=True,
append_images=frames[1:],
optimize=True,
duration=img.info.get('duration', 100),
loop=img.info.get('loop', 0)
)
else:
# 处理静态图片
resized_img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
# 保存到缓冲区,保持原始格式
if img.format == 'PNG' and img.mode in ('RGBA', 'LA'):
resized_img.save(output_buffer, format='PNG', optimize=True)
else:
resized_img.save(output_buffer, format='JPEG', quality=95, optimize=True)
# 获取压缩后的数据并转换为base64
compressed_data = output_buffer.getvalue()
logger.success(f"压缩图片: {original_width}x{original_height} -> {new_width}x{new_height}")
logger.info(f"压缩前大小: {len(image_data)/1024:.1f}KB, 压缩后大小: {len(compressed_data)/1024:.1f}KB")
return base64.b64encode(compressed_data).decode('utf-8')
except Exception as e:
logger.error(f"压缩图片失败: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return base64_data
def image_path_to_base64(image_path: str) -> str: def image_path_to_base64(image_path: str) -> str:
"""将图片路径转换为base64编码 """将图片路径转换为base64编码
Args: Args:

View File

@@ -7,10 +7,11 @@ from typing import Tuple, Union
import aiohttp import aiohttp
from loguru import logger from loguru import logger
from nonebot import get_driver from nonebot import get_driver
import base64
from PIL import Image
import io
from ...common.database import Database from ...common.database import Database
from ..chat.config import global_config from ..chat.config import global_config
from ..chat.utils_image import compress_base64_image_by_scale
driver = get_driver() driver = get_driver()
config = driver.config config = driver.config
@@ -405,3 +406,77 @@ class LLM_request:
) )
return embedding return embedding
def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 1024 * 1024) -> str:
"""压缩base64格式的图片到指定大小
Args:
base64_data: base64编码的图片数据
target_size: 目标文件大小字节默认0.8MB
Returns:
str: 压缩后的base64图片数据
"""
try:
# 将base64转换为字节数据
image_data = base64.b64decode(base64_data)
# 如果已经小于目标大小,直接返回原图
if len(image_data) <= 2*1024*1024:
return base64_data
# 将字节数据转换为图片对象
img = Image.open(io.BytesIO(image_data))
# 获取原始尺寸
original_width, original_height = img.size
# 计算缩放比例
scale = min(1.0, (target_size / len(image_data)) ** 0.5)
# 计算新的尺寸
new_width = int(original_width * scale)
new_height = int(original_height * scale)
# 创建内存缓冲区
output_buffer = io.BytesIO()
# 如果是GIF处理所有帧
if getattr(img, "is_animated", False):
frames = []
for frame_idx in range(img.n_frames):
img.seek(frame_idx)
new_frame = img.copy()
new_frame = new_frame.resize((new_width//2, new_height//2), Image.Resampling.LANCZOS) # 动图折上折
frames.append(new_frame)
# 保存到缓冲区
frames[0].save(
output_buffer,
format='GIF',
save_all=True,
append_images=frames[1:],
optimize=True,
duration=img.info.get('duration', 100),
loop=img.info.get('loop', 0)
)
else:
# 处理静态图片
resized_img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
# 保存到缓冲区,保持原始格式
if img.format == 'PNG' and img.mode in ('RGBA', 'LA'):
resized_img.save(output_buffer, format='PNG', optimize=True)
else:
resized_img.save(output_buffer, format='JPEG', quality=95, optimize=True)
# 获取压缩后的数据并转换为base64
compressed_data = output_buffer.getvalue()
logger.success(f"压缩图片: {original_width}x{original_height} -> {new_width}x{new_height}")
logger.info(f"压缩前大小: {len(image_data)/1024:.1f}KB, 压缩后大小: {len(compressed_data)/1024:.1f}KB")
return base64.b64encode(compressed_data).decode('utf-8')
except Exception as e:
logger.error(f"压缩图片失败: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return base64_data