feat: 超大型重构

This commit is contained in:
tcmofashi
2025-03-09 10:27:54 +08:00
parent 8754571afb
commit fe3684736a
11 changed files with 1209 additions and 566 deletions

View File

@@ -7,10 +7,10 @@ from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent
from ..memory_system.memory import hippocampus from ..memory_system.memory import hippocampus
from ..moods.moods import MoodManager # 导入情绪管理器 from ..moods.moods import MoodManager # 导入情绪管理器
from .config import global_config from .config import global_config
from .cq_code import CQCode # 导入CQCode模块 from .cq_code import CQCode,cq_code_tool # 导入CQCode模块
from .emoji_manager import emoji_manager # 导入表情包管理器 from .emoji_manager import emoji_manager # 导入表情包管理器
from .llm_generator import ResponseGenerator from .llm_generator import ResponseGenerator
from .message import ( from .message_cq import (
Message, Message,
Message_Sending, Message_Sending,
Message_Thinking, # 导入 Message_Thinking 类 Message_Thinking, # 导入 Message_Thinking 类
@@ -180,7 +180,7 @@ class ChatBot:
if emoji_raw != None: if emoji_raw != None:
emoji_path,discription = emoji_raw emoji_path,discription = emoji_raw
emoji_cq = CQCode.create_emoji_cq(emoji_path) emoji_cq = cq_code_tool.create_emoji_cq(emoji_path)
if random() < 0.5: if random() < 0.5:
bot_response_time = tinking_time_point - 1 bot_response_time = tinking_time_point - 1

View File

@@ -3,7 +3,7 @@ import html
import os import os
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Optional from typing import Dict, Optional, List, Union
import requests import requests
@@ -12,12 +12,14 @@ import requests
import urllib3 import urllib3
from nonebot import get_driver from nonebot import get_driver
from urllib3.util import create_urllib3_context from urllib3.util import create_urllib3_context
from loguru import logger
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
from .config import global_config from .config import global_config
from .mapper import emojimapper from .mapper import emojimapper
from .utils_image import storage_emoji, storage_image from .utils_image import image_manager
from .utils_user import get_user_nickname from .utils_user import get_user_nickname
from .message_base import Seg
driver = get_driver() driver = get_driver()
config = driver.config config = driver.config
@@ -48,16 +50,15 @@ class CQCode:
type: CQ码类型'image', 'at', 'face'等) type: CQ码类型'image', 'at', 'face'等)
params: CQ码的参数字典 params: CQ码的参数字典
raw_code: 原始CQ码字符串 raw_code: 原始CQ码字符串
translated_plain_text: 经过处理如AI翻译后的文本表示 translated_segments: 经过处理后的Seg对象列表
""" """
type: str type: str
params: Dict[str, str] params: Dict[str, str]
# raw_code: str
group_id: int group_id: int
user_id: int user_id: int
group_name: str = "" group_name: str = ""
user_nickname: str = "" user_nickname: str = ""
translated_plain_text: Optional[str] = None translated_segments: Optional[Union[Seg, List[Seg]]] = None
reply_message: Dict = None # 存储回复消息 reply_message: Dict = None # 存储回复消息
image_base64: Optional[str] = None image_base64: Optional[str] = None
_llm: Optional[LLM_request] = None _llm: Optional[LLM_request] = None
@@ -66,31 +67,72 @@ class CQCode:
"""初始化LLM实例""" """初始化LLM实例"""
self._llm = LLM_request(model=global_config.vlm, temperature=0.4, max_tokens=300) self._llm = LLM_request(model=global_config.vlm, temperature=0.4, max_tokens=300)
async def translate(self): def translate(self):
"""根据CQ码类型进行相应的翻译处理""" """根据CQ码类型进行相应的翻译处理转换为Seg对象"""
if self.type == 'text': if self.type == 'text':
self.translated_plain_text = self.params.get('text', '') self.translated_segments = Seg(
type='text',
data=self.params.get('text', '')
)
elif self.type == 'image': elif self.type == 'image':
base64_data = self.translate_image()
if base64_data:
if self.params.get('sub_type') == '0': if self.params.get('sub_type') == '0':
self.translated_plain_text = await self.translate_image() self.translated_segments = Seg(
type='image',
data=base64_data
)
else: else:
self.translated_plain_text = await self.translate_emoji() self.translated_segments = Seg(
type='emoji',
data=base64_data
)
else:
self.translated_segments = Seg(
type='text',
data='[图片]'
)
elif self.type == 'at': elif self.type == 'at':
user_nickname = get_user_nickname(self.params.get('qq', '')) user_nickname = get_user_nickname(self.params.get('qq', ''))
if user_nickname: self.translated_segments = Seg(
self.translated_plain_text = f"[@{user_nickname}]" type='text',
else: data=f"[@{user_nickname or '某人'}]"
self.translated_plain_text = "@某人" )
elif self.type == 'reply': elif self.type == 'reply':
self.translated_plain_text = await self.translate_reply() reply_segments = self.translate_reply()
if reply_segments:
self.translated_segments = Seg(
type='seglist',
data=reply_segments
)
else:
self.translated_segments = Seg(
type='text',
data='[回复某人消息]'
)
elif self.type == 'face': elif self.type == 'face':
face_id = self.params.get('id', '') face_id = self.params.get('id', '')
# self.translated_plain_text = f"[表情{face_id}]" self.translated_segments = Seg(
self.translated_plain_text = f"[{emojimapper.get(int(face_id), '表情')}]" type='text',
data=f"[{emojimapper.get(int(face_id), '表情')}]"
)
elif self.type == 'forward': elif self.type == 'forward':
self.translated_plain_text = await self.translate_forward() forward_segments = self.translate_forward()
if forward_segments:
self.translated_segments = Seg(
type='seglist',
data=forward_segments
)
else: else:
self.translated_plain_text = f"[{self.type}]" self.translated_segments = Seg(
type='text',
data='[转发消息]'
)
else:
self.translated_segments = Seg(
type='text',
data=f"[{self.type}]"
)
def get_img(self): def get_img(self):
''' '''
@@ -160,155 +202,101 @@ class CQCode:
return None return None
async def translate_emoji(self) -> str:
"""处理表情包类型的CQ码""" def translate_image(self) -> Optional[str]:
"""处理图片类型的CQ码返回base64字符串"""
if 'url' not in self.params: if 'url' not in self.params:
return '[表情包]' return None
base64_str = self.get_img() return self.get_img()
if base64_str:
# 将 base64 字符串转换为字节类型
image_bytes = base64.b64decode(base64_str)
storage_emoji(image_bytes)
return await self.get_emoji_description(base64_str)
else:
return '[表情包]'
async def translate_image(self) -> str: def translate_forward(self) -> Optional[List[Seg]]:
"""处理图片类型的CQ码区分普通图片和表情包""" """处理转发消息返回Seg列表"""
# 没有url直接返回默认文本
if 'url' not in self.params:
return '[图片]'
base64_str = self.get_img()
if base64_str:
image_bytes = base64.b64decode(base64_str)
storage_image(image_bytes)
return await self.get_image_description(base64_str)
else:
return '[图片]'
async def get_emoji_description(self, image_base64: str) -> str:
"""调用AI接口获取表情包描述"""
try:
prompt = "这是一个表情包请用简短的中文描述这个表情包传达的情感和含义。最多20个字。"
# description, _ = self._llm.generate_response_for_image_sync(prompt, image_base64)
description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
return f"[表情包:{description}]"
except Exception as e:
print(f"\033[1;31m[错误]\033[0m AI接口调用失败: {str(e)}")
return "[表情包]"
async def get_image_description(self, image_base64: str) -> str:
"""调用AI接口获取普通图片描述"""
try:
prompt = "请用中文描述这张图片的内容。如果有文字请把文字都描述出来。并尝试猜测这个图片的含义。最多200个字。"
# description, _ = self._llm.generate_response_for_image_sync(prompt, image_base64)
description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
return f"[图片:{description}]"
except Exception as e:
print(f"\033[1;31m[错误]\033[0m AI接口调用失败: {str(e)}")
return "[图片]"
async def translate_forward(self) -> str:
"""处理转发消息"""
try: try:
if 'content' not in self.params: if 'content' not in self.params:
return '[转发消息]' return None
# 解析content内容需要先反转义
content = self.unescape(self.params['content']) content = self.unescape(self.params['content'])
# print(f"\033[1;34m[调试信息]\033[0m 转发消息内容: {content}")
# 将字符串形式的列表转换为Python对象
import ast import ast
try: try:
messages = ast.literal_eval(content) messages = ast.literal_eval(content)
except ValueError as e: except ValueError as e:
print(f"\033[1;31m[错误]\033[0m 解析转发消息内容失败: {str(e)}") logger.error(f"解析转发消息内容失败: {str(e)}")
return '[转发消息]' return None
# 处理每条消息 formatted_segments = []
formatted_messages = []
for msg in messages: for msg in messages:
sender = msg.get('sender', {}) sender = msg.get('sender', {})
nickname = sender.get('card') or sender.get('nickname', '未知用户') nickname = sender.get('card') or sender.get('nickname', '未知用户')
# 获取消息内容并使用Message类处理
raw_message = msg.get('raw_message', '') raw_message = msg.get('raw_message', '')
message_array = msg.get('message', []) message_array = msg.get('message', [])
if message_array and isinstance(message_array, list): if message_array and isinstance(message_array, list):
# 检查是否包含嵌套的转发消息
for message_part in message_array: for message_part in message_array:
if message_part.get('type') == 'forward': if message_part.get('type') == 'forward':
content = '[转发消息]' content_seg = Seg(type='text', data='[转发消息]')
break break
else: else:
# 处理普通消息
if raw_message: if raw_message:
from .message import Message from .message_cq import MessageRecvCQ
message_obj = Message( message_obj = MessageRecvCQ(
user_id=msg.get('user_id', 0), user_id=msg.get('user_id', 0),
message_id=msg.get('message_id', 0), message_id=msg.get('message_id', 0),
raw_message=raw_message, raw_message=raw_message,
plain_text=raw_message, plain_text=raw_message,
group_id=msg.get('group_id', 0) group_id=msg.get('group_id', 0)
) )
await message_obj.initialize() content_seg = Seg(type='seglist', data=message_obj.message_segments)
content = message_obj.processed_plain_text
else: else:
content = '[空消息]' content_seg = Seg(type='text', data='[空消息]')
else: else:
# 处理普通消息
if raw_message: if raw_message:
from .message import Message from .message_cq import MessageRecvCQ
message_obj = Message( message_obj = MessageRecvCQ(
user_id=msg.get('user_id', 0), user_id=msg.get('user_id', 0),
message_id=msg.get('message_id', 0), message_id=msg.get('message_id', 0),
raw_message=raw_message, raw_message=raw_message,
plain_text=raw_message, plain_text=raw_message,
group_id=msg.get('group_id', 0) group_id=msg.get('group_id', 0)
) )
await message_obj.initialize() content_seg = Seg(type='seglist', data=message_obj.message_segments)
content = message_obj.processed_plain_text
else: else:
content = '[空消息]' content_seg = Seg(type='text', data='[空消息]')
formatted_msg = f"{nickname}: {content}" formatted_segments.append(Seg(type='text', data=f"{nickname}: "))
formatted_messages.append(formatted_msg) formatted_segments.append(content_seg)
formatted_segments.append(Seg(type='text', data='\n'))
# 合并所有消息 return formatted_segments
combined_messages = '\n'.join(formatted_messages)
print(f"\033[1;34m[调试信息]\033[0m 合并后的转发消息: {combined_messages}")
return f"[转发消息:\n{combined_messages}]"
except Exception as e: except Exception as e:
print(f"\033[1;31m[错误]\033[0m 处理转发消息失败: {str(e)}") logger.error(f"处理转发消息失败: {str(e)}")
return '[转发消息]' return None
async def translate_reply(self) -> str: def translate_reply(self) -> Optional[List[Seg]]:
"""处理回复类型的CQ码""" """处理回复类型的CQ码返回Seg列表"""
from .message_cq import MessageRecvCQ
# 创建Message对象 if self.reply_message is None:
from .message import Message return None
if self.reply_message == None:
# print(f"\033[1;31m[错误]\033[0m 回复消息为空")
return '[回复某人消息]'
if self.reply_message.sender.user_id: if self.reply_message.sender.user_id:
message_obj = Message( message_obj = MessageRecvCQ(
user_id=self.reply_message.sender.user_id, user_id=self.reply_message.sender.user_id,
message_id=self.reply_message.message_id, message_id=self.reply_message.message_id,
raw_message=str(self.reply_message.message), raw_message=str(self.reply_message.message),
group_id=self.group_id group_id=self.group_id
) )
await message_obj.initialize()
if message_obj.user_id == global_config.BOT_QQ:
return f"[回复 {global_config.BOT_NICKNAME} 的消息: {message_obj.processed_plain_text}]"
else:
return f"[回复 {self.reply_message.sender.nickname} 的消息: {message_obj.processed_plain_text}]"
segments = []
if message_obj.user_id == global_config.BOT_QQ:
segments.append(Seg(type='text', data=f"[回复 {global_config.BOT_NICKNAME} 的消息: "))
else: else:
print("\033[1;31m[错误]\033[0m 回复消息的sender.user_id为空") segments.append(Seg(type='text', data=f"[回复 {self.reply_message.sender.nickname} 的消息: "))
return '[回复某人消息]'
segments.append(Seg(type='seglist', data=message_obj.message_segments))
segments.append(Seg(type='text', data="]"))
return segments
else:
return None
@staticmethod @staticmethod
def unescape(text: str) -> str: def unescape(text: str) -> str:
@@ -318,29 +306,12 @@ class CQCode:
.replace('&#93;', ']') \ .replace('&#93;', ']') \
.replace('&amp;', '&') .replace('&amp;', '&')
@staticmethod
def create_emoji_cq(file_path: str) -> str:
"""
创建表情包CQ码
Args:
file_path: 本地表情包文件路径
Returns:
表情包CQ码字符串
"""
# 确保使用绝对路径
abs_path = os.path.abspath(file_path)
# 转义特殊字符
escaped_path = abs_path.replace('&', '&amp;') \
.replace('[', '&#91;') \
.replace(']', '&#93;') \
.replace(',', '&#44;')
# 生成CQ码设置sub_type=1表示这是表情包
return f"[CQ:image,file=file:///{escaped_path},sub_type=1]"
class CQCode_tool: class CQCode_tool:
@staticmethod @staticmethod
async def cq_from_dict_to_class(cq_code: Dict, reply: Optional[Dict] = None) -> CQCode: def cq_from_dict_to_class(cq_code: Dict, reply: Optional[Dict] = None) -> CQCode:
""" """
将CQ码字典转换为CQCode对象 将CQ码字典转换为CQCode对象
@@ -369,7 +340,7 @@ class CQCode_tool:
) )
# 进行翻译处理 # 进行翻译处理
await instance.translate() instance.translate()
return instance return instance
@staticmethod @staticmethod
@@ -383,5 +354,26 @@ class CQCode_tool:
""" """
return f"[CQ:reply,id={message_id}]" return f"[CQ:reply,id={message_id}]"
@staticmethod
def create_emoji_cq(file_path: str) -> str:
"""
创建表情包CQ码
Args:
file_path: 本地表情包文件路径
Returns:
表情包CQ码字符串
"""
# 确保使用绝对路径
abs_path = os.path.abspath(file_path)
# 转义特殊字符
escaped_path = abs_path.replace('&', '&amp;') \
.replace('[', '&#91;') \
.replace(']', '&#93;') \
.replace(',', '&#44;')
# 生成CQ码设置sub_type=1表示这是表情包
return f"[CQ:image,file=file:///{escaped_path},sub_type=1]"
cq_code_tool = CQCode_tool() cq_code_tool = CQCode_tool()

