feat: 史上最好的消息流重构和图片管理
This commit is contained in:
@@ -18,6 +18,7 @@ from .config import global_config
|
||||
from .emoji_manager import emoji_manager
|
||||
from .relationship_manager import relationship_manager
|
||||
from .willing_manager import willing_manager
|
||||
from .chat_stream import chat_manager
|
||||
|
||||
# 创建LLM统计实例
|
||||
llm_stats = LLMStatistics("llm_statistics.txt")
|
||||
@@ -101,6 +102,8 @@ async def _(bot: Bot):
|
||||
|
||||
asyncio.create_task(emoji_manager._periodic_scan(interval_MINS=global_config.EMOJI_REGISTER_INTERVAL))
|
||||
print("\033[1;38;5;208m-----------开始偷表情包!-----------\033[0m")
|
||||
asyncio.create_task(chat_manager._initialize())
|
||||
asyncio.create_task(chat_manager._auto_save_task())
|
||||
|
||||
@group_msg.handle()
|
||||
async def _(bot: Bot, event: GroupMessageEvent, state: T_State):
|
||||
|
||||
@@ -18,7 +18,7 @@ from .chat_stream import chat_manager
|
||||
from .message_sender import message_manager # 导入新的消息管理器
|
||||
from .relationship_manager import relationship_manager
|
||||
from .storage import MessageStorage
|
||||
from .utils import calculate_typing_time, is_mentioned_bot_in_txt
|
||||
from .utils import calculate_typing_time, is_mentioned_bot_in_message
|
||||
from .utils_image import image_path_to_base64
|
||||
from .willing_manager import willing_manager # 导入意愿管理器
|
||||
from .message_base import UserInfo, GroupInfo, Seg
|
||||
@@ -45,8 +45,8 @@ class ChatBot:
|
||||
|
||||
self.bot = bot # 更新 bot 实例
|
||||
|
||||
group_info = await bot.get_group_info(group_id=event.group_id)
|
||||
sender_info = await bot.get_group_member_info(group_id=event.group_id, user_id=event.user_id, no_cache=True)
|
||||
# group_info = await bot.get_group_info(group_id=event.group_id)
|
||||
# sender_info = await bot.get_group_member_info(group_id=event.group_id, user_id=event.user_id, no_cache=True)
|
||||
|
||||
# 白名单设定由nontbot侧完成
|
||||
if event.group_id:
|
||||
@@ -55,18 +55,31 @@ class ChatBot:
|
||||
if event.user_id in global_config.ban_user_id:
|
||||
return
|
||||
|
||||
user_info=UserInfo(
|
||||
user_id=event.user_id,
|
||||
user_nickname=event.sender.nickname,
|
||||
user_cardname=event.sender.card or None,
|
||||
platform='qq'
|
||||
)
|
||||
|
||||
group_info=GroupInfo(
|
||||
group_id=event.group_id,
|
||||
group_name=None,
|
||||
platform='qq'
|
||||
)
|
||||
|
||||
message_cq=MessageRecvCQ(
|
||||
message_id=event.message_id,
|
||||
user_id=event.user_id,
|
||||
user_info=user_info,
|
||||
raw_message=str(event.original_message),
|
||||
group_id=event.group_id,
|
||||
group_info=group_info,
|
||||
reply_message=event.reply,
|
||||
platform='qq'
|
||||
)
|
||||
message_json=message_cq.to_dict()
|
||||
|
||||
# 进入maimbot
|
||||
message=MessageRecv(**message_json)
|
||||
message=MessageRecv(message_json)
|
||||
|
||||
groupinfo=message.message_info.group_info
|
||||
userinfo=message.message_info.user_info
|
||||
@@ -75,6 +88,7 @@ class ChatBot:
|
||||
# 消息过滤,涉及到config有待更新
|
||||
|
||||
chat = await chat_manager.get_or_create_stream(platform=messageinfo.platform, user_info=userinfo, group_info=groupinfo)
|
||||
message.update_chat_stream(chat)
|
||||
await relationship_manager.update_relationship(chat_stream=chat,)
|
||||
await relationship_manager.update_relationship_value(chat_stream=chat, relationship_value = 0.5)
|
||||
|
||||
@@ -99,7 +113,7 @@ class ChatBot:
|
||||
|
||||
await self.storage.store_message(message,chat, topic[0] if topic else None)
|
||||
|
||||
is_mentioned = is_mentioned_bot_in_txt(message.processed_plain_text)
|
||||
is_mentioned = is_mentioned_bot_in_message(message)
|
||||
reply_probability = await willing_manager.change_reply_willing_received(
|
||||
chat_stream=chat,
|
||||
topic=topic[0] if topic else None,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import time
|
||||
import copy
|
||||
from typing import Dict, Optional
|
||||
|
||||
from loguru import logger
|
||||
@@ -86,9 +87,9 @@ class ChatManager:
|
||||
self._ensure_collection()
|
||||
self._initialized = True
|
||||
# 在事件循环中启动初始化
|
||||
asyncio.create_task(self._initialize())
|
||||
# 启动自动保存任务
|
||||
asyncio.create_task(self._auto_save_task())
|
||||
# asyncio.create_task(self._initialize())
|
||||
# # 启动自动保存任务
|
||||
# asyncio.create_task(self._auto_save_task())
|
||||
|
||||
async def _initialize(self):
|
||||
"""异步初始化"""
|
||||
@@ -122,11 +123,17 @@ class ChatManager:
|
||||
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
|
||||
) -> str:
|
||||
"""生成聊天流唯一ID"""
|
||||
if group_info:
|
||||
# 组合关键信息
|
||||
components = [
|
||||
platform,
|
||||
str(group_info.group_id)
|
||||
]
|
||||
else:
|
||||
components = [
|
||||
platform,
|
||||
str(user_info.user_id),
|
||||
str(group_info.group_id) if group_info else "private",
|
||||
"private"
|
||||
]
|
||||
|
||||
# 使用MD5生成唯一ID
|
||||
@@ -153,10 +160,11 @@ class ChatManager:
|
||||
if stream_id in self.streams:
|
||||
stream = self.streams[stream_id]
|
||||
# 更新用户信息和群组信息
|
||||
stream.update_active_time()
|
||||
stream=copy.deepcopy(stream)
|
||||
stream.user_info = user_info
|
||||
if group_info:
|
||||
stream.group_info = group_info
|
||||
stream.update_active_time()
|
||||
return stream
|
||||
|
||||
# 检查数据库中是否存在
|
||||
@@ -180,7 +188,7 @@ class ChatManager:
|
||||
# 保存到内存和数据库
|
||||
self.streams[stream_id] = stream
|
||||
await self._save_stream(stream)
|
||||
return stream
|
||||
return copy.deepcopy(stream)
|
||||
|
||||
def get_stream(self, stream_id: str) -> Optional[ChatStream]:
|
||||
"""通过stream_id获取聊天流"""
|
||||
|
||||
@@ -59,6 +59,7 @@ class CQCode:
|
||||
params: Dict[str, str]
|
||||
group_id: int
|
||||
user_id: int
|
||||
user_nickname: str
|
||||
group_name: str = ""
|
||||
user_nickname: str = ""
|
||||
translated_segments: Optional[Union[Seg, List[Seg]]] = None
|
||||
@@ -68,9 +69,7 @@ class CQCode:
|
||||
|
||||
def __post_init__(self):
|
||||
"""初始化LLM实例"""
|
||||
self._llm = LLM_request(
|
||||
model=global_config.vlm, temperature=0.4, max_tokens=300
|
||||
)
|
||||
pass
|
||||
|
||||
def translate(self):
|
||||
"""根据CQ码类型进行相应的翻译处理,转换为Seg对象"""
|
||||
@@ -225,8 +224,7 @@ class CQCode:
|
||||
group_id=msg.get("group_id", 0),
|
||||
)
|
||||
content_seg = Seg(
|
||||
type="seglist", data=message_obj.message_segments
|
||||
)
|
||||
type="seglist", data=message_obj.message_segment )
|
||||
else:
|
||||
content_seg = Seg(type="text", data="[空消息]")
|
||||
else:
|
||||
@@ -241,7 +239,7 @@ class CQCode:
|
||||
group_id=msg.get("group_id", 0),
|
||||
)
|
||||
content_seg = Seg(
|
||||
type="seglist", data=message_obj.message_segments
|
||||
type="seglist", data=message_obj.message_segment
|
||||
)
|
||||
else:
|
||||
content_seg = Seg(type="text", data="[空消息]")
|
||||
@@ -272,7 +270,7 @@ class CQCode:
|
||||
)
|
||||
|
||||
segments = []
|
||||
if message_obj.user_id == global_config.BOT_QQ:
|
||||
if message_obj.message_info.user_info.user_id == global_config.BOT_QQ:
|
||||
segments.append(
|
||||
Seg(
|
||||
type="text", data=f"[回复 {global_config.BOT_NICKNAME} 的消息: "
|
||||
@@ -286,7 +284,7 @@ class CQCode:
|
||||
)
|
||||
)
|
||||
|
||||
segments.append(Seg(type="seglist", data=message_obj.message_segments))
|
||||
segments.append(Seg(type="seglist", data=message_obj.message_segment))
|
||||
segments.append(Seg(type="text", data="]"))
|
||||
return segments
|
||||
else:
|
||||
@@ -305,12 +303,13 @@ class CQCode:
|
||||
|
||||
class CQCode_tool:
|
||||
@staticmethod
|
||||
def cq_from_dict_to_class(cq_code: Dict, reply: Optional[Dict] = None) -> CQCode:
|
||||
def cq_from_dict_to_class(cq_code: Dict,msg ,reply: Optional[Dict] = None) -> CQCode:
|
||||
"""
|
||||
将CQ码字典转换为CQCode对象
|
||||
|
||||
Args:
|
||||
cq_code: CQ码字典
|
||||
msg: MessageCQ对象
|
||||
reply: 回复消息的字典(可选)
|
||||
|
||||
Returns:
|
||||
@@ -326,7 +325,13 @@ class CQCode_tool:
|
||||
params = cq_code.get("data", {})
|
||||
|
||||
instance = CQCode(
|
||||
type=cq_type, params=params, group_id=0, user_id=0, reply_message=reply
|
||||
type=cq_type,
|
||||
params=params,
|
||||
group_id=msg.message_info.group_info.group_id,
|
||||
user_id=msg.message_info.user_info.user_id,
|
||||
user_nickname=msg.message_info.user_info.user_nickname,
|
||||
group_name=msg.message_info.group_info.group_name,
|
||||
reply_message=reply
|
||||
)
|
||||
|
||||
# 进行翻译处理
|
||||
@@ -384,5 +389,24 @@ class CQCode_tool:
|
||||
# 生成CQ码,设置sub_type=1表示这是表情包
|
||||
return f"[CQ:image,file=base64://{escaped_base64},sub_type=1]"
|
||||
|
||||
@staticmethod
|
||||
def create_image_cq_base64(base64_data: str) -> str:
|
||||
"""
|
||||
创建表情包CQ码
|
||||
Args:
|
||||
base64_data: base64编码的表情包数据
|
||||
Returns:
|
||||
表情包CQ码字符串
|
||||
"""
|
||||
# 转义base64数据
|
||||
escaped_base64 = (
|
||||
base64_data.replace("&", "&")
|
||||
.replace("[", "[")
|
||||
.replace("]", "]")
|
||||
.replace(",", ",")
|
||||
)
|
||||
# 生成CQ码,设置sub_type=1表示这是表情包
|
||||
return f"[CQ:image,file=base64://{escaped_base64},sub_type=0]"
|
||||
|
||||
|
||||
cq_code_tool = CQCode_tool()
|
||||
|
||||
@@ -239,7 +239,7 @@ class EmojiManager:
|
||||
# 即使表情包已存在,也检查是否需要同步到images集合
|
||||
description = existing_emoji.get('discription')
|
||||
# 检查是否在images集合中存在
|
||||
existing_image = await image_manager.db.db.images.find_one({'hash': image_hash})
|
||||
existing_image = image_manager.db.db.images.find_one({'hash': image_hash})
|
||||
if not existing_image:
|
||||
# 同步到images集合
|
||||
image_doc = {
|
||||
@@ -249,7 +249,7 @@ class EmojiManager:
|
||||
'description': description,
|
||||
'timestamp': int(time.time())
|
||||
}
|
||||
await image_manager.db.db.images.update_one(
|
||||
image_manager.db.db.images.update_one(
|
||||
{'hash': image_hash},
|
||||
{'$set': image_doc},
|
||||
upsert=True
|
||||
@@ -260,7 +260,7 @@ class EmojiManager:
|
||||
continue
|
||||
|
||||
# 检查是否在images集合中已有描述
|
||||
existing_description = await image_manager._get_description_from_db(image_hash, 'emoji')
|
||||
existing_description = image_manager._get_description_from_db(image_hash, 'emoji')
|
||||
|
||||
if existing_description:
|
||||
description = existing_description
|
||||
@@ -302,13 +302,13 @@ class EmojiManager:
|
||||
'description': description,
|
||||
'timestamp': int(time.time())
|
||||
}
|
||||
await image_manager.db.db.images.update_one(
|
||||
image_manager.db.db.images.update_one(
|
||||
{'hash': image_hash},
|
||||
{'$set': image_doc},
|
||||
upsert=True
|
||||
)
|
||||
# 保存描述到image_descriptions集合
|
||||
await image_manager._save_description_to_db(image_hash, description, 'emoji')
|
||||
image_manager._save_description_to_db(image_hash, description, 'emoji')
|
||||
logger.success(f"同步保存到images集合: {filename}")
|
||||
else:
|
||||
logger.warning(f"跳过表情包: {filename}")
|
||||
|
||||
@@ -7,7 +7,7 @@ from nonebot import get_driver
|
||||
from ...common.database import Database
|
||||
from ..models.utils_model import LLM_request
|
||||
from .config import global_config
|
||||
from .message import MessageRecv, MessageThinking, MessageSending
|
||||
from .message import MessageRecv, MessageThinking, MessageSending,Message
|
||||
from .prompt_builder import prompt_builder
|
||||
from .relationship_manager import relationship_manager
|
||||
from .utils import process_llm_response
|
||||
@@ -144,7 +144,7 @@ class ResponseGenerator:
|
||||
# content: str, content_check: str, reasoning_content: str, reasoning_content_check: str):
|
||||
def _save_to_db(
|
||||
self,
|
||||
message: Message,
|
||||
message: MessageRecv,
|
||||
sender_name: str,
|
||||
prompt: str,
|
||||
prompt_check: str,
|
||||
@@ -155,7 +155,7 @@ class ResponseGenerator:
|
||||
self.db.db.reasoning_logs.insert_one(
|
||||
{
|
||||
"time": time.time(),
|
||||
"group_id": message.group_id,
|
||||
"chat_id": message.chat_stream.stream_id,
|
||||
"user": sender_name,
|
||||
"message": message.processed_plain_text,
|
||||
"model": self.current_model_type,
|
||||
|
||||
@@ -7,7 +7,7 @@ from loguru import logger
|
||||
|
||||
from .utils_image import image_manager
|
||||
from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase
|
||||
from .chat_stream import ChatStream
|
||||
from .chat_stream import ChatStream, chat_manager
|
||||
# 禁用SSL警告
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
@@ -25,8 +25,8 @@ class MessageRecv(MessageBase):
|
||||
Args:
|
||||
message_dict: MessageCQ序列化后的字典
|
||||
"""
|
||||
message_info = BaseMessageInfo(**message_dict.get('message_info', {}))
|
||||
message_segment = Seg(**message_dict.get('message_segment', {}))
|
||||
message_info = BaseMessageInfo.from_dict(message_dict.get('message_info', {}))
|
||||
message_segment = Seg.from_dict(message_dict.get('message_segment', {}))
|
||||
raw_message = message_dict.get('raw_message')
|
||||
|
||||
super().__init__(
|
||||
@@ -39,6 +39,8 @@ class MessageRecv(MessageBase):
|
||||
self.processed_plain_text = "" # 初始化为空字符串
|
||||
self.detailed_plain_text = "" # 初始化为空字符串
|
||||
self.is_emoji=False
|
||||
def update_chat_stream(self,chat_stream:ChatStream):
|
||||
self.chat_stream=chat_stream
|
||||
|
||||
async def process(self) -> None:
|
||||
"""处理消息内容,生成纯文本和详细文本
|
||||
@@ -83,12 +85,12 @@ class MessageRecv(MessageBase):
|
||||
return seg.data
|
||||
elif seg.type == 'image':
|
||||
# 如果是base64图片数据
|
||||
if isinstance(seg.data, str) and seg.data.startswith(('data:', 'base64:')):
|
||||
if isinstance(seg.data, str):
|
||||
return await image_manager.get_image_description(seg.data)
|
||||
return '[图片]'
|
||||
elif seg.type == 'emoji':
|
||||
self.is_emoji=True
|
||||
if isinstance(seg.data, str) and seg.data.startswith(('data:', 'base64:')):
|
||||
if isinstance(seg.data, str):
|
||||
return await image_manager.get_emoji_description(seg.data)
|
||||
return '[表情]'
|
||||
else:
|
||||
@@ -217,11 +219,11 @@ class MessageProcessBase(Message):
|
||||
return seg.data
|
||||
elif seg.type == 'image':
|
||||
# 如果是base64图片数据
|
||||
if isinstance(seg.data, str) and seg.data.startswith(('data:', 'base64:')):
|
||||
if isinstance(seg.data, str):
|
||||
return await image_manager.get_image_description(seg.data)
|
||||
return '[图片]'
|
||||
elif seg.type == 'emoji':
|
||||
if isinstance(seg.data, str) and seg.data.startswith(('data:', 'base64:')):
|
||||
if isinstance(seg.data, str):
|
||||
return await image_manager.get_emoji_description(seg.data)
|
||||
return '[表情]'
|
||||
elif seg.type == 'at':
|
||||
@@ -296,7 +298,12 @@ class MessageSending(MessageProcessBase):
|
||||
self.reply_to_message_id = reply.message_info.message_id if reply else None
|
||||
self.is_head = is_head
|
||||
self.is_emoji = is_emoji
|
||||
if is_head:
|
||||
|
||||
def set_reply(self, reply: Optional['MessageRecv']) -> None:
|
||||
"""设置回复消息"""
|
||||
if reply:
|
||||
self.reply = reply
|
||||
self.reply_to_message_id = self.reply.message_info.message_id
|
||||
self.message_segment = Seg(type='seglist', data=[
|
||||
Seg(type='reply', data=reply.message_info.message_id),
|
||||
self.message_segment
|
||||
@@ -329,7 +336,7 @@ class MessageSending(MessageProcessBase):
|
||||
|
||||
def to_dict(self):
|
||||
ret= super().to_dict()
|
||||
ret['mesage_info']['user_info']=self.chat_stream.user_info.to_dict()
|
||||
ret['message_info']['user_info']=self.chat_stream.user_info.to_dict()
|
||||
return ret
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -2,7 +2,7 @@ from dataclasses import dataclass, asdict
|
||||
from typing import List, Optional, Union, Any, Dict
|
||||
|
||||
@dataclass
|
||||
class Seg(dict):
|
||||
class Seg:
|
||||
"""消息片段类,用于表示消息的不同部分
|
||||
|
||||
Attributes:
|
||||
@@ -15,47 +15,25 @@ class Seg(dict):
|
||||
"""
|
||||
type: str
|
||||
data: Union[str, List['Seg']]
|
||||
translated_data: Optional[str] = None
|
||||
|
||||
def __init__(self, type: str, data: Union[str, List['Seg']], translated_data: Optional[str] = None):
|
||||
"""初始化实例,确保字典和属性同步"""
|
||||
# 先初始化字典
|
||||
super().__init__(type=type, data=data)
|
||||
if translated_data is not None:
|
||||
self['translated_data'] = translated_data
|
||||
|
||||
# 再初始化属性
|
||||
object.__setattr__(self, 'type', type)
|
||||
object.__setattr__(self, 'data', data)
|
||||
object.__setattr__(self, 'translated_data', translated_data)
|
||||
# def __init__(self, type: str, data: Union[str, List['Seg']],):
|
||||
# """初始化实例,确保字典和属性同步"""
|
||||
# # 先初始化字典
|
||||
# self.type = type
|
||||
# self.data = data
|
||||
|
||||
# 验证数据类型
|
||||
self._validate_data()
|
||||
|
||||
def _validate_data(self) -> None:
|
||||
"""验证数据类型的正确性"""
|
||||
if self.type == 'seglist' and not isinstance(self.data, list):
|
||||
raise ValueError("seglist类型的data必须是列表")
|
||||
elif self.type == 'text' and not isinstance(self.data, str):
|
||||
raise ValueError("text类型的data必须是字符串")
|
||||
elif self.type == 'image' and not isinstance(self.data, str):
|
||||
raise ValueError("image类型的data必须是字符串")
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
"""重写属性设置,同时更新字典值"""
|
||||
# 更新属性
|
||||
object.__setattr__(self, name, value)
|
||||
# 同步更新字典
|
||||
if name in ['type', 'data', 'translated_data']:
|
||||
self[name] = value
|
||||
|
||||
def __setitem__(self, key: str, value: Any) -> None:
|
||||
"""重写字典值设置,同时更新属性"""
|
||||
# 更新字典
|
||||
super().__setitem__(key, value)
|
||||
# 同步更新属性
|
||||
if key in ['type', 'data', 'translated_data']:
|
||||
object.__setattr__(self, key, value)
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> 'Seg':
|
||||
"""从字典创建Seg实例"""
|
||||
type=data.get('type')
|
||||
data=data.get('data')
|
||||
if type == 'seglist':
|
||||
data = [Seg.from_dict(seg) for seg in data]
|
||||
return cls(
|
||||
type=type,
|
||||
data=data
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""转换为字典格式"""
|
||||
@@ -64,8 +42,6 @@ class Seg(dict):
|
||||
result['data'] = [seg.to_dict() for seg in self.data]
|
||||
else:
|
||||
result['data'] = self.data
|
||||
if self.translated_data is not None:
|
||||
result['translated_data'] = self.translated_data
|
||||
return result
|
||||
|
||||
@dataclass
|
||||
@@ -79,6 +55,7 @@ class GroupInfo:
|
||||
"""转换为字典格式"""
|
||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> 'GroupInfo':
|
||||
"""从字典创建GroupInfo实例
|
||||
|
||||
@@ -106,6 +83,7 @@ class UserInfo:
|
||||
"""转换为字典格式"""
|
||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> 'UserInfo':
|
||||
"""从字典创建UserInfo实例
|
||||
|
||||
@@ -126,7 +104,7 @@ class UserInfo:
|
||||
class BaseMessageInfo:
|
||||
"""消息信息类"""
|
||||
platform: Optional[str] = None
|
||||
message_id: Optional[int,str] = None
|
||||
message_id: Union[str,int,None] = None
|
||||
time: Optional[int] = None
|
||||
group_info: Optional[GroupInfo] = None
|
||||
user_info: Optional[UserInfo] = None
|
||||
@@ -141,6 +119,25 @@ class BaseMessageInfo:
|
||||
else:
|
||||
result[field] = value
|
||||
return result
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> 'BaseMessageInfo':
|
||||
"""从字典创建BaseMessageInfo实例
|
||||
|
||||
Args:
|
||||
data: 包含必要字段的字典
|
||||
|
||||
Returns:
|
||||
BaseMessageInfo: 新的实例
|
||||
"""
|
||||
group_info = GroupInfo(**data.get('group_info', {}))
|
||||
user_info = UserInfo(**data.get('user_info', {}))
|
||||
return cls(
|
||||
platform=data.get('platform'),
|
||||
message_id=data.get('message_id'),
|
||||
time=data.get('time'),
|
||||
group_info=group_info,
|
||||
user_info=user_info
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class MessageBase:
|
||||
|
||||
@@ -27,27 +27,10 @@ class MessageCQ(MessageBase):
|
||||
def __init__(
|
||||
self,
|
||||
message_id: int,
|
||||
user_id: int,
|
||||
group_id: Optional[int] = None,
|
||||
user_info: UserInfo,
|
||||
group_info: Optional[GroupInfo] = None,
|
||||
platform: str = "qq"
|
||||
):
|
||||
# 构造用户信息
|
||||
user_info = UserInfo(
|
||||
platform=platform,
|
||||
user_id=user_id,
|
||||
user_nickname=get_user_nickname(user_id),
|
||||
user_cardname=get_user_cardname(user_id) if group_id else None
|
||||
)
|
||||
|
||||
# 构造群组信息(如果有)
|
||||
group_info = None
|
||||
if group_id:
|
||||
group_info = GroupInfo(
|
||||
platform=platform,
|
||||
group_id=group_id,
|
||||
group_name=get_groupname(group_id)
|
||||
)
|
||||
|
||||
# 构造基础消息信息
|
||||
message_info = BaseMessageInfo(
|
||||
platform=platform,
|
||||
@@ -56,7 +39,6 @@ class MessageCQ(MessageBase):
|
||||
group_info=group_info,
|
||||
user_info=user_info
|
||||
)
|
||||
|
||||
# 调用父类初始化,message_segment 由子类设置
|
||||
super().__init__(
|
||||
message_info=message_info,
|
||||
@@ -71,14 +53,17 @@ class MessageRecvCQ(MessageCQ):
|
||||
def __init__(
|
||||
self,
|
||||
message_id: int,
|
||||
user_id: int,
|
||||
user_info: UserInfo,
|
||||
raw_message: str,
|
||||
group_id: Optional[int] = None,
|
||||
group_info: Optional[GroupInfo] = None,
|
||||
platform: str = "qq",
|
||||
reply_message: Optional[Dict] = None,
|
||||
platform: str = "qq"
|
||||
):
|
||||
# 调用父类初始化
|
||||
super().__init__(message_id, user_id, group_id, platform)
|
||||
super().__init__(message_id, user_info, group_info, platform)
|
||||
|
||||
if group_info.group_name is None:
|
||||
group_info.group_name = get_groupname(group_info.group_id)
|
||||
|
||||
# 解析消息段
|
||||
self.message_segment = self._parse_message(raw_message, reply_message)
|
||||
@@ -117,7 +102,7 @@ class MessageRecvCQ(MessageCQ):
|
||||
|
||||
# 转换CQ码为Seg对象
|
||||
for code_item in cq_code_dict_list:
|
||||
message_obj = cq_code_tool.cq_from_dict_to_class(code_item, reply=reply_message)
|
||||
message_obj = cq_code_tool.cq_from_dict_to_class(code_item,msg=self,reply=reply_message)
|
||||
if message_obj.translated_segments:
|
||||
segments.append(message_obj.translated_segments)
|
||||
|
||||
@@ -142,13 +127,14 @@ class MessageSendCQ(MessageCQ):
|
||||
data: Dict
|
||||
):
|
||||
# 调用父类初始化
|
||||
message_info = BaseMessageInfo(**data.get('message_info', {}))
|
||||
message_segment = Seg(**data.get('message_segment', {}))
|
||||
message_info = BaseMessageInfo.from_dict(data.get('message_info', {}))
|
||||
message_segment = Seg.from_dict(data.get('message_segment', {}))
|
||||
super().__init__(
|
||||
message_info.message_id,
|
||||
message_info.user_info.user_id,
|
||||
message_info.group_info.group_id if message_info.group_info else None,
|
||||
message_info.platform)
|
||||
message_info.user_info,
|
||||
message_info.group_info if message_info.group_info else None,
|
||||
message_info.platform
|
||||
)
|
||||
|
||||
self.message_segment = message_segment
|
||||
self.raw_message = self._generate_raw_message()
|
||||
@@ -171,11 +157,9 @@ class MessageSendCQ(MessageCQ):
|
||||
if seg.type == 'text':
|
||||
return str(seg.data)
|
||||
elif seg.type == 'image':
|
||||
# 如果是base64图片数据
|
||||
if seg.data.startswith(('data:', 'base64:')):
|
||||
return cq_code_tool.create_image_cq_base64(seg.data)
|
||||
elif seg.type == 'emoji':
|
||||
return cq_code_tool.create_emoji_cq_base64(seg.data)
|
||||
# 如果是表情包(本地文件)
|
||||
return cq_code_tool.create_emoji_cq(seg.data)
|
||||
elif seg.type == 'at':
|
||||
return f"[CQ:at,qq={seg.data}]"
|
||||
elif seg.type == 'reply':
|
||||
|
||||
@@ -41,10 +41,10 @@ class Message_Sender:
|
||||
message=message_send.raw_message,
|
||||
auto_escape=False
|
||||
)
|
||||
print(f"\033[1;34m[调试]\033[0m 发送消息{message}成功")
|
||||
print(f"\033[1;34m[调试]\033[0m 发送消息{message.processed_plain_text}成功")
|
||||
except Exception as e:
|
||||
print(f"发生错误 {e}")
|
||||
print(f"\033[1;34m[调试]\033[0m 发送消息{message}失败")
|
||||
print(f"\033[1;34m[调试]\033[0m 发送消息{message.processed_plain_text}失败")
|
||||
else:
|
||||
try:
|
||||
await self._current_bot.send_private_msg(
|
||||
@@ -52,10 +52,10 @@ class Message_Sender:
|
||||
message=message_send.raw_message,
|
||||
auto_escape=False
|
||||
)
|
||||
print(f"\033[1;34m[调试]\033[0m 发送消息{message}成功")
|
||||
print(f"\033[1;34m[调试]\033[0m 发送消息{message.processed_plain_text}成功")
|
||||
except Exception as e:
|
||||
print(f"发生错误 {e}")
|
||||
print(f"\033[1;34m[调试]\033[0m 发送消息{message}失败")
|
||||
print(f"\033[1;34m[调试]\033[0m 发送消息{message.processed_plain_text}失败")
|
||||
|
||||
|
||||
class MessageContainer:
|
||||
@@ -137,11 +137,7 @@ class MessageManager:
|
||||
return self.containers[chat_id]
|
||||
|
||||
def add_message(self, message: Union[MessageThinking, MessageSending, MessageSet]) -> None:
|
||||
chat_stream = chat_manager.get_stream_by_info(
|
||||
platform=message.message_info.platform,
|
||||
user_info=message.message_info.user_info,
|
||||
group_info=message.message_info.group_info
|
||||
)
|
||||
chat_stream = message.chat_stream
|
||||
if not chat_stream:
|
||||
raise ValueError("无法找到对应的聊天流")
|
||||
container = self.get_container(chat_stream.stream_id)
|
||||
@@ -165,13 +161,14 @@ class MessageManager:
|
||||
else:
|
||||
print(f"\033[1;34m[调试]\033[0m 消息'{message_earliest.processed_plain_text}'正在发送中")
|
||||
if message_earliest.is_head and message_earliest.update_thinking_time() > 30:
|
||||
await message_sender.send_message(message_earliest)
|
||||
await message_sender.send_message(message_earliest.set_reply())
|
||||
else:
|
||||
await message_sender.send_message(message_earliest)
|
||||
|
||||
if message_earliest.is_emoji:
|
||||
message_earliest.processed_plain_text = "[表情包]"
|
||||
await self.storage.store_message(message_earliest, None)
|
||||
# if message_earliest.is_emoji:
|
||||
# message_earliest.processed_plain_text = "[表情包]"
|
||||
await message_earliest.process()
|
||||
await self.storage.store_message(message_earliest, message_earliest.chat_stream,None)
|
||||
|
||||
container.remove_message(message_earliest)
|
||||
|
||||
@@ -184,13 +181,14 @@ class MessageManager:
|
||||
|
||||
try:
|
||||
if msg.is_head and msg.update_thinking_time() > 30:
|
||||
await message_sender.send_group_message(chat_id, msg.processed_plain_text, auto_escape=False, reply_message_id=msg.reply_message_id)
|
||||
await message_sender.send_message(msg.set_reply())
|
||||
else:
|
||||
await message_sender.send_group_message(chat_id, msg.processed_plain_text, auto_escape=False)
|
||||
await message_sender.send_message(msg)
|
||||
|
||||
if msg.is_emoji:
|
||||
msg.processed_plain_text = "[表情包]"
|
||||
await self.storage.store_message(msg, None)
|
||||
# if msg.is_emoji:
|
||||
# msg.processed_plain_text = "[表情包]"
|
||||
await msg.process()
|
||||
await self.storage.store_message(msg,msg.chat_stream, None)
|
||||
|
||||
if not container.remove_message(msg):
|
||||
print("\033[1;33m[警告]\033[0m 尝试删除不存在的消息")
|
||||
|
||||
@@ -23,12 +23,12 @@ class Relationship:
|
||||
saved = False
|
||||
|
||||
def __init__(self, chat:ChatStream=None,data:dict=None):
|
||||
self.user_id=chat.user_info.user_id if chat.user_info else data.get('user_id',0)
|
||||
self.platform=chat.platform if chat.user_info else data.get('platform','')
|
||||
self.nickname=chat.user_info.user_nickname if chat.user_info else data.get('nickname','')
|
||||
self.relationship_value=data.get('relationship_value',0)
|
||||
self.age=data.get('age',0)
|
||||
self.gender=data.get('gender','')
|
||||
self.user_id=chat.user_info.user_id if chat else data.get('user_id',0)
|
||||
self.platform=chat.platform if chat else data.get('platform','')
|
||||
self.nickname=chat.user_info.user_nickname if chat else data.get('nickname','')
|
||||
self.relationship_value=data.get('relationship_value',0) if data else 0
|
||||
self.age=data.get('age',0) if data else 0
|
||||
self.gender=data.get('gender','') if data else ''
|
||||
|
||||
|
||||
class RelationshipManager:
|
||||
|
||||
@@ -59,7 +59,7 @@ class ImageManager:
|
||||
self.db.db.image_descriptions.create_index([('hash', 1)], unique=True)
|
||||
self.db.db.image_descriptions.create_index([('type', 1)])
|
||||
|
||||
async def _get_description_from_db(self, image_hash: str, description_type: str) -> Optional[str]:
|
||||
def _get_description_from_db(self, image_hash: str, description_type: str) -> Optional[str]:
|
||||
"""从数据库获取图片描述
|
||||
|
||||
Args:
|
||||
@@ -69,13 +69,13 @@ class ImageManager:
|
||||
Returns:
|
||||
Optional[str]: 描述文本,如果不存在则返回None
|
||||
"""
|
||||
result = await self.db.db.image_descriptions.find_one({
|
||||
result= self.db.db.image_descriptions.find_one({
|
||||
'hash': image_hash,
|
||||
'type': description_type
|
||||
})
|
||||
return result['description'] if result else None
|
||||
|
||||
async def _save_description_to_db(self, image_hash: str, description: str, description_type: str) -> None:
|
||||
def _save_description_to_db(self, image_hash: str, description: str, description_type: str) -> None:
|
||||
"""保存图片描述到数据库
|
||||
|
||||
Args:
|
||||
@@ -83,7 +83,7 @@ class ImageManager:
|
||||
description: 描述文本
|
||||
description_type: 描述类型 ('emoji' 或 'image')
|
||||
"""
|
||||
await self.db.db.image_descriptions.update_one(
|
||||
self.db.db.image_descriptions.update_one(
|
||||
{'hash': image_hash, 'type': description_type},
|
||||
{
|
||||
'$set': {
|
||||
@@ -253,8 +253,9 @@ class ImageManager:
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
|
||||
# 查询缓存的描述
|
||||
cached_description = await self._get_description_from_db(image_hash, 'emoji')
|
||||
cached_description = self._get_description_from_db(image_hash, 'emoji')
|
||||
if cached_description:
|
||||
logger.info(f"缓存表情包描述: {cached_description}")
|
||||
return f"[表情包:{cached_description}]"
|
||||
|
||||
# 调用AI获取描述
|
||||
@@ -281,7 +282,7 @@ class ImageManager:
|
||||
'description': description,
|
||||
'timestamp': timestamp
|
||||
}
|
||||
await self.db.db.images.update_one(
|
||||
self.db.db.images.update_one(
|
||||
{'hash': image_hash},
|
||||
{'$set': image_doc},
|
||||
upsert=True
|
||||
@@ -291,7 +292,7 @@ class ImageManager:
|
||||
logger.error(f"保存表情包文件失败: {str(e)}")
|
||||
|
||||
# 保存描述到数据库
|
||||
await self._save_description_to_db(image_hash, description, 'emoji')
|
||||
self._save_description_to_db(image_hash, description, 'emoji')
|
||||
|
||||
return f"[表情包:{description}]"
|
||||
except Exception as e:
|
||||
@@ -306,7 +307,7 @@ class ImageManager:
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
|
||||
# 查询缓存的描述
|
||||
cached_description = await self._get_description_from_db(image_hash, 'image')
|
||||
cached_description = self._get_description_from_db(image_hash, 'image')
|
||||
if cached_description:
|
||||
return f"[图片:{cached_description}]"
|
||||
|
||||
@@ -334,7 +335,7 @@ class ImageManager:
|
||||
'description': description,
|
||||
'timestamp': timestamp
|
||||
}
|
||||
await self.db.db.images.update_one(
|
||||
self.db.db.images.update_one(
|
||||
{'hash': image_hash},
|
||||
{'$set': image_doc},
|
||||
upsert=True
|
||||
@@ -357,80 +358,6 @@ class ImageManager:
|
||||
image_manager = ImageManager()
|
||||
|
||||
|
||||
def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 1024 * 1024) -> str:
|
||||
"""压缩base64格式的图片到指定大小
|
||||
Args:
|
||||
base64_data: base64编码的图片数据
|
||||
target_size: 目标文件大小(字节),默认0.8MB
|
||||
Returns:
|
||||
str: 压缩后的base64图片数据
|
||||
"""
|
||||
try:
|
||||
# 将base64转换为字节数据
|
||||
image_data = base64.b64decode(base64_data)
|
||||
|
||||
# 如果已经小于目标大小,直接返回原图
|
||||
if len(image_data) <= 2*1024*1024:
|
||||
return base64_data
|
||||
|
||||
# 将字节数据转换为图片对象
|
||||
img = Image.open(io.BytesIO(image_data))
|
||||
|
||||
# 获取原始尺寸
|
||||
original_width, original_height = img.size
|
||||
|
||||
# 计算缩放比例
|
||||
scale = min(1.0, (target_size / len(image_data)) ** 0.5)
|
||||
|
||||
# 计算新的尺寸
|
||||
new_width = int(original_width * scale)
|
||||
new_height = int(original_height * scale)
|
||||
|
||||
# 创建内存缓冲区
|
||||
output_buffer = io.BytesIO()
|
||||
|
||||
# 如果是GIF,处理所有帧
|
||||
if getattr(img, "is_animated", False):
|
||||
frames = []
|
||||
for frame_idx in range(img.n_frames):
|
||||
img.seek(frame_idx)
|
||||
new_frame = img.copy()
|
||||
new_frame = new_frame.resize((new_width//2, new_height//2), Image.Resampling.LANCZOS) # 动图折上折
|
||||
frames.append(new_frame)
|
||||
|
||||
# 保存到缓冲区
|
||||
frames[0].save(
|
||||
output_buffer,
|
||||
format='GIF',
|
||||
save_all=True,
|
||||
append_images=frames[1:],
|
||||
optimize=True,
|
||||
duration=img.info.get('duration', 100),
|
||||
loop=img.info.get('loop', 0)
|
||||
)
|
||||
else:
|
||||
# 处理静态图片
|
||||
resized_img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||||
|
||||
# 保存到缓冲区,保持原始格式
|
||||
if img.format == 'PNG' and img.mode in ('RGBA', 'LA'):
|
||||
resized_img.save(output_buffer, format='PNG', optimize=True)
|
||||
else:
|
||||
resized_img.save(output_buffer, format='JPEG', quality=95, optimize=True)
|
||||
|
||||
# 获取压缩后的数据并转换为base64
|
||||
compressed_data = output_buffer.getvalue()
|
||||
logger.success(f"压缩图片: {original_width}x{original_height} -> {new_width}x{new_height}")
|
||||
logger.info(f"压缩前大小: {len(image_data)/1024:.1f}KB, 压缩后大小: {len(compressed_data)/1024:.1f}KB")
|
||||
|
||||
return base64.b64encode(compressed_data).decode('utf-8')
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"压缩图片失败: {str(e)}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return base64_data
|
||||
|
||||
def image_path_to_base64(image_path: str) -> str:
|
||||
"""将图片路径转换为base64编码
|
||||
Args:
|
||||
|
||||
@@ -7,10 +7,11 @@ from typing import Tuple, Union
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
from nonebot import get_driver
|
||||
|
||||
import base64
|
||||
from PIL import Image
|
||||
import io
|
||||
from ...common.database import Database
|
||||
from ..chat.config import global_config
|
||||
from ..chat.utils_image import compress_base64_image_by_scale
|
||||
|
||||
driver = get_driver()
|
||||
config = driver.config
|
||||
@@ -405,3 +406,77 @@ class LLM_request:
|
||||
)
|
||||
return embedding
|
||||
|
||||
def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 1024 * 1024) -> str:
|
||||
"""压缩base64格式的图片到指定大小
|
||||
Args:
|
||||
base64_data: base64编码的图片数据
|
||||
target_size: 目标文件大小(字节),默认0.8MB
|
||||
Returns:
|
||||
str: 压缩后的base64图片数据
|
||||
"""
|
||||
try:
|
||||
# 将base64转换为字节数据
|
||||
image_data = base64.b64decode(base64_data)
|
||||
|
||||
# 如果已经小于目标大小,直接返回原图
|
||||
if len(image_data) <= 2*1024*1024:
|
||||
return base64_data
|
||||
|
||||
# 将字节数据转换为图片对象
|
||||
img = Image.open(io.BytesIO(image_data))
|
||||
|
||||
# 获取原始尺寸
|
||||
original_width, original_height = img.size
|
||||
|
||||
# 计算缩放比例
|
||||
scale = min(1.0, (target_size / len(image_data)) ** 0.5)
|
||||
|
||||
# 计算新的尺寸
|
||||
new_width = int(original_width * scale)
|
||||
new_height = int(original_height * scale)
|
||||
|
||||
# 创建内存缓冲区
|
||||
output_buffer = io.BytesIO()
|
||||
|
||||
# 如果是GIF,处理所有帧
|
||||
if getattr(img, "is_animated", False):
|
||||
frames = []
|
||||
for frame_idx in range(img.n_frames):
|
||||
img.seek(frame_idx)
|
||||
new_frame = img.copy()
|
||||
new_frame = new_frame.resize((new_width//2, new_height//2), Image.Resampling.LANCZOS) # 动图折上折
|
||||
frames.append(new_frame)
|
||||
|
||||
# 保存到缓冲区
|
||||
frames[0].save(
|
||||
output_buffer,
|
||||
format='GIF',
|
||||
save_all=True,
|
||||
append_images=frames[1:],
|
||||
optimize=True,
|
||||
duration=img.info.get('duration', 100),
|
||||
loop=img.info.get('loop', 0)
|
||||
)
|
||||
else:
|
||||
# 处理静态图片
|
||||
resized_img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||||
|
||||
# 保存到缓冲区,保持原始格式
|
||||
if img.format == 'PNG' and img.mode in ('RGBA', 'LA'):
|
||||
resized_img.save(output_buffer, format='PNG', optimize=True)
|
||||
else:
|
||||
resized_img.save(output_buffer, format='JPEG', quality=95, optimize=True)
|
||||
|
||||
# 获取压缩后的数据并转换为base64
|
||||
compressed_data = output_buffer.getvalue()
|
||||
logger.success(f"压缩图片: {original_width}x{original_height} -> {new_width}x{new_height}")
|
||||
logger.info(f"压缩前大小: {len(image_data)/1024:.1f}KB, 压缩后大小: {len(compressed_data)/1024:.1f}KB")
|
||||
|
||||
return base64.b64encode(compressed_data).decode('utf-8')
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"压缩图片失败: {str(e)}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return base64_data
|
||||
|
||||
|
||||
Reference in New Issue
Block a user