feat: 超大型重构

This commit is contained in:
tcmofashi
2025-03-09 10:27:54 +08:00
parent 8754571afb
commit fe3684736a
11 changed files with 1209 additions and 566 deletions

View File

@@ -7,10 +7,10 @@ from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent
from ..memory_system.memory import hippocampus
from ..moods.moods import MoodManager # 导入情绪管理器
from .config import global_config
from .cq_code import CQCode # 导入CQCode模块
from .cq_code import CQCode,cq_code_tool # 导入CQCode模块
from .emoji_manager import emoji_manager # 导入表情包管理器
from .llm_generator import ResponseGenerator
from .message import (
from .message_cq import (
Message,
Message_Sending,
Message_Thinking, # 导入 Message_Thinking 类
@@ -180,7 +180,7 @@ class ChatBot:
if emoji_raw != None:
emoji_path,discription = emoji_raw
emoji_cq = CQCode.create_emoji_cq(emoji_path)
emoji_cq = cq_code_tool.create_emoji_cq(emoji_path)
if random() < 0.5:
bot_response_time = tinking_time_point - 1

View File

@@ -3,7 +3,7 @@ import html
import os
import time
from dataclasses import dataclass
from typing import Dict, Optional
from typing import Dict, Optional, List, Union
import requests
@@ -12,12 +12,14 @@ import requests
import urllib3
from nonebot import get_driver
from urllib3.util import create_urllib3_context
from loguru import logger
from ..models.utils_model import LLM_request
from .config import global_config
from .mapper import emojimapper
from .utils_image import storage_emoji, storage_image
from .utils_image import image_manager
from .utils_user import get_user_nickname
from .message_base import Seg
driver = get_driver()
config = driver.config
@@ -48,16 +50,15 @@ class CQCode:
type: CQ码类型'image', 'at', 'face'等)
params: CQ码的参数字典
raw_code: 原始CQ码字符串
translated_plain_text: 经过处理如AI翻译后的文本表示
translated_segments: 经过处理后的Seg对象列表
"""
type: str
params: Dict[str, str]
# raw_code: str
group_id: int
user_id: int
group_name: str = ""
user_nickname: str = ""
translated_plain_text: Optional[str] = None
translated_segments: Optional[Union[Seg, List[Seg]]] = None
reply_message: Dict = None # 存储回复消息
image_base64: Optional[str] = None
_llm: Optional[LLM_request] = None
@@ -66,31 +67,72 @@ class CQCode:
"""初始化LLM实例"""
self._llm = LLM_request(model=global_config.vlm, temperature=0.4, max_tokens=300)
async def translate(self):
"""根据CQ码类型进行相应的翻译处理"""
def translate(self):
"""根据CQ码类型进行相应的翻译处理转换为Seg对象"""
if self.type == 'text':
self.translated_plain_text = self.params.get('text', '')
self.translated_segments = Seg(
type='text',
data=self.params.get('text', '')
)
elif self.type == 'image':
if self.params.get('sub_type') == '0':
self.translated_plain_text = await self.translate_image()
base64_data = self.translate_image()
if base64_data:
if self.params.get('sub_type') == '0':
self.translated_segments = Seg(
type='image',
data=base64_data
)
else:
self.translated_segments = Seg(
type='emoji',
data=base64_data
)
else:
self.translated_plain_text = await self.translate_emoji()
self.translated_segments = Seg(
type='text',
data='[图片]'
)
elif self.type == 'at':
user_nickname = get_user_nickname(self.params.get('qq', ''))
if user_nickname:
self.translated_plain_text = f"[@{user_nickname}]"
else:
self.translated_plain_text = "@某人"
self.translated_segments = Seg(
type='text',
data=f"[@{user_nickname or '某人'}]"
)
elif self.type == 'reply':
self.translated_plain_text = await self.translate_reply()
reply_segments = self.translate_reply()
if reply_segments:
self.translated_segments = Seg(
type='seglist',
data=reply_segments
)
else:
self.translated_segments = Seg(
type='text',
data='[回复某人消息]'
)
elif self.type == 'face':
face_id = self.params.get('id', '')
# self.translated_plain_text = f"[表情{face_id}]"
self.translated_plain_text = f"[{emojimapper.get(int(face_id), '表情')}]"
self.translated_segments = Seg(
type='text',
data=f"[{emojimapper.get(int(face_id), '表情')}]"
)
elif self.type == 'forward':
self.translated_plain_text = await self.translate_forward()
forward_segments = self.translate_forward()
if forward_segments:
self.translated_segments = Seg(
type='seglist',
data=forward_segments
)
else:
self.translated_segments = Seg(
type='text',
data='[转发消息]'
)
else:
self.translated_plain_text = f"[{self.type}]"
self.translated_segments = Seg(
type='text',
data=f"[{self.type}]"
)
def get_img(self):
'''
@@ -160,155 +202,101 @@ class CQCode:
return None
async def translate_emoji(self) -> str:
"""处理表情包类型的CQ码"""
def translate_image(self) -> Optional[str]:
"""处理图片类型的CQ码返回base64字符串"""
if 'url' not in self.params:
return '[表情包]'
base64_str = self.get_img()
if base64_str:
# 将 base64 字符串转换为字节类型
image_bytes = base64.b64decode(base64_str)
storage_emoji(image_bytes)
return await self.get_emoji_description(base64_str)
else:
return '[表情包]'
return None
return self.get_img()
async def translate_image(self) -> str:
"""处理图片类型的CQ码区分普通图片和表情包"""
# 没有url直接返回默认文本
if 'url' not in self.params:
return '[图片]'
base64_str = self.get_img()
if base64_str:
image_bytes = base64.b64decode(base64_str)
storage_image(image_bytes)
return await self.get_image_description(base64_str)
else:
return '[图片]'
async def get_emoji_description(self, image_base64: str) -> str:
"""调用AI接口获取表情包描述"""
try:
prompt = "这是一个表情包请用简短的中文描述这个表情包传达的情感和含义。最多20个字。"
# description, _ = self._llm.generate_response_for_image_sync(prompt, image_base64)
description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
return f"[表情包:{description}]"
except Exception as e:
print(f"\033[1;31m[错误]\033[0m AI接口调用失败: {str(e)}")
return "[表情包]"
async def get_image_description(self, image_base64: str) -> str:
"""调用AI接口获取普通图片描述"""
try:
prompt = "请用中文描述这张图片的内容。如果有文字请把文字都描述出来。并尝试猜测这个图片的含义。最多200个字。"
# description, _ = self._llm.generate_response_for_image_sync(prompt, image_base64)
description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
return f"[图片:{description}]"
except Exception as e:
print(f"\033[1;31m[错误]\033[0m AI接口调用失败: {str(e)}")
return "[图片]"
async def translate_forward(self) -> str:
"""处理转发消息"""
def translate_forward(self) -> Optional[List[Seg]]:
"""处理转发消息返回Seg列表"""
try:
if 'content' not in self.params:
return '[转发消息]'
return None
# 解析content内容需要先反转义
content = self.unescape(self.params['content'])
# print(f"\033[1;34m[调试信息]\033[0m 转发消息内容: {content}")
# 将字符串形式的列表转换为Python对象
import ast
try:
messages = ast.literal_eval(content)
except ValueError as e:
print(f"\033[1;31m[错误]\033[0m 解析转发消息内容失败: {str(e)}")
return '[转发消息]'
logger.error(f"解析转发消息内容失败: {str(e)}")
return None
# 处理每条消息
formatted_messages = []
formatted_segments = []
for msg in messages:
sender = msg.get('sender', {})
nickname = sender.get('card') or sender.get('nickname', '未知用户')
# 获取消息内容并使用Message类处理
raw_message = msg.get('raw_message', '')
message_array = msg.get('message', [])
if message_array and isinstance(message_array, list):
# 检查是否包含嵌套的转发消息
for message_part in message_array:
if message_part.get('type') == 'forward':
content = '[转发消息]'
content_seg = Seg(type='text', data='[转发消息]')
break
else:
# 处理普通消息
if raw_message:
from .message import Message
message_obj = Message(
user_id=msg.get('user_id', 0),
message_id=msg.get('message_id', 0),
raw_message=raw_message,
plain_text=raw_message,
group_id=msg.get('group_id', 0)
)
await message_obj.initialize()
content = message_obj.processed_plain_text
else:
content = '[空消息]'
if raw_message:
from .message_cq import MessageRecvCQ
message_obj = MessageRecvCQ(
user_id=msg.get('user_id', 0),
message_id=msg.get('message_id', 0),
raw_message=raw_message,
plain_text=raw_message,
group_id=msg.get('group_id', 0)
)
content_seg = Seg(type='seglist', data=message_obj.message_segments)
else:
content_seg = Seg(type='text', data='[空消息]')
else:
# 处理普通消息
if raw_message:
from .message import Message
message_obj = Message(
from .message_cq import MessageRecvCQ
message_obj = MessageRecvCQ(
user_id=msg.get('user_id', 0),
message_id=msg.get('message_id', 0),
raw_message=raw_message,
plain_text=raw_message,
group_id=msg.get('group_id', 0)
)
await message_obj.initialize()
content = message_obj.processed_plain_text
content_seg = Seg(type='seglist', data=message_obj.message_segments)
else:
content = '[空消息]'
content_seg = Seg(type='text', data='[空消息]')
formatted_msg = f"{nickname}: {content}"
formatted_messages.append(formatted_msg)
formatted_segments.append(Seg(type='text', data=f"{nickname}: "))
formatted_segments.append(content_seg)
formatted_segments.append(Seg(type='text', data='\n'))
# 合并所有消息
combined_messages = '\n'.join(formatted_messages)
print(f"\033[1;34m[调试信息]\033[0m 合并后的转发消息: {combined_messages}")
return f"[转发消息:\n{combined_messages}]"
return formatted_segments
except Exception as e:
print(f"\033[1;31m[错误]\033[0m 处理转发消息失败: {str(e)}")
return '[转发消息]'
logger.error(f"处理转发消息失败: {str(e)}")
return None
async def translate_reply(self) -> str:
"""处理回复类型的CQ码"""
# 创建Message对象
from .message import Message
if self.reply_message == None:
# print(f"\033[1;31m[错误]\033[0m 回复消息为空")
return '[回复某人消息]'
def translate_reply(self) -> Optional[List[Seg]]:
"""处理回复类型的CQ码返回Seg列表"""
from .message_cq import MessageRecvCQ
if self.reply_message is None:
return None
if self.reply_message.sender.user_id:
message_obj = Message(
message_obj = MessageRecvCQ(
user_id=self.reply_message.sender.user_id,
message_id=self.reply_message.message_id,
raw_message=str(self.reply_message.message),
group_id=self.group_id
)
await message_obj.initialize()
segments = []
if message_obj.user_id == global_config.BOT_QQ:
return f"[回复 {global_config.BOT_NICKNAME} 的消息: {message_obj.processed_plain_text}]"
segments.append(Seg(type='text', data=f"[回复 {global_config.BOT_NICKNAME} 的消息: "))
else:
return f"[回复 {self.reply_message.sender.nickname} 的消息: {message_obj.processed_plain_text}]"
segments.append(Seg(type='text', data=f"[回复 {self.reply_message.sender.nickname} 的消息: "))
segments.append(Seg(type='seglist', data=message_obj.message_segments))
segments.append(Seg(type='text', data="]"))
return segments
else:
print("\033[1;31m[错误]\033[0m 回复消息的sender.user_id为空")
return '[回复某人消息]'
return None
@staticmethod
def unescape(text: str) -> str:
@@ -318,29 +306,12 @@ class CQCode:
.replace('&#93;', ']') \
.replace('&amp;', '&')
@staticmethod
def create_emoji_cq(file_path: str) -> str:
"""
创建表情包CQ码
Args:
file_path: 本地表情包文件路径
Returns:
表情包CQ码字符串
"""
# 确保使用绝对路径
abs_path = os.path.abspath(file_path)
# 转义特殊字符
escaped_path = abs_path.replace('&', '&amp;') \
.replace('[', '&#91;') \
.replace(']', '&#93;') \
.replace(',', '&#44;')
# 生成CQ码设置sub_type=1表示这是表情包
return f"[CQ:image,file=file:///{escaped_path},sub_type=1]"
class CQCode_tool:
@staticmethod
async def cq_from_dict_to_class(cq_code: Dict, reply: Optional[Dict] = None) -> CQCode:
def cq_from_dict_to_class(cq_code: Dict, reply: Optional[Dict] = None) -> CQCode:
"""
将CQ码字典转换为CQCode对象
@@ -369,7 +340,7 @@ class CQCode_tool:
)
# 进行翻译处理
await instance.translate()
instance.translate()
return instance
@staticmethod
@@ -382,6 +353,27 @@ class CQCode_tool:
回复CQ码字符串
"""
return f"[CQ:reply,id={message_id}]"
@staticmethod
def create_emoji_cq(file_path: str) -> str:
"""
创建表情包CQ码
Args:
file_path: 本地表情包文件路径
Returns:
表情包CQ码字符串
"""
# 确保使用绝对路径
abs_path = os.path.abspath(file_path)
# 转义特殊字符
escaped_path = abs_path.replace('&', '&amp;') \
.replace('[', '&#91;') \
.replace(']', '&#93;') \
.replace(',', '&#44;')
# 生成CQ码设置sub_type=1表示这是表情包
return f"[CQ:image,file=file:///{escaped_path},sub_type=1]"
cq_code_tool = CQCode_tool()

View File

@@ -4,6 +4,8 @@ import random
import time
import traceback
from typing import Optional
import base64
import hashlib
from loguru import logger
from nonebot import get_driver
@@ -13,9 +15,11 @@ from ..chat.config import global_config
from ..chat.utils import get_embedding
from ..chat.utils_image import image_path_to_base64
from ..models.utils_model import LLM_request
from ..chat.utils_image import ImageManager
driver = get_driver()
config = driver.config
image_manager = ImageManager()
class EmojiManager:
@@ -142,14 +146,14 @@ class EmojiManager:
emoji_similarities.sort(key=lambda x: x[1], reverse=True)
# 获取前3个最相似的表情包
top_3_emojis = emoji_similarities[:3]
top_10_emojis = emoji_similarities[:10 if len(emoji_similarities) > 10 else len(emoji_similarities)]
if not top_3_emojis:
if not top_10_emojis:
logger.warning("未找到匹配的表情包")
return None
# 从前3个中随机选择一个
selected_emoji, similarity = random.choice(top_3_emojis)
selected_emoji, similarity = random.choice(top_10_emojis)
if selected_emoji and 'path' in selected_emoji:
# 更新使用次数
@@ -172,13 +176,13 @@ class EmojiManager:
return None
async def _get_emoji_discription(self, image_base64: str) -> str:
"""获取表情包的标签"""
"""获取表情包的标签使用image_manager的描述生成功能"""
try:
prompt = '这是一个表情包,使用中文简洁的描述一下表情包的内容和表情包所表达的情感'
content, _ = await self.vlm.generate_response_for_image(prompt, image_base64)
logger.debug(f"输出描述: {content}")
return content
# 使用image_manager获取描述去掉前后的方括号和"表情包:"前缀
description = await image_manager.get_emoji_description(image_base64)
# 去掉[表情包xxx]的格式,只保留描述内容
description = description.strip('[]').replace('表情包:', '')
return description
except Exception as e:
logger.error(f"获取标签失败: {str(e)}")
@@ -220,42 +224,94 @@ class EmojiManager:
for filename in files_to_process:
image_path = os.path.join(emoji_dir, filename)
# 检查是否已经注册过
existing_emoji = self.db.db['emoji'].find_one({'filename': filename})
if existing_emoji:
continue
# 压缩图片并获取base64编码
# 获取图片的base64编码和哈希值
image_base64 = image_path_to_base64(image_path)
if image_base64 is None:
os.remove(image_path)
continue
# 获取表情包的描述
discription = await self._get_emoji_discription(image_base64)
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
# 检查是否已经注册过
existing_emoji = self.db.db['emoji'].find_one({'filename': filename})
description = None
if existing_emoji:
# 即使表情包已存在也检查是否需要同步到images集合
description = existing_emoji.get('discription')
# 检查是否在images集合中存在
existing_image = await image_manager.db.db.images.find_one({'hash': image_hash})
if not existing_image:
# 同步到images集合
image_doc = {
'hash': image_hash,
'path': image_path,
'type': 'emoji',
'description': description,
'timestamp': int(time.time())
}
await image_manager.db.db.images.update_one(
{'hash': image_hash},
{'$set': image_doc},
upsert=True
)
# 保存描述到image_descriptions集合
await image_manager._save_description_to_db(image_hash, description, 'emoji')
logger.success(f"同步已存在的表情包到images集合: {filename}")
continue
# 检查是否在images集合中已有描述
existing_description = await image_manager._get_description_from_db(image_hash, 'emoji')
if existing_description:
description = existing_description
else:
# 获取表情包的描述
description = await self._get_emoji_discription(image_base64)
if global_config.EMOJI_CHECK:
check = await self._check_emoji(image_base64)
if '' not in check:
os.remove(image_path)
logger.info(f"描述: {discription}")
logger.info(f"描述: {description}")
logger.info(f"其不满足过滤规则,被剔除 {check}")
continue
logger.info(f"check通过 {check}")
embedding = await get_embedding(discription)
if discription is not None:
if description is not None:
embedding = await get_embedding(description)
# 准备数据库记录
emoji_record = {
'filename': filename,
'path': image_path,
'embedding':embedding,
'discription': discription,
'embedding': embedding,
'discription': description,
'hash': image_hash,
'timestamp': int(time.time())
}
# 保存到数据库
# 保存到emoji数据库
self.db.db['emoji'].insert_one(emoji_record)
logger.success(f"注册新表情包: {filename}")
logger.info(f"描述: {discription}")
logger.info(f"描述: {description}")
# 保存到images数据库
image_doc = {
'hash': image_hash,
'path': image_path,
'type': 'emoji',
'description': description,
'timestamp': int(time.time())
}
await image_manager.db.db.images.update_one(
{'hash': image_hash},
{'$set': image_doc},
upsert=True
)
# 保存描述到image_descriptions集合
await image_manager._save_description_to_db(image_hash, description, 'emoji')
logger.success(f"同步保存到images集合: {filename}")
else:
logger.warning(f"跳过表情包: {filename}")

View File

@@ -7,7 +7,7 @@ from nonebot import get_driver
from ...common.database import Database
from ..models.utils_model import LLM_request
from .config import global_config
from .message import Message
from .message_cq import Message
from .prompt_builder import prompt_builder
from .relationship_manager import relationship_manager
from .utils import process_llm_response

View File

@@ -1,231 +1,360 @@
import time
from dataclasses import dataclass
from typing import Dict, ForwardRef, List, Optional
from typing import Dict, ForwardRef, List, Optional, Union
import urllib3
from loguru import logger
from .cq_code import CQCode, cq_code_tool
from .utils_cq import parse_cq_code
from .utils_user import get_groupname, get_user_cardname, get_user_nickname
Message = ForwardRef('Message') # 添加这行
from .utils_image import image_manager
from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase
# 禁用SSL警告
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
#这个类是消息数据类,用于存储和管理消息数据。
#它定义了消息的属性包括群组ID、用户ID、消息ID、原始消息内容、纯文本内容和时间戳。
#它还定义了两个辅助属性keywords用于提取消息的关键词is_plain_text用于判断消息是否为纯文本。
@dataclass
class MessageRecv(MessageBase):
"""接收消息类用于处理从MessageCQ序列化的消息"""
def __init__(self, message_dict: Dict):
"""从MessageCQ的字典初始化
Args:
message_dict: MessageCQ序列化后的字典
"""
message_info = BaseMessageInfo(**message_dict.get('message_info', {}))
message_segment = Seg(**message_dict.get('message_segment', {}))
raw_message = message_dict.get('raw_message')
super().__init__(
message_info=message_info,
message_segment=message_segment,
raw_message=raw_message
)
# 处理消息内容
self.processed_plain_text = "" # 初始化为空字符串
self.detailed_plain_text = "" # 初始化为空字符串
async def process(self) -> None:
"""处理消息内容,生成纯文本和详细文本
这个方法必须在创建实例后显式调用,因为它包含异步操作。
"""
self.processed_plain_text = await self._process_message_segments(self.message_segment)
self.detailed_plain_text = self._generate_detailed_text()
async def _process_message_segments(self, segment: Seg) -> str:
"""递归处理消息段,转换为文字描述
Args:
segment: 要处理的消息段
Returns:
str: 处理后的文本
"""
if segment.type == 'seglist':
# 处理消息段列表
segments_text = []
for seg in segment.data:
processed = await self._process_message_segments(seg)
if processed:
segments_text.append(processed)
return ' '.join(segments_text)
else:
# 处理单个消息段
return await self._process_single_segment(segment)
async def _process_single_segment(self, seg: Seg) -> str:
"""处理单个消息段
Args:
seg: 要处理的消息段
Returns:
str: 处理后的文本
"""
try:
if seg.type == 'text':
return seg.data
elif seg.type == 'image':
# 如果是base64图片数据
if isinstance(seg.data, str) and seg.data.startswith(('data:', 'base64:')):
return await image_manager.get_image_description(seg.data)
return '[图片]'
elif seg.type == 'emoji':
if isinstance(seg.data, str) and seg.data.startswith(('data:', 'base64:')):
return await image_manager.get_emoji_description(seg.data)
return '[表情]'
else:
return f"[{seg.type}:{str(seg.data)}]"
except Exception as e:
logger.error(f"处理消息段失败: {str(e)}, 类型: {seg.type}, 数据: {seg.data}")
return f"[处理失败的{seg.type}消息]"
def _generate_detailed_text(self) -> str:
"""生成详细文本,包含时间和用户信息"""
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.message_info.time))
user_info = self.message_info.user_info
name = (
f"{user_info.user_nickname}(ta的昵称:{user_info.user_cardname},ta的id:{user_info.user_id})"
if user_info.user_cardname!=''
else f"{user_info.user_nickname}(ta的id:{user_info.user_id})"
)
return f"[{time_str}] {name}: {self.processed_plain_text}\n"
@dataclass
class Message:
"""消息数据类"""
message_id: int = None
time: float = None
group_id: int = None
group_name: str = None # 群名称
user_id: int = None
user_nickname: str = None # 用户昵称
user_cardname: str = None # 用户群昵称
raw_message: str = None # 原始消息包含未解析的cq码
plain_text: str = None # 纯文本
reply_message: Dict = None # 存储 回复的 源消息
# 延迟初始化字段
_initialized: bool = False
message_segments: List[Dict] = None # 存储解析后的消息片段
processed_plain_text: str = None # 用于存储处理后的plain_text
detailed_plain_text: str = None # 用于存储详细可读文本
# 状态标志
is_emoji: bool = False
has_emoji: bool = False
translate_cq: bool = True
async def initialize(self):
"""显式异步初始化方法(必须调用)"""
if self._initialized:
return
# 异步获取补充信息
self.group_name = self.group_name or get_groupname(self.group_id)
self.user_nickname = self.user_nickname or get_user_nickname(self.user_id)
self.user_cardname = self.user_cardname or get_user_cardname(self.user_id)
# 消息解析
if self.raw_message:
if not isinstance(self,Message_Sending):
self.message_segments = await self.parse_message_segments(self.raw_message)
self.processed_plain_text = ' '.join(
seg.translated_plain_text
for seg in self.message_segments
)
# 构建详细文本
if self.time is None:
self.time = int(time.time())
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.time))
name = (
f"{self.user_nickname}(ta的昵称:{self.user_cardname},ta的id:{self.user_id})"
if self.user_cardname
else f"{self.user_nickname or f'用户{self.user_id}'}"
)
if isinstance(self,Message_Sending) and self.is_emoji:
self.detailed_plain_text = f"[{time_str}] {name}: {self.detailed_plain_text}\n"
else:
self.detailed_plain_text = f"[{time_str}] {name}: {self.processed_plain_text}\n"
self._initialized = True
class MessageProcessBase(MessageBase):
"""消息处理基类,用于处理中和发送中的消息"""
async def parse_message_segments(self, message: str) -> List[CQCode]:
"""
将消息解析为片段列表包括纯文本和CQ码
返回的列表中每个元素都是字典,包含:
- cq_code_list:分割出的聊天对象包括文本和CQ码
- trans_list:翻译后的对象列表
"""
# print(f"\033[1;34m[调试信息]\033[0m 正在处理消息: {message}")
cq_code_dict_list = []
trans_list = []
start = 0
while True:
# 查找下一个CQ码的开始位置
cq_start = message.find('[CQ:', start)
#如果没有cq码直接返回文本内容
if cq_start == -1:
# 如果没有找到更多CQ码添加剩余文本
if start < len(message):
text = message[start:].strip()
if text: # 只添加非空文本
cq_code_dict_list.append(parse_cq_code(text))
break
# 添加CQ码前的文本
if cq_start > start:
text = message[start:cq_start].strip()
if text: # 只添加非空文本
cq_code_dict_list.append(parse_cq_code(text))
# 查找CQ码的结束位置
cq_end = message.find(']', cq_start)
if cq_end == -1:
# CQ码未闭合作为普通文本处理
text = message[cq_start:].strip()
if text:
cq_code_dict_list.append(parse_cq_code(text))
break
cq_code = message[cq_start:cq_end + 1]
#将cq_code解析成字典
cq_code_dict_list.append(parse_cq_code(cq_code))
# 更新start位置到当前CQ码之后
start = cq_end + 1
# print(f"\033[1;34m[调试信息]\033[0m 提取的消息对象:列表: {cq_code_dict_list}")
#判定是否是表情包消息,以及是否含有表情包
if len(cq_code_dict_list) == 1 and cq_code_dict_list[0]['type'] == 'image':
self.is_emoji = True
self.has_emoji_emoji = True
else:
for segment in cq_code_dict_list:
if segment['type'] == 'image' and segment['data'].get('sub_type') == '1':
self.has_emoji_emoji = True
break
#翻译作为字典的CQ码
for _code_item in cq_code_dict_list:
message_obj = await cq_code_tool.cq_from_dict_to_class(_code_item,reply = self.reply_message)
trans_list.append(message_obj)
return trans_list
def __init__(
self,
message_id: str,
user_id: int,
group_id: Optional[int] = None,
platform: str = "qq",
message_segment: Optional[Seg] = None,
reply: Optional['MessageRecv'] = None
):
# 构造用户信息
user_info = UserInfo(
platform=platform,
user_id=user_id,
user_nickname=get_user_nickname(user_id),
user_cardname=get_user_cardname(user_id) if group_id else None
)
class Message_Thinking:
"""消息思考类"""
def __init__(self, message: Message,message_id: str):
# 复制原始消息的基本属性
self.group_id = message.group_id
self.user_id = message.user_id
self.user_nickname = message.user_nickname
self.user_cardname = message.user_cardname
self.group_name = message.group_name
self.message_id = message_id
# 思考状态相关属性
# 构造群组信息(如果有)
group_info = None
if group_id:
group_info = GroupInfo(
platform=platform,
group_id=group_id,
group_name=get_groupname(group_id)
)
# 构造基础消息信息
message_info = BaseMessageInfo(
platform=platform,
message_id=message_id,
time=int(time.time()),
group_info=group_info,
user_info=user_info
)
# 调用父类初始化
super().__init__(
message_info=message_info,
message_segment=message_segment,
raw_message=None
)
# 处理状态相关属性
self.thinking_start_time = int(time.time())
self.thinking_time = 0
self.interupt=False
def update_thinking_time(self):
self.thinking_time = round(time.time(), 2) - self.thinking_start_time
# 文本处理相关属性
self.processed_plain_text = ""
self.detailed_plain_text = ""
# 回复消息
self.reply = reply
@dataclass
class Message_Sending(Message):
"""发送中的消息类"""
thinking_start_time: float = None # 思考开始时间
thinking_time: float = None # 思考时间
reply_message_id: int = None # 存储 回复的 源消息ID
is_head: bool = False # 是否是头部消息
def update_thinking_time(self):
self.thinking_time = round(time.time(), 2) - self.thinking_start_time
def update_thinking_time(self) -> float:
"""更新思考时间"""
self.thinking_time = round(time.time() - self.thinking_start_time, 2)
return self.thinking_time
async def _process_message_segments(self, segment: Seg) -> str:
"""递归处理消息段,转换为文字描述
Args:
segment: 要处理的消息段
Returns:
str: 处理后的文本
"""
if segment.type == 'seglist':
# 处理消息段列表
segments_text = []
for seg in segment.data:
processed = await self._process_message_segments(seg)
if processed:
segments_text.append(processed)
return ' '.join(segments_text)
else:
# 处理单个消息段
return await self._process_single_segment(segment)
async def _process_single_segment(self, seg: Seg) -> str:
"""处理单个消息段
Args:
seg: 要处理的消息段
Returns:
str: 处理后的文本
"""
try:
if seg.type == 'text':
return seg.data
elif seg.type == 'image':
# 如果是base64图片数据
if isinstance(seg.data, str) and seg.data.startswith(('data:', 'base64:')):
return await image_manager.get_image_description(seg.data)
return '[图片]'
elif seg.type == 'emoji':
if isinstance(seg.data, str) and seg.data.startswith(('data:', 'base64:')):
return await image_manager.get_emoji_description(seg.data)
return '[表情]'
elif seg.type == 'at':
return f"[@{seg.data}]"
elif seg.type == 'reply':
if self.reply and hasattr(self.reply, 'processed_plain_text'):
return f"[回复:{self.reply.processed_plain_text}]"
else:
return f"[{seg.type}:{str(seg.data)}]"
except Exception as e:
logger.error(f"处理消息段失败: {str(e)}, 类型: {seg.type}, 数据: {seg.data}")
return f"[处理失败的{seg.type}消息]"
def _generate_detailed_text(self) -> str:
"""生成详细文本,包含时间和用户信息"""
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.message_info.time))
user_info = self.message_info.user_info
name = (
f"{user_info.user_nickname}(ta的昵称:{user_info.user_cardname},ta的id:{user_info.user_id})"
if user_info.user_cardname != ''
else f"{user_info.user_nickname}(ta的id:{user_info.user_id})"
)
return f"[{time_str}] {name}: {self.processed_plain_text}\n"
@dataclass
class MessageThinking(MessageProcessBase):
"""思考状态的消息类"""
def __init__(
self,
message_id: str,
user_id: int,
group_id: Optional[int] = None,
platform: str = "qq",
reply: Optional['MessageRecv'] = None
):
# 调用父类初始化
super().__init__(
message_id=message_id,
user_id=user_id,
group_id=group_id,
platform=platform,
message_segment=None, # 思考状态不需要消息段
reply=reply
)
# 思考状态特有属性
self.interrupt = False
@dataclass
class MessageSending(MessageProcessBase):
"""发送状态的消息类"""
def __init__(
self,
message_id: str,
user_id: int,
message_segment: Seg,
group_id: Optional[int] = None,
reply: Optional['MessageRecv'] = None,
platform: str = "qq",
is_head: bool = False
):
# 调用父类初始化
super().__init__(
message_id=message_id,
user_id=user_id,
group_id=group_id,
platform=platform,
message_segment=message_segment,
reply=reply
)
# 发送状态特有属性
self.reply_to_message_id = reply.message_info.message_id if reply else None
self.is_head = is_head
async def process(self) -> None:
"""处理消息内容,生成纯文本和详细文本"""
if self.message_segment:
self.processed_plain_text = await self._process_message_segments(self.message_segment)
self.detailed_plain_text = self._generate_detailed_text()
@classmethod
def from_thinking(
cls,
thinking: MessageThinking,
message_segment: Seg,
reply: Optional['MessageRecv'] = None,
is_head: bool = False
) -> 'MessageSending':
"""从思考状态消息创建发送状态消息"""
return cls(
message_id=thinking.message_info.message_id,
user_id=thinking.message_info.user_info.user_id,
message_segment=message_segment,
group_id=thinking.message_info.group_info.group_id if thinking.message_info.group_info else None,
reply=reply or thinking.reply,
platform=thinking.message_info.platform,
is_head=is_head
)
@dataclass
class MessageSet:
"""消息集合类,可以存储多个发送消息"""
def __init__(self, group_id: int, user_id: int, message_id: str):
self.group_id = group_id
self.user_id = user_id
self.message_id = message_id
self.messages: List[Message_Sending] = [] # 修改类型标注
self.messages: List[MessageSending] = []
self.time = round(time.time(), 2)
def add_message(self, message: Message_Sending) -> None:
"""添加消息到集合只接受Message_Sending类型"""
if not isinstance(message, Message_Sending):
raise TypeError("MessageSet只能添加Message_Sending类型的消息")
def add_message(self, message: MessageSending) -> None:
"""添加消息到集合"""
if not isinstance(message, MessageSending):
raise TypeError("MessageSet只能添加MessageSending类型的消息")
self.messages.append(message)
# 按时间排序
self.messages.sort(key=lambda x: x.time)
self.messages.sort(key=lambda x: x.message_info.time)
def get_message_by_index(self, index: int) -> Optional[Message_Sending]:
def get_message_by_index(self, index: int) -> Optional[MessageSending]:
"""通过索引获取消息"""
if 0 <= index < len(self.messages):
return self.messages[index]
return None
def get_message_by_time(self, target_time: float) -> Optional[Message_Sending]:
def get_message_by_time(self, target_time: float) -> Optional[MessageSending]:
"""获取最接近指定时间的消息"""
if not self.messages:
return None
# 使用二分查找找到最接近的消息
left, right = 0, len(self.messages) - 1
while left < right:
mid = (left + right) // 2
if self.messages[mid].time < target_time:
if self.messages[mid].message_info.time < target_time:
left = mid + 1
else:
right = mid
return self.messages[left]
def clear_messages(self) -> None:
"""清空所有消息"""
self.messages.clear()
def remove_message(self, message: Message_Sending) -> bool:
def remove_message(self, message: MessageSending) -> bool:
"""移除指定消息"""
if message in self.messages:
self.messages.remove(message)

View File

@@ -0,0 +1,158 @@
from dataclasses import dataclass, asdict
from typing import List, Optional, Union, Any, Dict
@dataclass
class Seg(dict):
"""消息片段类,用于表示消息的不同部分
Attributes:
type: 片段类型,可以是 'text''image''seglist'
data: 片段的具体内容
- 对于 text 类型data 是字符串
- 对于 image 类型data 是 base64 字符串
- 对于 seglist 类型data 是 Seg 列表
translated_data: 经过翻译处理的数据(可选)
"""
type: str
data: Union[str, List['Seg']]
translated_data: Optional[str] = None
def __init__(self, type: str, data: Union[str, List['Seg']], translated_data: Optional[str] = None):
"""初始化实例,确保字典和属性同步"""
# 先初始化字典
super().__init__(type=type, data=data)
if translated_data is not None:
self['translated_data'] = translated_data
# 再初始化属性
object.__setattr__(self, 'type', type)
object.__setattr__(self, 'data', data)
object.__setattr__(self, 'translated_data', translated_data)
# 验证数据类型
self._validate_data()
def _validate_data(self) -> None:
"""验证数据类型的正确性"""
if self.type == 'seglist' and not isinstance(self.data, list):
raise ValueError("seglist类型的data必须是列表")
elif self.type == 'text' and not isinstance(self.data, str):
raise ValueError("text类型的data必须是字符串")
elif self.type == 'image' and not isinstance(self.data, str):
raise ValueError("image类型的data必须是字符串")
def __setattr__(self, name: str, value: Any) -> None:
"""重写属性设置,同时更新字典值"""
# 更新属性
object.__setattr__(self, name, value)
# 同步更新字典
if name in ['type', 'data', 'translated_data']:
self[name] = value
def __setitem__(self, key: str, value: Any) -> None:
"""重写字典值设置,同时更新属性"""
# 更新字典
super().__setitem__(key, value)
# 同步更新属性
if key in ['type', 'data', 'translated_data']:
object.__setattr__(self, key, value)
def to_dict(self) -> Dict:
"""转换为字典格式"""
result = {'type': self.type}
if self.type == 'seglist':
result['data'] = [seg.to_dict() for seg in self.data]
else:
result['data'] = self.data
if self.translated_data is not None:
result['translated_data'] = self.translated_data
return result
@dataclass
class GroupInfo:
"""群组信息类"""
platform: Optional[str] = None
group_id: Optional[int] = None
group_name: Optional[str] = None # 群名称
def to_dict(self) -> Dict:
"""转换为字典格式"""
return {k: v for k, v in asdict(self).items() if v is not None}
@dataclass
class UserInfo:
"""用户信息类"""
platform: Optional[str] = None
user_id: Optional[int] = None
user_nickname: Optional[str] = None # 用户昵称
user_cardname: Optional[str] = None # 用户群昵称
def to_dict(self) -> Dict:
"""转换为字典格式"""
return {k: v for k, v in asdict(self).items() if v is not None}
@dataclass
class BaseMessageInfo:
"""消息信息类"""
platform: Optional[str] = None
message_id: Optional[int,str] = None
time: Optional[int] = None
group_info: Optional[GroupInfo] = None
user_info: Optional[UserInfo] = None
def to_dict(self) -> Dict:
"""转换为字典格式"""
result = {}
for field, value in asdict(self).items():
if value is not None:
if isinstance(value, (GroupInfo, UserInfo)):
result[field] = value.to_dict()
else:
result[field] = value
return result
@dataclass
class MessageBase:
"""消息类"""
message_info: BaseMessageInfo
message_segment: Seg
raw_message: Optional[str] = None # 原始消息包含未解析的cq码
def to_dict(self) -> Dict:
"""转换为字典格式
Returns:
Dict: 包含所有非None字段的字典其中
- message_info: 转换为字典格式
- message_segment: 转换为字典格式
- raw_message: 如果存在则包含
"""
result = {
'message_info': self.message_info.to_dict(),
'message_segment': self.message_segment.to_dict()
}
if self.raw_message is not None:
result['raw_message'] = self.raw_message
return result
@classmethod
def from_dict(cls, data: Dict) -> 'MessageBase':
"""从字典创建MessageBase实例
Args:
data: 包含必要字段的字典
Returns:
MessageBase: 新的实例
"""
message_info = BaseMessageInfo(**data.get('message_info', {}))
message_segment = Seg(**data.get('message_segment', {}))
raw_message = data.get('raw_message')
return cls(
message_info=message_info,
message_segment=message_segment,
raw_message=raw_message
)

View File

@@ -0,0 +1,188 @@
import time
from dataclasses import dataclass
from typing import Dict, ForwardRef, List, Optional, Union
import urllib3
from .cq_code import CQCode, cq_code_tool
from .utils_cq import parse_cq_code
from .utils_user import get_groupname, get_user_cardname, get_user_nickname
from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase
# 禁用SSL警告
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
#这个类是消息数据类,用于存储和管理消息数据。
#它定义了消息的属性包括群组ID、用户ID、消息ID、原始消息内容、纯文本内容和时间戳。
#它还定义了两个辅助属性keywords用于提取消息的关键词is_plain_text用于判断消息是否为纯文本。
@dataclass
class MessageCQ(MessageBase):
"""QQ消息基类继承自MessageBase
最小必要参数:
- message_id: 消息ID
- user_id: 发送者/接收者ID
- platform: 平台标识(默认为"qq"
"""
def __init__(
self,
message_id: int,
user_id: int,
group_id: Optional[int] = None,
platform: str = "qq"
):
# 构造用户信息
user_info = UserInfo(
platform=platform,
user_id=user_id,
user_nickname=get_user_nickname(user_id),
user_cardname=get_user_cardname(user_id) if group_id else None
)
# 构造群组信息(如果有)
group_info = None
if group_id:
group_info = GroupInfo(
platform=platform,
group_id=group_id,
group_name=get_groupname(group_id)
)
# 构造基础消息信息
message_info = BaseMessageInfo(
platform=platform,
message_id=message_id,
time=int(time.time()),
group_info=group_info,
user_info=user_info
)
# 调用父类初始化message_segment 由子类设置
super().__init__(
message_info=message_info,
message_segment=None,
raw_message=None
)
@dataclass
class MessageRecvCQ(MessageCQ):
"""QQ接收消息类用于解析raw_message到Seg对象"""
def __init__(
self,
message_id: int,
user_id: int,
raw_message: str,
group_id: Optional[int] = None,
reply_message: Optional[Dict] = None,
platform: str = "qq"
):
# 调用父类初始化
super().__init__(message_id, user_id, group_id, platform)
# 解析消息段
self.message_segment = self._parse_message(raw_message, reply_message)
self.raw_message = raw_message
def _parse_message(self, message: str, reply_message: Optional[Dict] = None) -> Seg:
"""解析消息内容为Seg对象"""
cq_code_dict_list = []
segments = []
start = 0
while True:
cq_start = message.find('[CQ:', start)
if cq_start == -1:
if start < len(message):
text = message[start:].strip()
if text:
cq_code_dict_list.append(parse_cq_code(text))
break
if cq_start > start:
text = message[start:cq_start].strip()
if text:
cq_code_dict_list.append(parse_cq_code(text))
cq_end = message.find(']', cq_start)
if cq_end == -1:
text = message[cq_start:].strip()
if text:
cq_code_dict_list.append(parse_cq_code(text))
break
cq_code = message[cq_start:cq_end + 1]
cq_code_dict_list.append(parse_cq_code(cq_code))
start = cq_end + 1
# 转换CQ码为Seg对象
for code_item in cq_code_dict_list:
message_obj = cq_code_tool.cq_from_dict_to_class(code_item, reply=reply_message)
if message_obj.translated_segments:
segments.append(message_obj.translated_segments)
# 如果只有一个segment直接返回
if len(segments) == 1:
return segments[0]
# 否则返回seglist类型的Seg
return Seg(type='seglist', data=segments)
def to_dict(self) -> Dict:
"""转换为字典格式,包含所有必要信息"""
base_dict = super().to_dict()
return base_dict
@dataclass
class MessageSendCQ(MessageCQ):
"""QQ发送消息类用于将Seg对象转换为raw_message"""
def __init__(
self,
message_id: int,
user_id: int,
message_segment: Seg,
group_id: Optional[int] = None,
reply_to_message_id: Optional[int] = None,
platform: str = "qq"
):
# 调用父类初始化
super().__init__(message_id, user_id, group_id, platform)
self.message_segment = message_segment
self.raw_message = self._generate_raw_message(reply_to_message_id)
def _generate_raw_message(self, reply_to_message_id: Optional[int] = None) -> str:
"""将Seg对象转换为raw_message"""
segments = []
# 添加回复消息
if reply_to_message_id:
segments.append(cq_code_tool.create_reply_cq(reply_to_message_id))
# 处理消息段
if self.message_segment.type == 'seglist':
for seg in self.message_segment.data:
segments.append(self._seg_to_cq_code(seg))
else:
segments.append(self._seg_to_cq_code(self.message_segment))
return ''.join(segments)
def _seg_to_cq_code(self, seg: Seg) -> str:
"""将单个Seg对象转换为CQ码字符串"""
if seg.type == 'text':
return str(seg.data)
elif seg.type == 'image':
# 如果是base64图片数据
if seg.data.startswith(('data:', 'base64:')):
return f"[CQ:image,file=base64://{seg.data}]"
# 如果是表情包(本地文件)
return cq_code_tool.create_emoji_cq(seg.data)
elif seg.type == 'at':
return f"[CQ:at,qq={seg.data}]"
elif seg.type == 'reply':
return cq_code_tool.create_reply_cq(int(seg.data))
else:
return f"[{seg.data}]"

View File

@@ -5,7 +5,7 @@ from typing import Dict, List, Optional, Union
from nonebot.adapters.onebot.v11 import Bot
from .cq_code import cq_code_tool
from .message import Message, Message_Sending, Message_Thinking, MessageSet
from .message_cq import Message, Message_Sending, Message_Thinking, MessageSet
from .storage import MessageStorage
from .utils import calculate_typing_time
from .config import global_config

View File

@@ -1,7 +1,7 @@
from typing import Optional
from ...common.database import Database
from .message import Message
from .message_cq import Message
class MessageStorage:

View File

@@ -11,31 +11,12 @@ from nonebot import get_driver
from ..models.utils_model import LLM_request
from ..utils.typo_generator import ChineseTypoGenerator
from .config import global_config
from .message import Message
from .message_cq import Message
driver = get_driver()
config = driver.config
def combine_messages(messages: List[Message]) -> str:
"""将消息列表组合成格式化的字符串
Args:
messages: Message对象列表
Returns:
str: 格式化后的消息字符串
"""
result = ""
for message in messages:
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message.time))
name = message.user_nickname or f"用户{message.user_id}"
content = message.processed_plain_text or message.plain_text
result += f"[{time_str}] {name}: {content}\n"
return result
def db_message_to_str(message_dict: Dict) -> str:
print(f"message_dict: {message_dict}")
@@ -159,7 +140,7 @@ async def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
return []
# 转换为 Message对象列表
from .message import Message
from .message_cq import Message
message_objects = []
for msg_data in recent_messages:
try:

View File

@@ -2,7 +2,11 @@ import base64
import io
import os
import time
import zlib # 用于 CRC32
import zlib
import aiohttp
import hashlib
from typing import Optional, Tuple, Union
from urllib.parse import urlparse
from loguru import logger
from nonebot import get_driver
@@ -10,213 +14,348 @@ from PIL import Image
from ...common.database import Database
from ..chat.config import global_config
from ..models.utils_model import LLM_request
driver = get_driver()
config = driver.config
def storage_compress_image(base64_data: str, max_size: int = 200) -> str:
"""
压缩base64格式的图片到指定大小单位KB并在数据库中记录图片信息
Args:
base64_data: base64编码的图片数据
max_size: 最大文件大小KB
Returns:
str: 压缩后的base64图片数据
"""
try:
# 将base64转换为字节数据
image_data = base64.b64decode(base64_data)
# 使用 CRC32 计算哈希值
hash_value = format(zlib.crc32(image_data) & 0xFFFFFFFF, 'x')
# 确保图片目录存在
images_dir = "data/images"
os.makedirs(images_dir, exist_ok=True)
# 连接数据库
db = Database(
host=config.mongodb_host,
port=int(config.mongodb_port),
db_name=config.database_name,
username=config.mongodb_username,
password=config.mongodb_password,
auth_source=config.mongodb_auth_source
)
# 检查是否已存在相同哈希值的图片
collection = db.db['images']
existing_image = collection.find_one({'hash': hash_value})
if existing_image:
print(f"\033[1;33m[提示]\033[0m 发现重复图片,使用已存在的文件: {existing_image['path']}")
return base64_data
# 将字节数据转换为图片对象
img = Image.open(io.BytesIO(image_data))
# 如果是动图,直接返回原图
if getattr(img, 'is_animated', False):
return base64_data
# 计算当前大小KB
current_size = len(image_data) / 1024
# 如果已经小于目标大小,直接使用原图
if current_size <= max_size:
compressed_data = image_data
else:
# 压缩逻辑
# 先缩放到50%
new_width = int(img.width * 0.5)
new_height = int(img.height * 0.5)
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
# 如果缩放后的最大边长仍然大于400继续缩放
max_dimension = 400
max_current = max(new_width, new_height)
if max_current > max_dimension:
ratio = max_dimension / max_current
new_width = int(new_width * ratio)
new_height = int(new_height * ratio)
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
# 转换为RGB模式去除透明通道
if img.mode in ('RGBA', 'P'):
img = img.convert('RGB')
# 使用固定质量参数压缩
output = io.BytesIO()
img.save(output, format='JPEG', quality=85, optimize=True)
compressed_data = output.getvalue()
# 生成文件名(使用时间戳和哈希值确保唯一性)
timestamp = int(time.time())
filename = f"{timestamp}_{hash_value}.jpg"
image_path = os.path.join(images_dir, filename)
# 保存文件
with open(image_path, "wb") as f:
f.write(compressed_data)
print(f"\033[1;32m[成功]\033[0m 保存图片到: {image_path}")
try:
# 准备数据库记录
image_record = {
'filename': filename,
'path': image_path,
'size': len(compressed_data) / 1024,
'timestamp': timestamp,
'width': img.width,
'height': img.height,
'description': '',
'tags': [],
'type': 'image',
'hash': hash_value
}
# 保存记录
collection.insert_one(image_record)
print("\033[1;32m[成功]\033[0m 保存图片记录到数据库")
except Exception as db_error:
print(f"\033[1;31m[错误]\033[0m 数据库操作失败: {str(db_error)}")
# 将压缩后的数据转换为base64
compressed_base64 = base64.b64encode(compressed_data).decode('utf-8')
return compressed_base64
except Exception as e:
print(f"\033[1;31m[错误]\033[0m 压缩图片失败: {str(e)}")
import traceback
print(traceback.format_exc())
return base64_data
def storage_emoji(image_data: bytes) -> bytes:
"""
存储表情包到本地文件夹
Args:
image_data: 图片字节数据
group_id: 群组ID仅用于日志
user_id: 用户ID仅用于日志
Returns:
bytes: 原始图片数据
"""
if not global_config.EMOJI_SAVE:
return image_data
try:
# 使用 CRC32 计算哈希值
hash_value = format(zlib.crc32(image_data) & 0xFFFFFFFF, 'x')
# 确保表情包目录存在
emoji_dir = "data/emoji"
os.makedirs(emoji_dir, exist_ok=True)
# 检查是否已存在相同哈希值的文件
for filename in os.listdir(emoji_dir):
if hash_value in filename:
# print(f"\033[1;33m[提示]\033[0m 发现重复表情包: {filename}")
return image_data
# 生成文件名
timestamp = int(time.time())
filename = f"{timestamp}_{hash_value}.jpg"
emoji_path = os.path.join(emoji_dir, filename)
# 直接保存原始文件
with open(emoji_path, "wb") as f:
f.write(image_data)
print(f"\033[1;32m[成功]\033[0m 保存表情包到: {emoji_path}")
return image_data
except Exception as e:
print(f"\033[1;31m[错误]\033[0m 保存表情包失败: {str(e)}")
return image_data
class ImageManager:
_instance = None
IMAGE_DIR = "data" # 图像存储根目录
def storage_image(image_data: bytes) -> bytes:
"""
存储图片到本地文件夹
Args:
image_data: 图片字节数据
group_id: 群组ID仅用于日志
user_id: 用户ID仅用于日志
Returns:
bytes: 原始图片数据
"""
try:
# 使用 CRC32 计算哈希值
hash_value = format(zlib.crc32(image_data) & 0xFFFFFFFF, 'x')
# 确保表情包目录存在
image_dir = "data/image"
os.makedirs(image_dir, exist_ok=True)
# 检查是否已存在相同哈希值的文件
for filename in os.listdir(image_dir):
if hash_value in filename:
# print(f"\033[1;33m[提示]\033[0m 发现重复表情包: {filename}")
return image_data
# 生成文件名
timestamp = int(time.time())
filename = f"{timestamp}_{hash_value}.jpg"
image_path = os.path.join(image_dir, filename)
# 直接保存原始文件
with open(image_path, "wb") as f:
f.write(image_data)
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance.db = None
cls._instance._initialized = False
return cls._instance
def __init__(self):
if not self._initialized:
self.db = Database.get_instance()
self._ensure_image_collection()
self._ensure_description_collection()
self._ensure_image_dir()
self._initialized = True
self._llm = LLM_request(model=global_config.vlm, temperature=0.4, max_tokens=300)
print(f"\033[1;32m[成功]\033[0m 保存图片到: {image_path}")
return image_data
def _ensure_image_dir(self):
"""确保图像存储目录存在"""
os.makedirs(self.IMAGE_DIR, exist_ok=True)
except Exception as e:
print(f"\033[1;31m[错误]\033[0m 保存图片失败: {str(e)}")
return image_data
def _ensure_image_collection(self):
"""确保images集合存在并创建索引"""
if 'images' not in self.db.db.list_collection_names():
self.db.db.create_collection('images')
# 创建索引
self.db.db.images.create_index([('hash', 1)], unique=True)
self.db.db.images.create_index([('url', 1)])
self.db.db.images.create_index([('path', 1)])
def _ensure_description_collection(self):
"""确保image_descriptions集合存在并创建索引"""
if 'image_descriptions' not in self.db.db.list_collection_names():
self.db.db.create_collection('image_descriptions')
# 创建索引
self.db.db.image_descriptions.create_index([('hash', 1)], unique=True)
self.db.db.image_descriptions.create_index([('type', 1)])
async def _get_description_from_db(self, image_hash: str, description_type: str) -> Optional[str]:
"""从数据库获取图片描述
Args:
image_hash: 图片哈希值
description_type: 描述类型 ('emoji''image')
Returns:
Optional[str]: 描述文本如果不存在则返回None
"""
result = await self.db.db.image_descriptions.find_one({
'hash': image_hash,
'type': description_type
})
return result['description'] if result else None
async def _save_description_to_db(self, image_hash: str, description: str, description_type: str) -> None:
"""保存图片描述到数据库
Args:
image_hash: 图片哈希值
description: 描述文本
description_type: 描述类型 ('emoji''image')
"""
await self.db.db.image_descriptions.update_one(
{'hash': image_hash, 'type': description_type},
{
'$set': {
'description': description,
'timestamp': int(time.time())
}
},
upsert=True
)
async def save_image(self,
image_data: Union[str, bytes],
url: str = None,
description: str = None,
is_base64: bool = False) -> Optional[str]:
"""保存图像
Args:
image_data: 图像数据(base64字符串或字节)
url: 图像URL
description: 图像描述
is_base64: image_data是否为base64格式
Returns:
str: 保存后的文件路径,失败返回None
"""
try:
# 转换为字节格式
if is_base64:
if isinstance(image_data, str):
image_bytes = base64.b64decode(image_data)
else:
return None
else:
if isinstance(image_data, bytes):
image_bytes = image_data
else:
return None
# 计算哈希值
image_hash = hashlib.md5(image_bytes).hexdigest()
# 查重
existing = self.db.db.images.find_one({'hash': image_hash})
if existing:
return existing['path']
# 生成文件名和路径
timestamp = int(time.time())
filename = f"{timestamp}_{image_hash[:8]}.jpg"
file_path = os.path.join(self.IMAGE_DIR, filename)
# 保存文件
with open(file_path, "wb") as f:
f.write(image_bytes)
# 保存到数据库
image_doc = {
'hash': image_hash,
'path': file_path,
'url': url,
'description': description,
'timestamp': timestamp
}
self.db.db.images.insert_one(image_doc)
return file_path
except Exception as e:
logger.error(f"保存图像失败: {str(e)}")
return None
async def get_image_by_url(self, url: str) -> Optional[str]:
"""根据URL获取图像路径(带查重)
Args:
url: 图像URL
Returns:
str: 本地文件路径,不存在返回None
"""
try:
# 先查找是否已存在
existing = self.db.db.images.find_one({'url': url})
if existing:
return existing['path']
# 下载图像
async with aiohttp.ClientSession() as session:
async with session.get(url) as resp:
if resp.status == 200:
image_bytes = await resp.read()
return await self.save_image(image_bytes, url=url)
return None
except Exception as e:
logger.error(f"获取图像失败: {str(e)}")
return None
async def get_base64_by_url(self, url: str) -> Optional[str]:
"""根据URL获取base64(带查重)
Args:
url: 图像URL
Returns:
str: base64字符串,失败返回None
"""
try:
image_path = await self.get_image_by_url(url)
if not image_path:
return None
with open(image_path, 'rb') as f:
image_bytes = f.read()
return base64.b64encode(image_bytes).decode('utf-8')
except Exception as e:
logger.error(f"获取base64失败: {str(e)}")
return None
async def save_base64_image(self, base64_str: str, description: str = None) -> Optional[str]:
"""保存base64图像(带查重)
Args:
base64_str: base64字符串
description: 图像描述
Returns:
str: 保存路径,失败返回None
"""
return await self.save_image(base64_str, description=description, is_base64=True)
def check_url_exists(self, url: str) -> bool:
"""检查URL是否已存在
Args:
url: 图像URL
Returns:
bool: 是否存在
"""
return self.db.db.images.find_one({'url': url}) is not None
def check_hash_exists(self, image_data: Union[str, bytes], is_base64: bool = False) -> bool:
"""检查图像是否已存在
Args:
image_data: 图像数据(base64或字节)
is_base64: 是否为base64格式
Returns:
bool: 是否存在
"""
try:
if is_base64:
if isinstance(image_data, str):
image_bytes = base64.b64decode(image_data)
else:
return False
else:
if isinstance(image_data, bytes):
image_bytes = image_data
else:
return False
image_hash = hashlib.md5(image_bytes).hexdigest()
return self.db.db.images.find_one({'hash': image_hash}) is not None
except Exception as e:
logger.error(f"检查哈希失败: {str(e)}")
return False
async def get_emoji_description(self, image_base64: str) -> str:
"""获取表情包描述,带查重和保存功能"""
try:
# 计算图片哈希
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
# 查询缓存的描述
cached_description = await self._get_description_from_db(image_hash, 'emoji')
if cached_description:
return f"[表情包:{cached_description}]"
# 调用AI获取描述
prompt = "这是一个表情包,使用中文简洁的描述一下表情包的内容和表情包所表达的情感"
description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
# 根据配置决定是否保存图片
if global_config.EMOJI_SAVE:
# 生成文件名和路径
timestamp = int(time.time())
filename = f"emoji_{timestamp}_{image_hash[:8]}.jpg"
file_path = os.path.join(self.IMAGE_DIR, filename)
try:
# 保存文件
with open(file_path, "wb") as f:
f.write(image_bytes)
# 保存到数据库
image_doc = {
'hash': image_hash,
'path': file_path,
'type': 'emoji',
'description': description,
'timestamp': timestamp
}
await self.db.db.images.update_one(
{'hash': image_hash},
{'$set': image_doc},
upsert=True
)
logger.success(f"保存表情包: {file_path}")
except Exception as e:
logger.error(f"保存表情包文件失败: {str(e)}")
# 保存描述到数据库
await self._save_description_to_db(image_hash, description, 'emoji')
return f"[表情包:{description}]"
except Exception as e:
logger.error(f"获取表情包描述失败: {str(e)}")
return "[表情包]"
async def get_image_description(self, image_base64: str) -> str:
"""获取普通图片描述,带查重和保存功能"""
try:
# 计算图片哈希
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
# 查询缓存的描述
cached_description = await self._get_description_from_db(image_hash, 'image')
if cached_description:
return f"[图片:{cached_description}]"
# 调用AI获取描述
prompt = "请用中文描述这张图片的内容。如果有文字请把文字都描述出来。并尝试猜测这个图片的含义。最多200个字。"
description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
# 根据配置决定是否保存图片
if global_config.EMOJI_SAVE:
# 生成文件名和路径
timestamp = int(time.time())
filename = f"image_{timestamp}_{image_hash[:8]}.jpg"
file_path = os.path.join(self.IMAGE_DIR, filename)
try:
# 保存文件
with open(file_path, "wb") as f:
f.write(image_bytes)
# 保存到数据库
image_doc = {
'hash': image_hash,
'path': file_path,
'type': 'image',
'description': description,
'timestamp': timestamp
}
await self.db.db.images.update_one(
{'hash': image_hash},
{'$set': image_doc},
upsert=True
)
logger.success(f"保存图片: {file_path}")
except Exception as e:
logger.error(f"保存图片文件失败: {str(e)}")
# 保存描述到数据库
await self._save_description_to_db(image_hash, description, 'image')
return f"[图片:{description}]"
except Exception as e:
logger.error(f"获取图片描述失败: {str(e)}")
return "[图片]"
# 创建全局单例
image_manager = ImageManager()
def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 1024 * 1024) -> str:
"""压缩base64格式的图片到指定大小