View File

@@ -4,6 +4,8 @@ import random
import time import time
import traceback import traceback
from typing import Optional from typing import Optional
import base64
import hashlib
from loguru import logger from loguru import logger
from nonebot import get_driver from nonebot import get_driver
@@ -13,9 +15,11 @@ from ..chat.config import global_config
from ..chat.utils import get_embedding from ..chat.utils import get_embedding
from ..chat.utils_image import image_path_to_base64 from ..chat.utils_image import image_path_to_base64
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
from ..chat.utils_image import ImageManager
driver = get_driver() driver = get_driver()
config = driver.config config = driver.config
image_manager = ImageManager()
class EmojiManager: class EmojiManager:
@@ -142,14 +146,14 @@ class EmojiManager:
emoji_similarities.sort(key=lambda x: x[1], reverse=True) emoji_similarities.sort(key=lambda x: x[1], reverse=True)
# 获取前3个最相似的表情包 # 获取前3个最相似的表情包
top_3_emojis = emoji_similarities[:3] top_10_emojis = emoji_similarities[:10 if len(emoji_similarities) > 10 else len(emoji_similarities)]
if not top_3_emojis: if not top_10_emojis:
logger.warning("未找到匹配的表情包") logger.warning("未找到匹配的表情包")
return None return None
# 从前3个中随机选择一个 # 从前3个中随机选择一个
selected_emoji, similarity = random.choice(top_3_emojis) selected_emoji, similarity = random.choice(top_10_emojis)
if selected_emoji and 'path' in selected_emoji: if selected_emoji and 'path' in selected_emoji:
# 更新使用次数 # 更新使用次数
@@ -172,13 +176,13 @@ class EmojiManager:
return None return None
async def _get_emoji_discription(self, image_base64: str) -> str: async def _get_emoji_discription(self, image_base64: str) -> str:
"""获取表情包的标签""" """获取表情包的标签使用image_manager的描述生成功能"""
try: try:
prompt = '这是一个表情包,使用中文简洁的描述一下表情包的内容和表情包所表达的情感' # 使用image_manager获取描述去掉前后的方括号和"表情包:"前缀
description = await image_manager.get_emoji_description(image_base64)
content, _ = await self.vlm.generate_response_for_image(prompt, image_base64) # 去掉[表情包xxx]的格式,只保留描述内容
logger.debug(f"输出描述: {content}") description = description.strip('[]').replace('表情包:', '')
return content return description
except Exception as e: except Exception as e:
logger.error(f"获取标签失败: {str(e)}") logger.error(f"获取标签失败: {str(e)}")
@@ -220,42 +224,94 @@ class EmojiManager:
for filename in files_to_process: for filename in files_to_process:
image_path = os.path.join(emoji_dir, filename) image_path = os.path.join(emoji_dir, filename)
# 检查是否已经注册过 # 获取图片的base64编码和哈希值
existing_emoji = self.db.db['emoji'].find_one({'filename': filename})
if existing_emoji:
continue
# 压缩图片并获取base64编码
image_base64 = image_path_to_base64(image_path) image_base64 = image_path_to_base64(image_path)
if image_base64 is None: if image_base64 is None:
os.remove(image_path) os.remove(image_path)
continue continue
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
# 检查是否已经注册过
existing_emoji = self.db.db['emoji'].find_one({'filename': filename})
description = None
if existing_emoji:
# 即使表情包已存在也检查是否需要同步到images集合
description = existing_emoji.get('discription')
# 检查是否在images集合中存在
existing_image = await image_manager.db.db.images.find_one({'hash': image_hash})
if not existing_image:
# 同步到images集合
image_doc = {
'hash': image_hash,
'path': image_path,
'type': 'emoji',
'description': description,
'timestamp': int(time.time())
}
await image_manager.db.db.images.update_one(
{'hash': image_hash},
{'$set': image_doc},
upsert=True
)
# 保存描述到image_descriptions集合
await image_manager._save_description_to_db(image_hash, description, 'emoji')
logger.success(f"同步已存在的表情包到images集合: {filename}")
continue
# 检查是否在images集合中已有描述
existing_description = await image_manager._get_description_from_db(image_hash, 'emoji')
if existing_description:
description = existing_description
else:
# 获取表情包的描述 # 获取表情包的描述
discription = await self._get_emoji_discription(image_base64) description = await self._get_emoji_discription(image_base64)
if global_config.EMOJI_CHECK: if global_config.EMOJI_CHECK:
check = await self._check_emoji(image_base64) check = await self._check_emoji(image_base64)
if '' not in check: if '' not in check:
os.remove(image_path) os.remove(image_path)
logger.info(f"描述: {discription}") logger.info(f"描述: {description}")
logger.info(f"其不满足过滤规则,被剔除 {check}") logger.info(f"其不满足过滤规则,被剔除 {check}")
continue continue
logger.info(f"check通过 {check}") logger.info(f"check通过 {check}")
embedding = await get_embedding(discription)
if discription is not None: if description is not None:
embedding = await get_embedding(description)
# 准备数据库记录 # 准备数据库记录
emoji_record = { emoji_record = {
'filename': filename, 'filename': filename,
'path': image_path, 'path': image_path,
'embedding':embedding, 'embedding': embedding,
'discription': discription, 'discription': description,
'hash': image_hash,
'timestamp': int(time.time()) 'timestamp': int(time.time())
} }
# 保存到数据库 # 保存到emoji数据库
self.db.db['emoji'].insert_one(emoji_record) self.db.db['emoji'].insert_one(emoji_record)
logger.success(f"注册新表情包: {filename}") logger.success(f"注册新表情包: {filename}")
logger.info(f"描述: {discription}") logger.info(f"描述: {description}")
# 保存到images数据库
image_doc = {
'hash': image_hash,
'path': image_path,
'type': 'emoji',
'description': description,
'timestamp': int(time.time())
}
await image_manager.db.db.images.update_one(
{'hash': image_hash},
{'$set': image_doc},
upsert=True
)
# 保存描述到image_descriptions集合
await image_manager._save_description_to_db(image_hash, description, 'emoji')
logger.success(f"同步保存到images集合: {filename}")
else: else:
logger.warning(f"跳过表情包: {filename}") logger.warning(f"跳过表情包: {filename}")

