feat: 史上最好的消息流重构和图片管理

This commit is contained in:
tcmofashi
2025-03-11 04:42:24 +08:00
parent 7899e67cb2
commit 8cbf9bb048
13 changed files with 272 additions and 235 deletions

View File

@@ -18,6 +18,7 @@ from .config import global_config
from .emoji_manager import emoji_manager
from .relationship_manager import relationship_manager
from .willing_manager import willing_manager
from .chat_stream import chat_manager
# 创建LLM统计实例
llm_stats = LLMStatistics("llm_statistics.txt")
@@ -101,6 +102,8 @@ async def _(bot: Bot):
asyncio.create_task(emoji_manager._periodic_scan(interval_MINS=global_config.EMOJI_REGISTER_INTERVAL))
print("\033[1;38;5;208m-----------开始偷表情包!-----------\033[0m")
asyncio.create_task(chat_manager._initialize())
asyncio.create_task(chat_manager._auto_save_task())
@group_msg.handle()
async def _(bot: Bot, event: GroupMessageEvent, state: T_State):

View File

@@ -18,7 +18,7 @@ from .chat_stream import chat_manager
from .message_sender import message_manager # 导入新的消息管理器
from .relationship_manager import relationship_manager
from .storage import MessageStorage
from .utils import calculate_typing_time, is_mentioned_bot_in_txt
from .utils import calculate_typing_time, is_mentioned_bot_in_message
from .utils_image import image_path_to_base64
from .willing_manager import willing_manager # 导入意愿管理器
from .message_base import UserInfo, GroupInfo, Seg
@@ -45,8 +45,8 @@ class ChatBot:
self.bot = bot # 更新 bot 实例
group_info = await bot.get_group_info(group_id=event.group_id)
sender_info = await bot.get_group_member_info(group_id=event.group_id, user_id=event.user_id, no_cache=True)
# group_info = await bot.get_group_info(group_id=event.group_id)
# sender_info = await bot.get_group_member_info(group_id=event.group_id, user_id=event.user_id, no_cache=True)
# 白名单设定由nontbot侧完成
if event.group_id:
@@ -55,18 +55,31 @@ class ChatBot:
if event.user_id in global_config.ban_user_id:
return
user_info=UserInfo(
user_id=event.user_id,
user_nickname=event.sender.nickname,
user_cardname=event.sender.card or None,
platform='qq'
)
group_info=GroupInfo(
group_id=event.group_id,
group_name=None,
platform='qq'
)
message_cq=MessageRecvCQ(
message_id=event.message_id,
user_id=event.user_id,
user_info=user_info,
raw_message=str(event.original_message),
group_id=event.group_id,
group_info=group_info,
reply_message=event.reply,
platform='qq'
)
message_json=message_cq.to_dict()
# 进入maimbot
message=MessageRecv(**message_json)
message=MessageRecv(message_json)
groupinfo=message.message_info.group_info
userinfo=message.message_info.user_info
@@ -75,6 +88,7 @@ class ChatBot:
# 消息过滤涉及到config有待更新
chat = await chat_manager.get_or_create_stream(platform=messageinfo.platform, user_info=userinfo, group_info=groupinfo)
message.update_chat_stream(chat)
await relationship_manager.update_relationship(chat_stream=chat,)
await relationship_manager.update_relationship_value(chat_stream=chat, relationship_value = 0.5)
@@ -99,7 +113,7 @@ class ChatBot:
await self.storage.store_message(message,chat, topic[0] if topic else None)
is_mentioned = is_mentioned_bot_in_txt(message.processed_plain_text)
is_mentioned = is_mentioned_bot_in_message(message)
reply_probability = await willing_manager.change_reply_willing_received(
chat_stream=chat,
topic=topic[0] if topic else None,

View File

@@ -1,6 +1,7 @@
import asyncio
import hashlib
import time
import copy
from typing import Dict, Optional
from loguru import logger
@@ -86,9 +87,9 @@ class ChatManager:
self._ensure_collection()
self._initialized = True
# 在事件循环中启动初始化
asyncio.create_task(self._initialize())
# 启动自动保存任务
asyncio.create_task(self._auto_save_task())
# asyncio.create_task(self._initialize())
# # 启动自动保存任务
# asyncio.create_task(self._auto_save_task())
async def _initialize(self):
"""异步初始化"""
@@ -122,12 +123,18 @@ class ChatManager:
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
) -> str:
"""生成聊天流唯一ID"""
# 组合关键信息
components = [
platform,
str(user_info.user_id),
str(group_info.group_id) if group_info else "private",
]
if group_info:
# 组合关键信息
components = [
platform,
str(group_info.group_id)
]
else:
components = [
platform,
str(user_info.user_id),
"private"
]
# 使用MD5生成唯一ID
key = "_".join(components)
@@ -153,10 +160,11 @@ class ChatManager:
if stream_id in self.streams:
stream = self.streams[stream_id]
# 更新用户信息和群组信息
stream.update_active_time()
stream=copy.deepcopy(stream)
stream.user_info = user_info
if group_info:
stream.group_info = group_info
stream.update_active_time()
return stream
# 检查数据库中是否存在
@@ -180,7 +188,7 @@ class ChatManager:
# 保存到内存和数据库
self.streams[stream_id] = stream
await self._save_stream(stream)
return stream
return copy.deepcopy(stream)
def get_stream(self, stream_id: str) -> Optional[ChatStream]:
"""通过stream_id获取聊天流"""