View File

@@ -7,7 +7,7 @@ from nonebot import get_driver
from ...common.database import Database from ...common.database import Database
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
from .config import global_config from .config import global_config
from .message import Message from .message_cq import Message
from .prompt_builder import prompt_builder from .prompt_builder import prompt_builder
from .relationship_manager import relationship_manager from .relationship_manager import relationship_manager
from .utils import process_llm_response from .utils import process_llm_response

View File

@@ -1,14 +1,15 @@
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, ForwardRef, List, Optional from typing import Dict, ForwardRef, List, Optional, Union
import urllib3 import urllib3
from loguru import logger
from .cq_code import CQCode, cq_code_tool from .cq_code import CQCode, cq_code_tool
from .utils_cq import parse_cq_code from .utils_cq import parse_cq_code
from .utils_user import get_groupname, get_user_cardname, get_user_nickname from .utils_user import get_groupname, get_user_cardname, get_user_nickname
from .utils_image import image_manager
Message = ForwardRef('Message') # 添加这行 from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase
# 禁用SSL警告 # 禁用SSL警告
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
@@ -16,216 +17,344 @@ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
#它定义了消息的属性包括群组ID、用户ID、消息ID、原始消息内容、纯文本内容和时间戳。 #它定义了消息的属性包括群组ID、用户ID、消息ID、原始消息内容、纯文本内容和时间戳。
#它还定义了两个辅助属性keywords用于提取消息的关键词is_plain_text用于判断消息是否为纯文本。 #它还定义了两个辅助属性keywords用于提取消息的关键词is_plain_text用于判断消息是否为纯文本。
@dataclass
class MessageRecv(MessageBase):
"""接收消息类用于处理从MessageCQ序列化的消息"""
def __init__(self, message_dict: Dict):
"""从MessageCQ的字典初始化
Args:
message_dict: MessageCQ序列化后的字典
"""
message_info = BaseMessageInfo(**message_dict.get('message_info', {}))
message_segment = Seg(**message_dict.get('message_segment', {}))
raw_message = message_dict.get('raw_message')
super().__init__(
message_info=message_info,
message_segment=message_segment,
raw_message=raw_message
)
# 处理消息内容
self.processed_plain_text = "" # 初始化为空字符串
self.detailed_plain_text = "" # 初始化为空字符串
async def process(self) -> None:
"""处理消息内容,生成纯文本和详细文本
这个方法必须在创建实例后显式调用,因为它包含异步操作。
"""
self.processed_plain_text = await self._process_message_segments(self.message_segment)
self.detailed_plain_text = self._generate_detailed_text()
async def _process_message_segments(self, segment: Seg) -> str:
"""递归处理消息段,转换为文字描述
Args:
segment: 要处理的消息段
Returns:
str: 处理后的文本
"""
if segment.type == 'seglist':
# 处理消息段列表
segments_text = []
for seg in segment.data:
processed = await self._process_message_segments(seg)
if processed:
segments_text.append(processed)
return ' '.join(segments_text)
else:
# 处理单个消息段
return await self._process_single_segment(segment)
async def _process_single_segment(self, seg: Seg) -> str:
"""处理单个消息段
Args:
seg: 要处理的消息段
Returns:
str: 处理后的文本
"""
try:
if seg.type == 'text':
return seg.data
elif seg.type == 'image':
# 如果是base64图片数据
if isinstance(seg.data, str) and seg.data.startswith(('data:', 'base64:')):
return await image_manager.get_image_description(seg.data)
return '[图片]'
elif seg.type == 'emoji':
if isinstance(seg.data, str) and seg.data.startswith(('data:', 'base64:')):
return await image_manager.get_emoji_description(seg.data)
return '[表情]'
else:
return f"[{seg.type}:{str(seg.data)}]"
except Exception as e:
logger.error(f"处理消息段失败: {str(e)}, 类型: {seg.type}, 数据: {seg.data}")
return f"[处理失败的{seg.type}消息]"
def _generate_detailed_text(self) -> str:
"""生成详细文本,包含时间和用户信息"""
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.message_info.time))
user_info = self.message_info.user_info
name = (
f"{user_info.user_nickname}(ta的昵称:{user_info.user_cardname},ta的id:{user_info.user_id})"
if user_info.user_cardname!=''
else f"{user_info.user_nickname}(ta的id:{user_info.user_id})"
)
return f"[{time_str}] {name}: {self.processed_plain_text}\n"
@dataclass @dataclass
class Message: class MessageProcessBase(MessageBase):
"""消息数据类""" """消息处理基类,用于处理中和发送中的消息"""
message_id: int = None
time: float = None
group_id: int = None def __init__(
group_name: str = None # 群名称 self,
message_id: str,
user_id: int = None user_id: int,
user_nickname: str = None # 用户昵称 group_id: Optional[int] = None,
user_cardname: str = None # 用户群昵称 platform: str = "qq",
message_segment: Optional[Seg] = None,
raw_message: str = None # 原始消息包含未解析的cq码 reply: Optional['MessageRecv'] = None
plain_text: str = None # 纯文本 ):
# 构造用户信息
reply_message: Dict = None # 存储 回复的 源消息 user_info = UserInfo(
platform=platform,
# 延迟初始化字段 user_id=user_id,
_initialized: bool = False user_nickname=get_user_nickname(user_id),
message_segments: List[Dict] = None # 存储解析后的消息片段 user_cardname=get_user_cardname(user_id) if group_id else None
processed_plain_text: str = None # 用于存储处理后的plain_text
detailed_plain_text: str = None # 用于存储详细可读文本
# 状态标志
is_emoji: bool = False
has_emoji: bool = False
translate_cq: bool = True
async def initialize(self):
"""显式异步初始化方法(必须调用)"""
if self._initialized:
return
# 异步获取补充信息
self.group_name = self.group_name or get_groupname(self.group_id)
self.user_nickname = self.user_nickname or get_user_nickname(self.user_id)
self.user_cardname = self.user_cardname or get_user_cardname(self.user_id)
# 消息解析
if self.raw_message:
if not isinstance(self,Message_Sending):
self.message_segments = await self.parse_message_segments(self.raw_message)
self.processed_plain_text = ' '.join(
seg.translated_plain_text
for seg in self.message_segments
) )
# 构建详细文本 # 构造群组信息(如果有)
if self.time is None: group_info = None
self.time = int(time.time()) if group_id:
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.time)) group_info = GroupInfo(
name = ( platform=platform,
f"{self.user_nickname}(ta的昵称:{self.user_cardname},ta的id:{self.user_id})" group_id=group_id,
if self.user_cardname group_name=get_groupname(group_id)
else f"{self.user_nickname or f'用户{self.user_id}'}"
) )
if isinstance(self,Message_Sending) and self.is_emoji:
self.detailed_plain_text = f"[{time_str}] {name}: {self.detailed_plain_text}\n"
else:
self.detailed_plain_text = f"[{time_str}] {name}: {self.processed_plain_text}\n"
self._initialized = True # 构造基础消息信息
message_info = BaseMessageInfo(
platform=platform,
message_id=message_id,
time=int(time.time()),
group_info=group_info,
user_info=user_info
)
async def parse_message_segments(self, message: str) -> List[CQCode]: # 调用父类初始化
""" super().__init__(
将消息解析为片段列表包括纯文本和CQ码 message_info=message_info,
返回的列表中每个元素都是字典,包含: message_segment=message_segment,
- cq_code_list:分割出的聊天对象包括文本和CQ码 raw_message=None
- trans_list:翻译后的对象列表 )
"""
# print(f"\033[1;34m[调试信息]\033[0m 正在处理消息: {message}")
cq_code_dict_list = []
trans_list = []
start = 0 # 处理状态相关属性
while True:
# 查找下一个CQ码的开始位置
cq_start = message.find('[CQ:', start)
#如果没有cq码直接返回文本内容
if cq_start == -1:
# 如果没有找到更多CQ码添加剩余文本
if start < len(message):
text = message[start:].strip()
if text: # 只添加非空文本
cq_code_dict_list.append(parse_cq_code(text))
break
# 添加CQ码前的文本
if cq_start > start:
text = message[start:cq_start].strip()
if text: # 只添加非空文本
cq_code_dict_list.append(parse_cq_code(text))
# 查找CQ码的结束位置
cq_end = message.find(']', cq_start)
if cq_end == -1:
# CQ码未闭合作为普通文本处理
text = message[cq_start:].strip()
if text:
cq_code_dict_list.append(parse_cq_code(text))
break
cq_code = message[cq_start:cq_end + 1]
#将cq_code解析成字典
cq_code_dict_list.append(parse_cq_code(cq_code))
# 更新start位置到当前CQ码之后
start = cq_end + 1
# print(f"\033[1;34m[调试信息]\033[0m 提取的消息对象:列表: {cq_code_dict_list}")
#判定是否是表情包消息,以及是否含有表情包
if len(cq_code_dict_list) == 1 and cq_code_dict_list[0]['type'] == 'image':
self.is_emoji = True
self.has_emoji_emoji = True
else:
for segment in cq_code_dict_list:
if segment['type'] == 'image' and segment['data'].get('sub_type') == '1':
self.has_emoji_emoji = True
break
#翻译作为字典的CQ码
for _code_item in cq_code_dict_list:
message_obj = await cq_code_tool.cq_from_dict_to_class(_code_item,reply = self.reply_message)
trans_list.append(message_obj)
return trans_list
class Message_Thinking:
"""消息思考类"""
def __init__(self, message: Message,message_id: str):
# 复制原始消息的基本属性
self.group_id = message.group_id
self.user_id = message.user_id
self.user_nickname = message.user_nickname
self.user_cardname = message.user_cardname
self.group_name = message.group_name
self.message_id = message_id
# 思考状态相关属性
self.thinking_start_time = int(time.time()) self.thinking_start_time = int(time.time())
self.thinking_time = 0 self.thinking_time = 0
self.interupt=False
def update_thinking_time(self): # 文本处理相关属性
self.thinking_time = round(time.time(), 2) - self.thinking_start_time self.processed_plain_text = ""
self.detailed_plain_text = ""
# 回复消息
self.reply = reply
@dataclass def update_thinking_time(self) -> float:
class Message_Sending(Message): """更新思考时间"""
"""发送中的消息类""" self.thinking_time = round(time.time() - self.thinking_start_time, 2)
thinking_start_time: float = None # 思考开始时间
thinking_time: float = None # 思考时间
reply_message_id: int = None # 存储 回复的 源消息ID
is_head: bool = False # 是否是头部消息
def update_thinking_time(self):
self.thinking_time = round(time.time(), 2) - self.thinking_start_time
return self.thinking_time return self.thinking_time
async def _process_message_segments(self, segment: Seg) -> str:
"""递归处理消息段,转换为文字描述
Args:
segment: 要处理的消息段
Returns:
str: 处理后的文本
"""
if segment.type == 'seglist':
# 处理消息段列表
segments_text = []
for seg in segment.data:
processed = await self._process_message_segments(seg)
if processed:
segments_text.append(processed)
return ' '.join(segments_text)
else:
# 处理单个消息段
return await self._process_single_segment(segment)
async def _process_single_segment(self, seg: Seg) -> str:
"""处理单个消息段
Args:
seg: 要处理的消息段
Returns:
str: 处理后的文本
"""
try:
if seg.type == 'text':
return seg.data
elif seg.type == 'image':
# 如果是base64图片数据
if isinstance(seg.data, str) and seg.data.startswith(('data:', 'base64:')):
return await image_manager.get_image_description(seg.data)
return '[图片]'
elif seg.type == 'emoji':
if isinstance(seg.data, str) and seg.data.startswith(('data:', 'base64:')):
return await image_manager.get_emoji_description(seg.data)
return '[表情]'
elif seg.type == 'at':
return f"[@{seg.data}]"
elif seg.type == 'reply':
if self.reply and hasattr(self.reply, 'processed_plain_text'):
return f"[回复:{self.reply.processed_plain_text}]"
else:
return f"[{seg.type}:{str(seg.data)}]"
except Exception as e:
logger.error(f"处理消息段失败: {str(e)}, 类型: {seg.type}, 数据: {seg.data}")
return f"[处理失败的{seg.type}消息]"
def _generate_detailed_text(self) -> str:
"""生成详细文本,包含时间和用户信息"""
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.message_info.time))
user_info = self.message_info.user_info
name = (
f"{user_info.user_nickname}(ta的昵称:{user_info.user_cardname},ta的id:{user_info.user_id})"
if user_info.user_cardname != ''
else f"{user_info.user_nickname}(ta的id:{user_info.user_id})"
)
return f"[{time_str}] {name}: {self.processed_plain_text}\n"
@dataclass
class MessageThinking(MessageProcessBase):
"""思考状态的消息类"""
def __init__(
self,
message_id: str,
user_id: int,
group_id: Optional[int] = None,
platform: str = "qq",
reply: Optional['MessageRecv'] = None
):
# 调用父类初始化
super().__init__(
message_id=message_id,
user_id=user_id,
group_id=group_id,
platform=platform,
message_segment=None, # 思考状态不需要消息段
reply=reply
)
# 思考状态特有属性
self.interrupt = False
@dataclass
class MessageSending(MessageProcessBase):
"""发送状态的消息类"""
def __init__(
self,
message_id: str,
user_id: int,
message_segment: Seg,
group_id: Optional[int] = None,
reply: Optional['MessageRecv'] = None,
platform: str = "qq",
is_head: bool = False
):
# 调用父类初始化
super().__init__(
message_id=message_id,
user_id=user_id,
group_id=group_id,
platform=platform,
message_segment=message_segment,
reply=reply
)
# 发送状态特有属性
self.reply_to_message_id = reply.message_info.message_id if reply else None
self.is_head = is_head
async def process(self) -> None:
"""处理消息内容,生成纯文本和详细文本"""
if self.message_segment:
self.processed_plain_text = await self._process_message_segments(self.message_segment)
self.detailed_plain_text = self._generate_detailed_text()
@classmethod
def from_thinking(
cls,
thinking: MessageThinking,
message_segment: Seg,
reply: Optional['MessageRecv'] = None,
is_head: bool = False
) -> 'MessageSending':
"""从思考状态消息创建发送状态消息"""
return cls(
message_id=thinking.message_info.message_id,
user_id=thinking.message_info.user_info.user_id,
message_segment=message_segment,
group_id=thinking.message_info.group_info.group_id if thinking.message_info.group_info else None,
reply=reply or thinking.reply,
platform=thinking.message_info.platform,
is_head=is_head
)
@dataclass
class MessageSet: class MessageSet:
"""消息集合类,可以存储多个发送消息""" """消息集合类,可以存储多个发送消息"""
def __init__(self, group_id: int, user_id: int, message_id: str): def __init__(self, group_id: int, user_id: int, message_id: str):
self.group_id = group_id self.group_id = group_id
self.user_id = user_id self.user_id = user_id
self.message_id = message_id self.message_id = message_id
self.messages: List[Message_Sending] = [] # 修改类型标注 self.messages: List[MessageSending] = []
self.time = round(time.time(), 2) self.time = round(time.time(), 2)
def add_message(self, message: Message_Sending) -> None: def add_message(self, message: MessageSending) -> None:
"""添加消息到集合只接受Message_Sending类型""" """添加消息到集合"""
if not isinstance(message, Message_Sending): if not isinstance(message, MessageSending):
raise TypeError("MessageSet只能添加Message_Sending类型的消息") raise TypeError("MessageSet只能添加MessageSending类型的消息")
self.messages.append(message) self.messages.append(message)
# 按时间排序 self.messages.sort(key=lambda x: x.message_info.time)
self.messages.sort(key=lambda x: x.time)
def get_message_by_index(self, index: int) -> Optional[Message_Sending]: def get_message_by_index(self, index: int) -> Optional[MessageSending]:
"""通过索引获取消息""" """通过索引获取消息"""
if 0 <= index < len(self.messages): if 0 <= index < len(self.messages):
return self.messages[index] return self.messages[index]
return None return None
def get_message_by_time(self, target_time: float) -> Optional[Message_Sending]: def get_message_by_time(self, target_time: float) -> Optional[MessageSending]:
"""获取最接近指定时间的消息""" """获取最接近指定时间的消息"""
if not self.messages: if not self.messages:
return None return None
# 使用二分查找找到最接近的消息
left, right = 0, len(self.messages) - 1 left, right = 0, len(self.messages) - 1
while left < right: while left < right:
mid = (left + right) // 2 mid = (left + right) // 2
if self.messages[mid].time < target_time: if self.messages[mid].message_info.time < target_time:
left = mid + 1 left = mid + 1
else: else:
right = mid right = mid
return self.messages[left] return self.messages[left]
def clear_messages(self) -> None: def clear_messages(self) -> None:
"""清空所有消息""" """清空所有消息"""
self.messages.clear() self.messages.clear()
def remove_message(self, message: Message_Sending) -> bool: def remove_message(self, message: MessageSending) -> bool:
"""移除指定消息""" """移除指定消息"""
if message in self.messages: if message in self.messages:
self.messages.remove(message) self.messages.remove(message)

View File

@@ -0,0 +1,158 @@
from dataclasses import dataclass, asdict
from typing import List, Optional, Union, Any, Dict
@dataclass
class Seg(dict):
"""消息片段类,用于表示消息的不同部分
Attributes:
type: 片段类型,可以是 'text''image''seglist'
data: 片段的具体内容
- 对于 text 类型data 是字符串
- 对于 image 类型data 是 base64 字符串
- 对于 seglist 类型data 是 Seg 列表
translated_data: 经过翻译处理的数据(可选)
"""
type: str
data: Union[str, List['Seg']]
translated_data: Optional[str] = None
def __init__(self, type: str, data: Union[str, List['Seg']], translated_data: Optional[str] = None):
"""初始化实例,确保字典和属性同步"""
# 先初始化字典
super().__init__(type=type, data=data)
if translated_data is not None:
self['translated_data'] = translated_data
# 再初始化属性
object.__setattr__(self, 'type', type)
object.__setattr__(self, 'data', data)
object.__setattr__(self, 'translated_data', translated_data)
# 验证数据类型
self._validate_data()
def _validate_data(self) -> None:
"""验证数据类型的正确性"""
if self.type == 'seglist' and not isinstance(self.data, list):
raise ValueError("seglist类型的data必须是列表")
elif self.type == 'text' and not isinstance(self.data, str):
raise ValueError("text类型的data必须是字符串")
elif self.type == 'image' and not isinstance(self.data, str):
raise ValueError("image类型的data必须是字符串")
def __setattr__(self, name: str, value: Any) -> None:
"""重写属性设置,同时更新字典值"""
# 更新属性
object.__setattr__(self, name, value)
# 同步更新字典
if name in ['type', 'data', 'translated_data']:
self[name] = value
def __setitem__(self, key: str, value: Any) -> None:
"""重写字典值设置,同时更新属性"""
# 更新字典
super().__setitem__(key, value)
# 同步更新属性
if key in ['type', 'data', 'translated_data']:
object.__setattr__(self, key, value)
def to_dict(self) -> Dict:
"""转换为字典格式"""
result = {'type': self.type}
if self.type == 'seglist':
result['data'] = [seg.to_dict() for seg in self.data]
else:
result['data'] = self.data
if self.translated_data is not None:
result['translated_data'] = self.translated_data
return result
@dataclass
class GroupInfo:
"""群组信息类"""
platform: Optional[str] = None
group_id: Optional[int] = None
group_name: Optional[str] = None # 群名称
def to_dict(self) -> Dict:
"""转换为字典格式"""
return {k: v for k, v in asdict(self).items() if v is not None}
@dataclass
class UserInfo:
"""用户信息类"""
platform: Optional[str] = None
user_id: Optional[int] = None
user_nickname: Optional[str] = None # 用户昵称
user_cardname: Optional[str] = None # 用户群昵称
def to_dict(self) -> Dict:
"""转换为字典格式"""
return {k: v for k, v in asdict(self).items() if v is not None}
@dataclass
class BaseMessageInfo:
"""消息信息类"""
platform: Optional[str] = None
message_id: Optional[int,str] = None
time: Optional[int] = None
group_info: Optional[GroupInfo] = None
user_info: Optional[UserInfo] = None
def to_dict(self) -> Dict:
"""转换为字典格式"""
result = {}
for field, value in asdict(self).items():
if value is not None:
if isinstance(value, (GroupInfo, UserInfo)):
result[field] = value.to_dict()
else:
result[field] = value
return result
@dataclass
class MessageBase:
"""消息类"""
message_info: BaseMessageInfo
message_segment: Seg
raw_message: Optional[str] = None # 原始消息包含未解析的cq码
def to_dict(self) -> Dict:
"""转换为字典格式
Returns:
Dict: 包含所有非None字段的字典其中
- message_info: 转换为字典格式
- message_segment: 转换为字典格式
- raw_message: 如果存在则包含
"""
result = {
'message_info': self.message_info.to_dict(),
'message_segment': self.message_segment.to_dict()
}
if self.raw_message is not None:
result['raw_message'] = self.raw_message
return result
@classmethod
def from_dict(cls, data: Dict) -> 'MessageBase':
"""从字典创建MessageBase实例
Args:
data: 包含必要字段的字典
Returns:
MessageBase: 新的实例
"""
message_info = BaseMessageInfo(**data.get('message_info', {}))
message_segment = Seg(**data.get('message_segment', {}))
raw_message = data.get('raw_message')
return cls(
message_info=message_info,
message_segment=message_segment,
raw_message=raw_message
)