View File

@@ -59,6 +59,7 @@ class CQCode:
params: Dict[str, str]
group_id: int
user_id: int
user_nickname: str
group_name: str = ""
user_nickname: str = ""
translated_segments: Optional[Union[Seg, List[Seg]]] = None
@@ -68,9 +69,7 @@ class CQCode:
def __post_init__(self):
"""初始化LLM实例"""
self._llm = LLM_request(
model=global_config.vlm, temperature=0.4, max_tokens=300
)
pass
def translate(self):
"""根据CQ码类型进行相应的翻译处理转换为Seg对象"""
@@ -225,8 +224,7 @@ class CQCode:
group_id=msg.get("group_id", 0),
)
content_seg = Seg(
type="seglist", data=message_obj.message_segments
)
type="seglist", data=message_obj.message_segment )
else:
content_seg = Seg(type="text", data="[空消息]")
else:
@@ -241,7 +239,7 @@ class CQCode:
group_id=msg.get("group_id", 0),
)
content_seg = Seg(
type="seglist", data=message_obj.message_segments
type="seglist", data=message_obj.message_segment
)
else:
content_seg = Seg(type="text", data="[空消息]")
@@ -272,7 +270,7 @@ class CQCode:
)
segments = []
if message_obj.user_id == global_config.BOT_QQ:
if message_obj.message_info.user_info.user_id == global_config.BOT_QQ:
segments.append(
Seg(
type="text", data=f"[回复 {global_config.BOT_NICKNAME} 的消息: "
@@ -286,7 +284,7 @@ class CQCode:
)
)
segments.append(Seg(type="seglist", data=message_obj.message_segments))
segments.append(Seg(type="seglist", data=message_obj.message_segment))
segments.append(Seg(type="text", data="]"))
return segments
else:
@@ -305,12 +303,13 @@ class CQCode:
class CQCode_tool:
@staticmethod
def cq_from_dict_to_class(cq_code: Dict, reply: Optional[Dict] = None) -> CQCode:
def cq_from_dict_to_class(cq_code: Dict,msg ,reply: Optional[Dict] = None) -> CQCode:
"""
将CQ码字典转换为CQCode对象
Args:
cq_code: CQ码字典
msg: MessageCQ对象
reply: 回复消息的字典(可选)
Returns:
@@ -326,7 +325,13 @@ class CQCode_tool:
params = cq_code.get("data", {})
instance = CQCode(
type=cq_type, params=params, group_id=0, user_id=0, reply_message=reply
type=cq_type,
params=params,
group_id=msg.message_info.group_info.group_id,
user_id=msg.message_info.user_info.user_id,
user_nickname=msg.message_info.user_info.user_nickname,
group_name=msg.message_info.group_info.group_name,
reply_message=reply
)
# 进行翻译处理
@@ -384,5 +389,24 @@ class CQCode_tool:
# 生成CQ码设置sub_type=1表示这是表情包
return f"[CQ:image,file=base64://{escaped_base64},sub_type=1]"
@staticmethod
def create_image_cq_base64(base64_data: str) -> str:
"""
创建表情包CQ码
Args:
base64_data: base64编码的表情包数据
Returns:
表情包CQ码字符串
"""
# 转义base64数据
escaped_base64 = (
base64_data.replace("&", "&")
.replace("[", "[")
.replace("]", "]")
.replace(",", ",")
)
# 生成CQ码设置sub_type=1表示这是表情包
return f"[CQ:image,file=base64://{escaped_base64},sub_type=0]"
cq_code_tool = CQCode_tool()

View File

@@ -239,7 +239,7 @@ class EmojiManager:
# 即使表情包已存在也检查是否需要同步到images集合
description = existing_emoji.get('discription')
# 检查是否在images集合中存在
existing_image = await image_manager.db.db.images.find_one({'hash': image_hash})
existing_image = image_manager.db.db.images.find_one({'hash': image_hash})
if not existing_image:
# 同步到images集合
image_doc = {
@@ -249,7 +249,7 @@ class EmojiManager:
'description': description,
'timestamp': int(time.time())
}
await image_manager.db.db.images.update_one(
image_manager.db.db.images.update_one(
{'hash': image_hash},
{'$set': image_doc},
upsert=True
@@ -260,7 +260,7 @@ class EmojiManager:
continue
# 检查是否在images集合中已有描述
existing_description = await image_manager._get_description_from_db(image_hash, 'emoji')
existing_description = image_manager._get_description_from_db(image_hash, 'emoji')
if existing_description:
description = existing_description
@@ -302,13 +302,13 @@ class EmojiManager:
'description': description,
'timestamp': int(time.time())
}
await image_manager.db.db.images.update_one(
image_manager.db.db.images.update_one(
{'hash': image_hash},
{'$set': image_doc},
upsert=True
)
# 保存描述到image_descriptions集合
await image_manager._save_description_to_db(image_hash, description, 'emoji')
image_manager._save_description_to_db(image_hash, description, 'emoji')
logger.success(f"同步保存到images集合: {filename}")
else:
logger.warning(f"跳过表情包: {filename}")

View File

@@ -7,7 +7,7 @@ from nonebot import get_driver
from ...common.database import Database
from ..models.utils_model import LLM_request
from .config import global_config
from .message import MessageRecv, MessageThinking, MessageSending
from .message import MessageRecv, MessageThinking, MessageSending,Message
from .prompt_builder import prompt_builder
from .relationship_manager import relationship_manager
from .utils import process_llm_response
@@ -144,7 +144,7 @@ class ResponseGenerator:
# content: str, content_check: str, reasoning_content: str, reasoning_content_check: str):
def _save_to_db(
self,
message: Message,
message: MessageRecv,
sender_name: str,
prompt: str,
prompt_check: str,
@@ -155,7 +155,7 @@ class ResponseGenerator:
self.db.db.reasoning_logs.insert_one(
{
"time": time.time(),
"group_id": message.group_id,
"chat_id": message.chat_stream.stream_id,
"user": sender_name,
"message": message.processed_plain_text,
"model": self.current_model_type,

View File

@@ -7,7 +7,7 @@ from loguru import logger
from .utils_image import image_manager
from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase
from .chat_stream import ChatStream
from .chat_stream import ChatStream, chat_manager
# 禁用SSL警告
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
@@ -25,8 +25,8 @@ class MessageRecv(MessageBase):
Args:
message_dict: MessageCQ序列化后的字典
"""
message_info = BaseMessageInfo(**message_dict.get('message_info', {}))
message_segment = Seg(**message_dict.get('message_segment', {}))
message_info = BaseMessageInfo.from_dict(message_dict.get('message_info', {}))
message_segment = Seg.from_dict(message_dict.get('message_segment', {}))
raw_message = message_dict.get('raw_message')
super().__init__(
@@ -39,6 +39,8 @@ class MessageRecv(MessageBase):
self.processed_plain_text = "" # 初始化为空字符串
self.detailed_plain_text = "" # 初始化为空字符串
self.is_emoji=False
def update_chat_stream(self,chat_stream:ChatStream):
self.chat_stream=chat_stream
async def process(self) -> None:
"""处理消息内容,生成纯文本和详细文本
@@ -83,12 +85,12 @@ class MessageRecv(MessageBase):
return seg.data
elif seg.type == 'image':
# 如果是base64图片数据
if isinstance(seg.data, str) and seg.data.startswith(('data:', 'base64:')):
if isinstance(seg.data, str):
return await image_manager.get_image_description(seg.data)
return '[图片]'
elif seg.type == 'emoji':
self.is_emoji=True
if isinstance(seg.data, str) and seg.data.startswith(('data:', 'base64:')):
if isinstance(seg.data, str):
return await image_manager.get_emoji_description(seg.data)
return '[表情]'
else:
@@ -217,11 +219,11 @@ class MessageProcessBase(Message):
return seg.data
elif seg.type == 'image':
# 如果是base64图片数据
if isinstance(seg.data, str) and seg.data.startswith(('data:', 'base64:')):
if isinstance(seg.data, str):
return await image_manager.get_image_description(seg.data)
return '[图片]'
elif seg.type == 'emoji':
if isinstance(seg.data, str) and seg.data.startswith(('data:', 'base64:')):
if isinstance(seg.data, str):
return await image_manager.get_emoji_description(seg.data)
return '[表情]'
elif seg.type == 'at':
@@ -296,10 +298,15 @@ class MessageSending(MessageProcessBase):
self.reply_to_message_id = reply.message_info.message_id if reply else None
self.is_head = is_head
self.is_emoji = is_emoji
if is_head:
def set_reply(self, reply: Optional['MessageRecv']) -> None:
"""设置回复消息"""
if reply:
self.reply = reply
self.reply_to_message_id = self.reply.message_info.message_id
self.message_segment = Seg(type='seglist', data=[
Seg(type='reply', data=reply.message_info.message_id),
self.message_segment
Seg(type='reply', data=reply.message_info.message_id),
self.message_segment
])
async def process(self) -> None:
@@ -329,7 +336,7 @@ class MessageSending(MessageProcessBase):
def to_dict(self):
ret= super().to_dict()
ret['mesage_info']['user_info']=self.chat_stream.user_info.to_dict()
ret['message_info']['user_info']=self.chat_stream.user_info.to_dict()
return ret
@dataclass

View File

@@ -2,7 +2,7 @@ from dataclasses import dataclass, asdict
from typing import List, Optional, Union, Any, Dict
@dataclass
class Seg(dict):
class Seg:
"""消息片段类,用于表示消息的不同部分
Attributes:
@@ -15,47 +15,25 @@ class Seg(dict):
"""
type: str
data: Union[str, List['Seg']]
translated_data: Optional[str] = None
def __init__(self, type: str, data: Union[str, List['Seg']], translated_data: Optional[str] = None):
"""初始化实例,确保字典和属性同步"""
# 先初始化字典
super().__init__(type=type, data=data)
if translated_data is not None:
self['translated_data'] = translated_data
# 再初始化属性
object.__setattr__(self, 'type', type)
object.__setattr__(self, 'data', data)
object.__setattr__(self, 'translated_data', translated_data)
# def __init__(self, type: str, data: Union[str, List['Seg']],):
# """初始化实例,确保字典和属性同步"""
# # 先初始化字典
# self.type = type
# self.data = data
# 验证数据类型
self._validate_data()
def _validate_data(self) -> None:
"""验证数据类型的正确性"""
if self.type == 'seglist' and not isinstance(self.data, list):
raise ValueError("seglist类型的data必须是列表")
elif self.type == 'text' and not isinstance(self.data, str):
raise ValueError("text类型的data必须是字符串")
elif self.type == 'image' and not isinstance(self.data, str):
raise ValueError("image类型的data必须是字符串")
def __setattr__(self, name: str, value: Any) -> None:
"""重写属性设置,同时更新字典值"""
# 更新属性
object.__setattr__(self, name, value)
# 同步更新字典
if name in ['type', 'data', 'translated_data']:
self[name] = value
def __setitem__(self, key: str, value: Any) -> None:
"""重写字典值设置,同时更新属性"""
# 更新字典
super().__setitem__(key, value)
# 同步更新属性
if key in ['type', 'data', 'translated_data']:
object.__setattr__(self, key, value)
@classmethod
def from_dict(cls, data: Dict) -> 'Seg':
"""从字典创建Seg实例"""
type=data.get('type')
data=data.get('data')
if type == 'seglist':
data = [Seg.from_dict(seg) for seg in data]
return cls(
type=type,
data=data
)
def to_dict(self) -> Dict:
"""转换为字典格式"""
@@ -64,8 +42,6 @@ class Seg(dict):
result['data'] = [seg.to_dict() for seg in self.data]
else:
result['data'] = self.data
if self.translated_data is not None:
result['translated_data'] = self.translated_data
return result
@dataclass
@@ -79,6 +55,7 @@ class GroupInfo:
"""转换为字典格式"""
return {k: v for k, v in asdict(self).items() if v is not None}
@classmethod
def from_dict(cls, data: Dict) -> 'GroupInfo':
"""从字典创建GroupInfo实例
@@ -106,6 +83,7 @@ class UserInfo:
"""转换为字典格式"""
return {k: v for k, v in asdict(self).items() if v is not None}
@classmethod
def from_dict(cls, data: Dict) -> 'UserInfo':
"""从字典创建UserInfo实例
@@ -126,7 +104,7 @@ class UserInfo:
class BaseMessageInfo:
"""消息信息类"""
platform: Optional[str] = None
message_id: Optional[int,str] = None
message_id: Union[str,int,None] = None
time: Optional[int] = None
group_info: Optional[GroupInfo] = None
user_info: Optional[UserInfo] = None
@@ -141,6 +119,25 @@ class BaseMessageInfo:
else:
result[field] = value
return result
@classmethod
def from_dict(cls, data: Dict) -> 'BaseMessageInfo':
"""从字典创建BaseMessageInfo实例
Args:
data: 包含必要字段的字典
Returns:
BaseMessageInfo: 新的实例
"""
group_info = GroupInfo(**data.get('group_info', {}))
user_info = UserInfo(**data.get('user_info', {}))
return cls(
platform=data.get('platform'),
message_id=data.get('message_id'),
time=data.get('time'),
group_info=group_info,
user_info=user_info
)
@dataclass
class MessageBase:

View File

@@ -27,27 +27,10 @@ class MessageCQ(MessageBase):
def __init__(
self,
message_id: int,
user_id: int,
group_id: Optional[int] = None,
user_info: UserInfo,
group_info: Optional[GroupInfo] = None,
platform: str = "qq"
):
# 构造用户信息
user_info = UserInfo(
platform=platform,
user_id=user_id,
user_nickname=get_user_nickname(user_id),
user_cardname=get_user_cardname(user_id) if group_id else None
)
# 构造群组信息(如果有)
group_info = None
if group_id:
group_info = GroupInfo(
platform=platform,
group_id=group_id,
group_name=get_groupname(group_id)
)
# 构造基础消息信息
message_info = BaseMessageInfo(
platform=platform,
@@ -56,7 +39,6 @@ class MessageCQ(MessageBase):
group_info=group_info,
user_info=user_info
)
# 调用父类初始化message_segment 由子类设置
super().__init__(
message_info=message_info,
@@ -71,14 +53,17 @@ class MessageRecvCQ(MessageCQ):
def __init__(
self,
message_id: int,
user_id: int,
user_info: UserInfo,
raw_message: str,
group_id: Optional[int] = None,
group_info: Optional[GroupInfo] = None,
platform: str = "qq",
reply_message: Optional[Dict] = None,
platform: str = "qq"
):
# 调用父类初始化
super().__init__(message_id, user_id, group_id, platform)
super().__init__(message_id, user_info, group_info, platform)
if group_info.group_name is None:
group_info.group_name = get_groupname(group_info.group_id)
# 解析消息段
self.message_segment = self._parse_message(raw_message, reply_message)
@@ -117,7 +102,7 @@ class MessageRecvCQ(MessageCQ):
# 转换CQ码为Seg对象
for code_item in cq_code_dict_list:
message_obj = cq_code_tool.cq_from_dict_to_class(code_item, reply=reply_message)
message_obj = cq_code_tool.cq_from_dict_to_class(code_item,msg=self,reply=reply_message)
if message_obj.translated_segments:
segments.append(message_obj.translated_segments)
@@ -142,13 +127,14 @@ class MessageSendCQ(MessageCQ):
data: Dict
):
# 调用父类初始化
message_info = BaseMessageInfo(**data.get('message_info', {}))
message_segment = Seg(**data.get('message_segment', {}))
message_info = BaseMessageInfo.from_dict(data.get('message_info', {}))
message_segment = Seg.from_dict(data.get('message_segment', {}))
super().__init__(
message_info.message_id,
message_info.user_info.user_id,
message_info.group_info.group_id if message_info.group_info else None,
message_info.platform)
message_info.user_info,
message_info.group_info if message_info.group_info else None,
message_info.platform
)
self.message_segment = message_segment
self.raw_message = self._generate_raw_message()
@@ -171,11 +157,9 @@ class MessageSendCQ(MessageCQ):
if seg.type == 'text':
return str(seg.data)
elif seg.type == 'image':
# 如果是base64图片数据
if seg.data.startswith(('data:', 'base64:')):
return cq_code_tool.create_emoji_cq_base64(seg.data)
# 如果是表情包(本地文件)
return cq_code_tool.create_emoji_cq(seg.data)
return cq_code_tool.create_image_cq_base64(seg.data)
elif seg.type == 'emoji':
return cq_code_tool.create_emoji_cq_base64(seg.data)
elif seg.type == 'at':
return f"[CQ:at,qq={seg.data}]"
elif seg.type == 'reply':

View File

@@ -41,10 +41,10 @@ class Message_Sender:
message=message_send.raw_message,
auto_escape=False
)
print(f"\033[1;34m[调试]\033[0m 发送消息{message}成功")
print(f"\033[1;34m[调试]\033[0m 发送消息{message.processed_plain_text}成功")
except Exception as e:
print(f"发生错误 {e}")
print(f"\033[1;34m[调试]\033[0m 发送消息{message}失败")
print(f"\033[1;34m[调试]\033[0m 发送消息{message.processed_plain_text}失败")
else:
try:
await self._current_bot.send_private_msg(
@@ -52,10 +52,10 @@ class Message_Sender:
message=message_send.raw_message,
auto_escape=False
)
print(f"\033[1;34m[调试]\033[0m 发送消息{message}成功")
print(f"\033[1;34m[调试]\033[0m 发送消息{message.processed_plain_text}成功")
except Exception as e:
print(f"发生错误 {e}")
print(f"\033[1;34m[调试]\033[0m 发送消息{message}失败")
print(f"\033[1;34m[调试]\033[0m 发送消息{message.processed_plain_text}失败")
class MessageContainer:
@@ -137,11 +137,7 @@ class MessageManager:
return self.containers[chat_id]
def add_message(self, message: Union[MessageThinking, MessageSending, MessageSet]) -> None:
chat_stream = chat_manager.get_stream_by_info(
platform=message.message_info.platform,
user_info=message.message_info.user_info,
group_info=message.message_info.group_info
)
chat_stream = message.chat_stream
if not chat_stream:
raise ValueError("无法找到对应的聊天流")
container = self.get_container(chat_stream.stream_id)
@@ -165,13 +161,14 @@ class MessageManager:
else:
print(f"\033[1;34m[调试]\033[0m 消息'{message_earliest.processed_plain_text}'正在发送中")
if message_earliest.is_head and message_earliest.update_thinking_time() > 30:
await message_sender.send_message(message_earliest)
await message_sender.send_message(message_earliest.set_reply())
else:
await message_sender.send_message(message_earliest)
if message_earliest.is_emoji:
message_earliest.processed_plain_text = "[表情包]"
await self.storage.store_message(message_earliest, None)
# if message_earliest.is_emoji:
# message_earliest.processed_plain_text = "[表情包]"
await message_earliest.process()
await self.storage.store_message(message_earliest, message_earliest.chat_stream,None)
container.remove_message(message_earliest)
@@ -184,13 +181,14 @@ class MessageManager:
try:
if msg.is_head and msg.update_thinking_time() > 30:
await message_sender.send_group_message(chat_id, msg.processed_plain_text, auto_escape=False, reply_message_id=msg.reply_message_id)
await message_sender.send_message(msg.set_reply())
else:
await message_sender.send_group_message(chat_id, msg.processed_plain_text, auto_escape=False)
await message_sender.send_message(msg)
if msg.is_emoji:
msg.processed_plain_text = "[表情包]"
await self.storage.store_message(msg, None)
# if msg.is_emoji:
# msg.processed_plain_text = "[表情包]"
await msg.process()
await self.storage.store_message(msg,msg.chat_stream, None)
if not container.remove_message(msg):
print("\033[1;33m[警告]\033[0m 尝试删除不存在的消息")

View File

@@ -23,12 +23,12 @@ class Relationship:
saved = False
def __init__(self, chat:ChatStream=None,data:dict=None):
self.user_id=chat.user_info.user_id if chat.user_info else data.get('user_id',0)
self.platform=chat.platform if chat.user_info else data.get('platform','')
self.nickname=chat.user_info.user_nickname if chat.user_info else data.get('nickname','')
self.relationship_value=data.get('relationship_value',0)
self.age=data.get('age',0)
self.gender=data.get('gender','')
self.user_id=chat.user_info.user_id if chat else data.get('user_id',0)
self.platform=chat.platform if chat else data.get('platform','')
self.nickname=chat.user_info.user_nickname if chat else data.get('nickname','')
self.relationship_value=data.get('relationship_value',0) if data else 0
self.age=data.get('age',0) if data else 0
self.gender=data.get('gender','') if data else ''
class RelationshipManager:

View File

@@ -59,7 +59,7 @@ class ImageManager:
self.db.db.image_descriptions.create_index([('hash', 1)], unique=True)
self.db.db.image_descriptions.create_index([('type', 1)])
async def _get_description_from_db(self, image_hash: str, description_type: str) -> Optional[str]:
def _get_description_from_db(self, image_hash: str, description_type: str) -> Optional[str]:
"""从数据库获取图片描述
Args:
@@ -69,13 +69,13 @@ class ImageManager:
Returns:
Optional[str]: 描述文本如果不存在则返回None
"""
result = await self.db.db.image_descriptions.find_one({
result= self.db.db.image_descriptions.find_one({
'hash': image_hash,
'type': description_type
})
return result['description'] if result else None
async def _save_description_to_db(self, image_hash: str, description: str, description_type: str) -> None:
def _save_description_to_db(self, image_hash: str, description: str, description_type: str) -> None:
"""保存图片描述到数据库
Args:
@@ -83,7 +83,7 @@ class ImageManager:
description: 描述文本
description_type: 描述类型 ('emoji''image')
"""
await self.db.db.image_descriptions.update_one(
self.db.db.image_descriptions.update_one(
{'hash': image_hash, 'type': description_type},
{
'$set': {
@@ -253,8 +253,9 @@ class ImageManager:
image_hash = hashlib.md5(image_bytes).hexdigest()
# 查询缓存的描述
cached_description = await self._get_description_from_db(image_hash, 'emoji')
cached_description = self._get_description_from_db(image_hash, 'emoji')
if cached_description:
logger.info(f"缓存表情包描述: {cached_description}")
return f"[表情包:{cached_description}]"
# 调用AI获取描述
@@ -281,7 +282,7 @@ class ImageManager:
'description': description,
'timestamp': timestamp
}
await self.db.db.images.update_one(
self.db.db.images.update_one(
{'hash': image_hash},
{'$set': image_doc},
upsert=True
@@ -291,7 +292,7 @@ class ImageManager:
logger.error(f"保存表情包文件失败: {str(e)}")
# 保存描述到数据库
await self._save_description_to_db(image_hash, description, 'emoji')
self._save_description_to_db(image_hash, description, 'emoji')
return f"[表情包:{description}]"
except Exception as e:
@@ -306,7 +307,7 @@ class ImageManager:
image_hash = hashlib.md5(image_bytes).hexdigest()
# 查询缓存的描述
cached_description = await self._get_description_from_db(image_hash, 'image')
cached_description = self._get_description_from_db(image_hash, 'image')
if cached_description:
return f"[图片:{cached_description}]"
@@ -334,7 +335,7 @@ class ImageManager:
'description': description,
'timestamp': timestamp
}
await self.db.db.images.update_one(
self.db.db.images.update_one(
{'hash': image_hash},
{'$set': image_doc},
upsert=True
@@ -357,80 +358,6 @@ class ImageManager:
image_manager = ImageManager()
def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 1024 * 1024) -> str:
"""压缩base64格式的图片到指定大小
Args:
base64_data: base64编码的图片数据
target_size: 目标文件大小字节默认0.8MB
Returns:
str: 压缩后的base64图片数据
"""
try:
# 将base64转换为字节数据
image_data = base64.b64decode(base64_data)
# 如果已经小于目标大小,直接返回原图
if len(image_data) <= 2*1024*1024:
return base64_data
# 将字节数据转换为图片对象
img = Image.open(io.BytesIO(image_data))
# 获取原始尺寸
original_width, original_height = img.size
# 计算缩放比例
scale = min(1.0, (target_size / len(image_data)) ** 0.5)
# 计算新的尺寸
new_width = int(original_width * scale)
new_height = int(original_height * scale)
# 创建内存缓冲区
output_buffer = io.BytesIO()
# 如果是GIF处理所有帧
if getattr(img, "is_animated", False):
frames = []
for frame_idx in range(img.n_frames):
img.seek(frame_idx)
new_frame = img.copy()
new_frame = new_frame.resize((new_width//2, new_height//2), Image.Resampling.LANCZOS) # 动图折上折
frames.append(new_frame)
# 保存到缓冲区
frames[0].save(
output_buffer,
format='GIF',
save_all=True,
append_images=frames[1:],
optimize=True,
duration=img.info.get('duration', 100),
loop=img.info.get('loop', 0)
)
else:
# 处理静态图片
resized_img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
# 保存到缓冲区,保持原始格式
if img.format == 'PNG' and img.mode in ('RGBA', 'LA'):
resized_img.save(output_buffer, format='PNG', optimize=True)
else:
resized_img.save(output_buffer, format='JPEG', quality=95, optimize=True)
# 获取压缩后的数据并转换为base64
compressed_data = output_buffer.getvalue()
logger.success(f"压缩图片: {original_width}x{original_height} -> {new_width}x{new_height}")
logger.info(f"压缩前大小: {len(image_data)/1024:.1f}KB, 压缩后大小: {len(compressed_data)/1024:.1f}KB")
return base64.b64encode(compressed_data).decode('utf-8')
except Exception as e:
logger.error(f"压缩图片失败: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return base64_data
def image_path_to_base64(image_path: str) -> str:
"""将图片路径转换为base64编码
Args:

View File

@@ -7,10 +7,11 @@ from typing import Tuple, Union
import aiohttp
from loguru import logger
from nonebot import get_driver
import base64
from PIL import Image
import io
from ...common.database import Database
from ..chat.config import global_config
from ..chat.utils_image import compress_base64_image_by_scale
driver = get_driver()
config = driver.config
@@ -405,3 +406,77 @@ class LLM_request:
)
return embedding
def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 1024 * 1024) -> str:
"""压缩base64格式的图片到指定大小
Args:
base64_data: base64编码的图片数据
target_size: 目标文件大小字节默认0.8MB
Returns:
str: 压缩后的base64图片数据
"""
try:
# 将base64转换为字节数据
image_data = base64.b64decode(base64_data)
# 如果已经小于目标大小,直接返回原图
if len(image_data) <= 2*1024*1024:
return base64_data
# 将字节数据转换为图片对象
img = Image.open(io.BytesIO(image_data))
# 获取原始尺寸
original_width, original_height = img.size
# 计算缩放比例
scale = min(1.0, (target_size / len(image_data)) ** 0.5)
# 计算新的尺寸
new_width = int(original_width * scale)
new_height = int(original_height * scale)
# 创建内存缓冲区
output_buffer = io.BytesIO()
# 如果是GIF处理所有帧
if getattr(img, "is_animated", False):
frames = []
for frame_idx in range(img.n_frames):
img.seek(frame_idx)
new_frame = img.copy()
new_frame = new_frame.resize((new_width//2, new_height//2), Image.Resampling.LANCZOS) # 动图折上折
frames.append(new_frame)
# 保存到缓冲区
frames[0].save(
output_buffer,
format='GIF',
save_all=True,
append_images=frames[1:],
optimize=True,
duration=img.info.get('duration', 100),
loop=img.info.get('loop', 0)
)
else:
# 处理静态图片
resized_img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
# 保存到缓冲区,保持原始格式
if img.format == 'PNG' and img.mode in ('RGBA', 'LA'):
resized_img.save(output_buffer, format='PNG', optimize=True)
else:
resized_img.save(output_buffer, format='JPEG', quality=95, optimize=True)
# 获取压缩后的数据并转换为base64
compressed_data = output_buffer.getvalue()
logger.success(f"压缩图片: {original_width}x{original_height} -> {new_width}x{new_height}")
logger.info(f"压缩前大小: {len(image_data)/1024:.1f}KB, 压缩后大小: {len(compressed_data)/1024:.1f}KB")
return base64.b64encode(compressed_data).decode('utf-8')
except Exception as e:
logger.error(f"压缩图片失败: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return base64_data