View File

@@ -0,0 +1,188 @@
import time
from dataclasses import dataclass
from typing import Dict, ForwardRef, List, Optional, Union
import urllib3
from .cq_code import CQCode, cq_code_tool
from .utils_cq import parse_cq_code
from .utils_user import get_groupname, get_user_cardname, get_user_nickname
from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase
# 禁用SSL警告
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
#这个类是消息数据类,用于存储和管理消息数据。
#它定义了消息的属性包括群组ID、用户ID、消息ID、原始消息内容、纯文本内容和时间戳。
#它还定义了两个辅助属性keywords用于提取消息的关键词is_plain_text用于判断消息是否为纯文本。
@dataclass
class MessageCQ(MessageBase):
"""QQ消息基类继承自MessageBase
最小必要参数:
- message_id: 消息ID
- user_id: 发送者/接收者ID
- platform: 平台标识(默认为"qq"
"""
def __init__(
self,
message_id: int,
user_id: int,
group_id: Optional[int] = None,
platform: str = "qq"
):
# 构造用户信息
user_info = UserInfo(
platform=platform,
user_id=user_id,
user_nickname=get_user_nickname(user_id),
user_cardname=get_user_cardname(user_id) if group_id else None
)
# 构造群组信息(如果有)
group_info = None
if group_id:
group_info = GroupInfo(
platform=platform,
group_id=group_id,
group_name=get_groupname(group_id)
)
# 构造基础消息信息
message_info = BaseMessageInfo(
platform=platform,
message_id=message_id,
time=int(time.time()),
group_info=group_info,
user_info=user_info
)
# 调用父类初始化message_segment 由子类设置
super().__init__(
message_info=message_info,
message_segment=None,
raw_message=None
)
@dataclass
class MessageRecvCQ(MessageCQ):
"""QQ接收消息类用于解析raw_message到Seg对象"""
def __init__(
self,
message_id: int,
user_id: int,
raw_message: str,
group_id: Optional[int] = None,
reply_message: Optional[Dict] = None,
platform: str = "qq"
):
# 调用父类初始化
super().__init__(message_id, user_id, group_id, platform)
# 解析消息段
self.message_segment = self._parse_message(raw_message, reply_message)
self.raw_message = raw_message
def _parse_message(self, message: str, reply_message: Optional[Dict] = None) -> Seg:
"""解析消息内容为Seg对象"""
cq_code_dict_list = []
segments = []
start = 0
while True:
cq_start = message.find('[CQ:', start)
if cq_start == -1:
if start < len(message):
text = message[start:].strip()
if text:
cq_code_dict_list.append(parse_cq_code(text))
break
if cq_start > start:
text = message[start:cq_start].strip()
if text:
cq_code_dict_list.append(parse_cq_code(text))
cq_end = message.find(']', cq_start)
if cq_end == -1:
text = message[cq_start:].strip()
if text:
cq_code_dict_list.append(parse_cq_code(text))
break
cq_code = message[cq_start:cq_end + 1]
cq_code_dict_list.append(parse_cq_code(cq_code))
start = cq_end + 1
# 转换CQ码为Seg对象
for code_item in cq_code_dict_list:
message_obj = cq_code_tool.cq_from_dict_to_class(code_item, reply=reply_message)
if message_obj.translated_segments:
segments.append(message_obj.translated_segments)
# 如果只有一个segment直接返回
if len(segments) == 1:
return segments[0]
# 否则返回seglist类型的Seg
return Seg(type='seglist', data=segments)
def to_dict(self) -> Dict:
"""转换为字典格式,包含所有必要信息"""
base_dict = super().to_dict()
return base_dict
@dataclass
class MessageSendCQ(MessageCQ):
"""QQ发送消息类用于将Seg对象转换为raw_message"""
def __init__(
self,
message_id: int,
user_id: int,
message_segment: Seg,
group_id: Optional[int] = None,
reply_to_message_id: Optional[int] = None,
platform: str = "qq"
):
# 调用父类初始化
super().__init__(message_id, user_id, group_id, platform)
self.message_segment = message_segment
self.raw_message = self._generate_raw_message(reply_to_message_id)
def _generate_raw_message(self, reply_to_message_id: Optional[int] = None) -> str:
"""将Seg对象转换为raw_message"""
segments = []
# 添加回复消息
if reply_to_message_id:
segments.append(cq_code_tool.create_reply_cq(reply_to_message_id))
# 处理消息段
if self.message_segment.type == 'seglist':
for seg in self.message_segment.data:
segments.append(self._seg_to_cq_code(seg))
else:
segments.append(self._seg_to_cq_code(self.message_segment))
return ''.join(segments)
def _seg_to_cq_code(self, seg: Seg) -> str:
"""将单个Seg对象转换为CQ码字符串"""
if seg.type == 'text':
return str(seg.data)
elif seg.type == 'image':
# 如果是base64图片数据
if seg.data.startswith(('data:', 'base64:')):
return f"[CQ:image,file=base64://{seg.data}]"
# 如果是表情包(本地文件)
return cq_code_tool.create_emoji_cq(seg.data)
elif seg.type == 'at':
return f"[CQ:at,qq={seg.data}]"
elif seg.type == 'reply':
return cq_code_tool.create_reply_cq(int(seg.data))
else:
return f"[{seg.data}]"

View File

@@ -5,7 +5,7 @@ from typing import Dict, List, Optional, Union
from nonebot.adapters.onebot.v11 import Bot from nonebot.adapters.onebot.v11 import Bot
from .cq_code import cq_code_tool from .cq_code import cq_code_tool
from .message import Message, Message_Sending, Message_Thinking, MessageSet from .message_cq import Message, Message_Sending, Message_Thinking, MessageSet
from .storage import MessageStorage from .storage import MessageStorage
from .utils import calculate_typing_time from .utils import calculate_typing_time
from .config import global_config from .config import global_config

View File

@@ -1,7 +1,7 @@
from typing import Optional from typing import Optional
from ...common.database import Database from ...common.database import Database
from .message import Message from .message_cq import Message
class MessageStorage: class MessageStorage:

View File

@@ -11,31 +11,12 @@ from nonebot import get_driver
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
from ..utils.typo_generator import ChineseTypoGenerator from ..utils.typo_generator import ChineseTypoGenerator
from .config import global_config from .config import global_config
from .message import Message from .message_cq import Message
driver = get_driver() driver = get_driver()
config = driver.config config = driver.config
def combine_messages(messages: List[Message]) -> str:
"""将消息列表组合成格式化的字符串
Args:
messages: Message对象列表
Returns:
str: 格式化后的消息字符串
"""
result = ""
for message in messages:
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message.time))
name = message.user_nickname or f"用户{message.user_id}"
content = message.processed_plain_text or message.plain_text
result += f"[{time_str}] {name}: {content}\n"
return result
def db_message_to_str(message_dict: Dict) -> str: def db_message_to_str(message_dict: Dict) -> str:
print(f"message_dict: {message_dict}") print(f"message_dict: {message_dict}")
@@ -159,7 +140,7 @@ async def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
return [] return []
# 转换为 Message对象列表 # 转换为 Message对象列表
from .message import Message from .message_cq import Message
message_objects = [] message_objects = []
for msg_data in recent_messages: for msg_data in recent_messages:
try: try:

View File

@@ -2,7 +2,11 @@ import base64
import io import io
import os import os
import time import time
import zlib # 用于 CRC32 import zlib
import aiohttp
import hashlib
from typing import Optional, Tuple, Union
from urllib.parse import urlparse
from loguru import logger from loguru import logger
from nonebot import get_driver from nonebot import get_driver
@@ -10,213 +14,348 @@ from PIL import Image
from ...common.database import Database from ...common.database import Database
from ..chat.config import global_config from ..chat.config import global_config
from ..models.utils_model import LLM_request
driver = get_driver() driver = get_driver()
config = driver.config config = driver.config
class ImageManager:
_instance = None
IMAGE_DIR = "data" # 图像存储根目录
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance.db = None
cls._instance._initialized = False
return cls._instance
def __init__(self):
if not self._initialized:
self.db = Database.get_instance()
self._ensure_image_collection()
self._ensure_description_collection()
self._ensure_image_dir()
self._initialized = True
self._llm = LLM_request(model=global_config.vlm, temperature=0.4, max_tokens=300)
def _ensure_image_dir(self):
"""确保图像存储目录存在"""
os.makedirs(self.IMAGE_DIR, exist_ok=True)
def _ensure_image_collection(self):
"""确保images集合存在并创建索引"""
if 'images' not in self.db.db.list_collection_names():
self.db.db.create_collection('images')
# 创建索引
self.db.db.images.create_index([('hash', 1)], unique=True)
self.db.db.images.create_index([('url', 1)])
self.db.db.images.create_index([('path', 1)])
def _ensure_description_collection(self):
"""确保image_descriptions集合存在并创建索引"""
if 'image_descriptions' not in self.db.db.list_collection_names():
self.db.db.create_collection('image_descriptions')
# 创建索引
self.db.db.image_descriptions.create_index([('hash', 1)], unique=True)
self.db.db.image_descriptions.create_index([('type', 1)])
async def _get_description_from_db(self, image_hash: str, description_type: str) -> Optional[str]:
"""从数据库获取图片描述
def storage_compress_image(base64_data: str, max_size: int = 200) -> str:
"""
压缩base64格式的图片到指定大小单位KB并在数据库中记录图片信息
Args: Args:
base64_data: base64编码的图片数据 image_hash: 图片哈希值
max_size: 最大文件大小KB description_type: 描述类型 ('emoji''image')
Returns: Returns:
str: 压缩后的base64图片数据 Optional[str]: 描述文本如果不存在则返回None
""" """
try: result = await self.db.db.image_descriptions.find_one({
# 将base64转换为字节数据 'hash': image_hash,
image_data = base64.b64decode(base64_data) 'type': description_type
})
return result['description'] if result else None
# 使用 CRC32 计算哈希值 async def _save_description_to_db(self, image_hash: str, description: str, description_type: str) -> None:
hash_value = format(zlib.crc32(image_data) & 0xFFFFFFFF, 'x') """保存图片描述到数据库
# 确保图片目录存在 Args:
images_dir = "data/images" image_hash: 图片哈希值
os.makedirs(images_dir, exist_ok=True) description: 描述文本
description_type: 描述类型 ('emoji''image')
# 连接数据库 """
db = Database( await self.db.db.image_descriptions.update_one(
host=config.mongodb_host, {'hash': image_hash, 'type': description_type},
port=int(config.mongodb_port), {
db_name=config.database_name, '$set': {
username=config.mongodb_username, 'description': description,
password=config.mongodb_password, 'timestamp': int(time.time())
auth_source=config.mongodb_auth_source }
},
upsert=True
) )
# 检查是否已存在相同哈希值的图片 async def save_image(self,
collection = db.db['images'] image_data: Union[str, bytes],
existing_image = collection.find_one({'hash': hash_value}) url: str = None,
description: str = None,
if existing_image: is_base64: bool = False) -> Optional[str]:
print(f"\033[1;33m[提示]\033[0m 发现重复图片,使用已存在的文件: {existing_image['path']}") """保存图像
return base64_data Args:
image_data: 图像数据(base64字符串或字节)
# 将字节数据转换为图片对象 url: 图像URL
img = Image.open(io.BytesIO(image_data)) description: 图像描述
is_base64: image_data是否为base64格式
# 如果是动图,直接返回原图 Returns:
if getattr(img, 'is_animated', False): str: 保存后的文件路径,失败返回None
return base64_data """
try:
# 计算当前大小KB # 转换为字节格式
current_size = len(image_data) / 1024 if is_base64:
if isinstance(image_data, str):
# 如果已经小于目标大小,直接使用原图 image_bytes = base64.b64decode(image_data)
if current_size <= max_size:
compressed_data = image_data
else: else:
# 压缩逻辑 return None
# 先缩放到50% else:
new_width = int(img.width * 0.5) if isinstance(image_data, bytes):
new_height = int(img.height * 0.5) image_bytes = image_data
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) else:
return None
# 如果缩放后的最大边长仍然大于400继续缩放 # 计算哈希值
max_dimension = 400 image_hash = hashlib.md5(image_bytes).hexdigest()
max_current = max(new_width, new_height)
if max_current > max_dimension:
ratio = max_dimension / max_current
new_width = int(new_width * ratio)
new_height = int(new_height * ratio)
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
# 转换为RGB模式去除透明通道 # 查重
if img.mode in ('RGBA', 'P'): existing = self.db.db.images.find_one({'hash': image_hash})
img = img.convert('RGB') if existing:
return existing['path']
# 使用固定质量参数压缩 # 生成文件名和路径
output = io.BytesIO()
img.save(output, format='JPEG', quality=85, optimize=True)
compressed_data = output.getvalue()
# 生成文件名(使用时间戳和哈希值确保唯一性)
timestamp = int(time.time()) timestamp = int(time.time())
filename = f"{timestamp}_{hash_value}.jpg" filename = f"{timestamp}_{image_hash[:8]}.jpg"
image_path = os.path.join(images_dir, filename) file_path = os.path.join(self.IMAGE_DIR, filename)
# 保存文件 # 保存文件
with open(image_path, "wb") as f: with open(file_path, "wb") as f:
f.write(compressed_data) f.write(image_bytes)
print(f"\033[1;32m[成功]\033[0m 保存图片到: {image_path}") # 保存到数据库
image_doc = {
try: 'hash': image_hash,
# 准备数据库记录 'path': file_path,
image_record = { 'url': url,
'filename': filename, 'description': description,
'path': image_path, 'timestamp': timestamp
'size': len(compressed_data) / 1024,
'timestamp': timestamp,
'width': img.width,
'height': img.height,
'description': '',
'tags': [],
'type': 'image',
'hash': hash_value
} }
self.db.db.images.insert_one(image_doc)
# 保存记录 return file_path
collection.insert_one(image_record)
print("\033[1;32m[成功]\033[0m 保存图片记录到数据库")
except Exception as db_error:
print(f"\033[1;31m[错误]\033[0m 数据库操作失败: {str(db_error)}")
# 将压缩后的数据转换为base64
compressed_base64 = base64.b64encode(compressed_data).decode('utf-8')
return compressed_base64
except Exception as e: except Exception as e:
print(f"\033[1;31m[错误]\033[0m 压缩图片失败: {str(e)}") logger.error(f"保存图像失败: {str(e)}")
import traceback return None
print(traceback.format_exc())
return base64_data
def storage_emoji(image_data: bytes) -> bytes: async def get_image_by_url(self, url: str) -> Optional[str]:
""" """根据URL获取图像路径(带查重)
存储表情包到本地文件夹
Args: Args:
image_data: 图片字节数据 url: 图像URL
group_id: 群组ID仅用于日志
user_id: 用户ID仅用于日志
Returns: Returns:
bytes: 原始图片数据 str: 本地文件路径,不存在返回None
"""
if not global_config.EMOJI_SAVE:
return image_data
try:
# 使用 CRC32 计算哈希值
hash_value = format(zlib.crc32(image_data) & 0xFFFFFFFF, 'x')
# 确保表情包目录存在
emoji_dir = "data/emoji"
os.makedirs(emoji_dir, exist_ok=True)
# 检查是否已存在相同哈希值的文件
for filename in os.listdir(emoji_dir):
if hash_value in filename:
# print(f"\033[1;33m[提示]\033[0m 发现重复表情包: {filename}")
return image_data
# 生成文件名
timestamp = int(time.time())
filename = f"{timestamp}_{hash_value}.jpg"
emoji_path = os.path.join(emoji_dir, filename)
# 直接保存原始文件
with open(emoji_path, "wb") as f:
f.write(image_data)
print(f"\033[1;32m[成功]\033[0m 保存表情包到: {emoji_path}")
return image_data
except Exception as e:
print(f"\033[1;31m[错误]\033[0m 保存表情包失败: {str(e)}")
return image_data
def storage_image(image_data: bytes) -> bytes:
"""
存储图片到本地文件夹
Args:
image_data: 图片字节数据
group_id: 群组ID仅用于日志
user_id: 用户ID仅用于日志
Returns:
bytes: 原始图片数据
""" """
try: try:
# 使用 CRC32 计算哈希值 # 先查找是否已存在
hash_value = format(zlib.crc32(image_data) & 0xFFFFFFFF, 'x') existing = self.db.db.images.find_one({'url': url})
if existing:
return existing['path']
# 确保表情包目录存在 # 下载图像
image_dir = "data/image" async with aiohttp.ClientSession() as session:
os.makedirs(image_dir, exist_ok=True) async with session.get(url) as resp:
if resp.status == 200:
# 检查是否已存在相同哈希值的文件 image_bytes = await resp.read()
for filename in os.listdir(image_dir): return await self.save_image(image_bytes, url=url)
if hash_value in filename: return None
# print(f"\033[1;33m[提示]\033[0m 发现重复表情包: {filename}")
return image_data
# 生成文件名
timestamp = int(time.time())
filename = f"{timestamp}_{hash_value}.jpg"
image_path = os.path.join(image_dir, filename)
# 直接保存原始文件
with open(image_path, "wb") as f:
f.write(image_data)
print(f"\033[1;32m[成功]\033[0m 保存图片到: {image_path}")
return image_data
except Exception as e: except Exception as e:
print(f"\033[1;31m[错误]\033[0m 保存图片失败: {str(e)}") logger.error(f"获取图像失败: {str(e)}")
return image_data return None
async def get_base64_by_url(self, url: str) -> Optional[str]:
"""根据URL获取base64(带查重)
Args:
url: 图像URL
Returns:
str: base64字符串,失败返回None
"""
try:
image_path = await self.get_image_by_url(url)
if not image_path:
return None
with open(image_path, 'rb') as f:
image_bytes = f.read()
return base64.b64encode(image_bytes).decode('utf-8')
except Exception as e:
logger.error(f"获取base64失败: {str(e)}")
return None
async def save_base64_image(self, base64_str: str, description: str = None) -> Optional[str]:
"""保存base64图像(带查重)
Args:
base64_str: base64字符串
description: 图像描述
Returns:
str: 保存路径,失败返回None
"""
return await self.save_image(base64_str, description=description, is_base64=True)
def check_url_exists(self, url: str) -> bool:
"""检查URL是否已存在
Args:
url: 图像URL
Returns:
bool: 是否存在
"""
return self.db.db.images.find_one({'url': url}) is not None
def check_hash_exists(self, image_data: Union[str, bytes], is_base64: bool = False) -> bool:
"""检查图像是否已存在
Args:
image_data: 图像数据(base64或字节)
is_base64: 是否为base64格式
Returns:
bool: 是否存在
"""
try:
if is_base64:
if isinstance(image_data, str):
image_bytes = base64.b64decode(image_data)
else:
return False
else:
if isinstance(image_data, bytes):
image_bytes = image_data
else:
return False
image_hash = hashlib.md5(image_bytes).hexdigest()
return self.db.db.images.find_one({'hash': image_hash}) is not None
except Exception as e:
logger.error(f"检查哈希失败: {str(e)}")
return False
async def get_emoji_description(self, image_base64: str) -> str:
"""获取表情包描述,带查重和保存功能"""
try:
# 计算图片哈希
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
# 查询缓存的描述
cached_description = await self._get_description_from_db(image_hash, 'emoji')
if cached_description:
return f"[表情包:{cached_description}]"
# 调用AI获取描述
prompt = "这是一个表情包,使用中文简洁的描述一下表情包的内容和表情包所表达的情感"
description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
# 根据配置决定是否保存图片
if global_config.EMOJI_SAVE:
# 生成文件名和路径
timestamp = int(time.time())
filename = f"emoji_{timestamp}_{image_hash[:8]}.jpg"
file_path = os.path.join(self.IMAGE_DIR, filename)
try:
# 保存文件
with open(file_path, "wb") as f:
f.write(image_bytes)
# 保存到数据库
image_doc = {
'hash': image_hash,
'path': file_path,
'type': 'emoji',
'description': description,
'timestamp': timestamp
}
await self.db.db.images.update_one(
{'hash': image_hash},
{'$set': image_doc},
upsert=True
)
logger.success(f"保存表情包: {file_path}")
except Exception as e:
logger.error(f"保存表情包文件失败: {str(e)}")
# 保存描述到数据库
await self._save_description_to_db(image_hash, description, 'emoji')
return f"[表情包:{description}]"
except Exception as e:
logger.error(f"获取表情包描述失败: {str(e)}")
return "[表情包]"
async def get_image_description(self, image_base64: str) -> str:
"""获取普通图片描述,带查重和保存功能"""
try:
# 计算图片哈希
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
# 查询缓存的描述
cached_description = await self._get_description_from_db(image_hash, 'image')
if cached_description:
return f"[图片:{cached_description}]"
# 调用AI获取描述
prompt = "请用中文描述这张图片的内容。如果有文字请把文字都描述出来。并尝试猜测这个图片的含义。最多200个字。"
description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
# 根据配置决定是否保存图片
if global_config.EMOJI_SAVE:
# 生成文件名和路径
timestamp = int(time.time())
filename = f"image_{timestamp}_{image_hash[:8]}.jpg"
file_path = os.path.join(self.IMAGE_DIR, filename)
try:
# 保存文件
with open(file_path, "wb") as f:
f.write(image_bytes)
# 保存到数据库
image_doc = {
'hash': image_hash,
'path': file_path,
'type': 'image',
'description': description,
'timestamp': timestamp
}
await self.db.db.images.update_one(
{'hash': image_hash},
{'$set': image_doc},
upsert=True
)
logger.success(f"保存图片: {file_path}")
except Exception as e:
logger.error(f"保存图片文件失败: {str(e)}")
# 保存描述到数据库
await self._save_description_to_db(image_hash, description, 'image')
return f"[图片:{description}]"
except Exception as e:
logger.error(f"获取图片描述失败: {str(e)}")
return "[图片]"
# 创建全局单例
image_manager = ImageManager()
def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 1024 * 1024) -> str: def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 1024 * 1024) -> str:
"""压缩base64格式的图片到指定大小 """压缩base64格式的图片到指定大小