diff --git a/src/plugins/chat/__init__.py b/src/plugins/chat/__init__.py index 0bffaed19..a62343d0c 100644 --- a/src/plugins/chat/__init__.py +++ b/src/plugins/chat/__init__.py @@ -18,6 +18,7 @@ from .config import global_config from .emoji_manager import emoji_manager from .relationship_manager import relationship_manager from .willing_manager import willing_manager +from .chat_stream import chat_manager # 创建LLM统计实例 llm_stats = LLMStatistics("llm_statistics.txt") @@ -101,6 +102,8 @@ async def _(bot: Bot): asyncio.create_task(emoji_manager._periodic_scan(interval_MINS=global_config.EMOJI_REGISTER_INTERVAL)) print("\033[1;38;5;208m-----------开始偷表情包!-----------\033[0m") + asyncio.create_task(chat_manager._initialize()) + asyncio.create_task(chat_manager._auto_save_task()) @group_msg.handle() async def _(bot: Bot, event: GroupMessageEvent, state: T_State): diff --git a/src/plugins/chat/bot.py b/src/plugins/chat/bot.py index a2fdab873..a5f4ac476 100644 --- a/src/plugins/chat/bot.py +++ b/src/plugins/chat/bot.py @@ -18,7 +18,7 @@ from .chat_stream import chat_manager from .message_sender import message_manager # 导入新的消息管理器 from .relationship_manager import relationship_manager from .storage import MessageStorage -from .utils import calculate_typing_time, is_mentioned_bot_in_txt +from .utils import calculate_typing_time, is_mentioned_bot_in_message from .utils_image import image_path_to_base64 from .willing_manager import willing_manager # 导入意愿管理器 from .message_base import UserInfo, GroupInfo, Seg @@ -45,8 +45,8 @@ class ChatBot: self.bot = bot # 更新 bot 实例 - group_info = await bot.get_group_info(group_id=event.group_id) - sender_info = await bot.get_group_member_info(group_id=event.group_id, user_id=event.user_id, no_cache=True) + # group_info = await bot.get_group_info(group_id=event.group_id) + # sender_info = await bot.get_group_member_info(group_id=event.group_id, user_id=event.user_id, no_cache=True) # 白名单设定由nontbot侧完成 if event.group_id: @@ -54,19 +54,32 @@ class ChatBot: return if event.user_id in global_config.ban_user_id: return + + user_info=UserInfo( + user_id=event.user_id, + user_nickname=event.sender.nickname, + user_cardname=event.sender.card or None, + platform='qq' + ) + + group_info=GroupInfo( + group_id=event.group_id, + group_name=None, + platform='qq' + ) message_cq=MessageRecvCQ( message_id=event.message_id, - user_id=event.user_id, + user_info=user_info, raw_message=str(event.original_message), - group_id=event.group_id, + group_info=group_info, reply_message=event.reply, platform='qq' ) message_json=message_cq.to_dict() # 进入maimbot - message=MessageRecv(**message_json) + message=MessageRecv(message_json) groupinfo=message.message_info.group_info userinfo=message.message_info.user_info @@ -75,6 +88,7 @@ class ChatBot: # 消息过滤,涉及到config有待更新 chat = await chat_manager.get_or_create_stream(platform=messageinfo.platform, user_info=userinfo, group_info=groupinfo) + message.update_chat_stream(chat) await relationship_manager.update_relationship(chat_stream=chat,) await relationship_manager.update_relationship_value(chat_stream=chat, relationship_value = 0.5) @@ -99,7 +113,7 @@ class ChatBot: await self.storage.store_message(message,chat, topic[0] if topic else None) - is_mentioned = is_mentioned_bot_in_txt(message.processed_plain_text) + is_mentioned = is_mentioned_bot_in_message(message) reply_probability = await willing_manager.change_reply_willing_received( chat_stream=chat, topic=topic[0] if topic else None, diff --git a/src/plugins/chat/chat_stream.py b/src/plugins/chat/chat_stream.py index 36c97bed0..bee679173 100644 --- a/src/plugins/chat/chat_stream.py +++ b/src/plugins/chat/chat_stream.py @@ -1,6 +1,7 @@ import asyncio import hashlib import time +import copy from typing import Dict, Optional from loguru import logger @@ -86,9 +87,9 @@ class ChatManager: self._ensure_collection() self._initialized = True # 在事件循环中启动初始化 - asyncio.create_task(self._initialize()) - # 启动自动保存任务 - asyncio.create_task(self._auto_save_task()) + # asyncio.create_task(self._initialize()) + # # 启动自动保存任务 + # asyncio.create_task(self._auto_save_task()) async def _initialize(self): """异步初始化""" @@ -122,12 +123,18 @@ class ChatManager: self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None ) -> str: """生成聊天流唯一ID""" - # 组合关键信息 - components = [ - platform, - str(user_info.user_id), - str(group_info.group_id) if group_info else "private", - ] + if group_info: + # 组合关键信息 + components = [ + platform, + str(group_info.group_id) + ] + else: + components = [ + platform, + str(user_info.user_id), + "private" + ] # 使用MD5生成唯一ID key = "_".join(components) @@ -153,10 +160,11 @@ class ChatManager: if stream_id in self.streams: stream = self.streams[stream_id] # 更新用户信息和群组信息 + stream.update_active_time() + stream=copy.deepcopy(stream) stream.user_info = user_info if group_info: stream.group_info = group_info - stream.update_active_time() return stream # 检查数据库中是否存在 @@ -180,7 +188,7 @@ class ChatManager: # 保存到内存和数据库 self.streams[stream_id] = stream await self._save_stream(stream) - return stream + return copy.deepcopy(stream) def get_stream(self, stream_id: str) -> Optional[ChatStream]: """通过stream_id获取聊天流""" diff --git a/src/plugins/chat/cq_code.py b/src/plugins/chat/cq_code.py index 7581f8a33..6030b893f 100644 --- a/src/plugins/chat/cq_code.py +++ b/src/plugins/chat/cq_code.py @@ -59,6 +59,7 @@ class CQCode: params: Dict[str, str] group_id: int user_id: int + user_nickname: str group_name: str = "" user_nickname: str = "" translated_segments: Optional[Union[Seg, List[Seg]]] = None @@ -68,9 +69,7 @@ class CQCode: def __post_init__(self): """初始化LLM实例""" - self._llm = LLM_request( - model=global_config.vlm, temperature=0.4, max_tokens=300 - ) + pass def translate(self): """根据CQ码类型进行相应的翻译处理,转换为Seg对象""" @@ -225,8 +224,7 @@ class CQCode: group_id=msg.get("group_id", 0), ) content_seg = Seg( - type="seglist", data=message_obj.message_segments - ) + type="seglist", data=message_obj.message_segment ) else: content_seg = Seg(type="text", data="[空消息]") else: @@ -241,7 +239,7 @@ class CQCode: group_id=msg.get("group_id", 0), ) content_seg = Seg( - type="seglist", data=message_obj.message_segments + type="seglist", data=message_obj.message_segment ) else: content_seg = Seg(type="text", data="[空消息]") @@ -272,7 +270,7 @@ class CQCode: ) segments = [] - if message_obj.user_id == global_config.BOT_QQ: + if message_obj.message_info.user_info.user_id == global_config.BOT_QQ: segments.append( Seg( type="text", data=f"[回复 {global_config.BOT_NICKNAME} 的消息: " @@ -286,7 +284,7 @@ class CQCode: ) ) - segments.append(Seg(type="seglist", data=message_obj.message_segments)) + segments.append(Seg(type="seglist", data=message_obj.message_segment)) segments.append(Seg(type="text", data="]")) return segments else: @@ -305,12 +303,13 @@ class CQCode: class CQCode_tool: @staticmethod - def cq_from_dict_to_class(cq_code: Dict, reply: Optional[Dict] = None) -> CQCode: + def cq_from_dict_to_class(cq_code: Dict,msg ,reply: Optional[Dict] = None) -> CQCode: """ 将CQ码字典转换为CQCode对象 Args: cq_code: CQ码字典 + msg: MessageCQ对象 reply: 回复消息的字典(可选) Returns: @@ -326,7 +325,13 @@ class CQCode_tool: params = cq_code.get("data", {}) instance = CQCode( - type=cq_type, params=params, group_id=0, user_id=0, reply_message=reply + type=cq_type, + params=params, + group_id=msg.message_info.group_info.group_id, + user_id=msg.message_info.user_info.user_id, + user_nickname=msg.message_info.user_info.user_nickname, + group_name=msg.message_info.group_info.group_name, + reply_message=reply ) # 进行翻译处理 @@ -383,6 +388,25 @@ class CQCode_tool: ) # 生成CQ码,设置sub_type=1表示这是表情包 return f"[CQ:image,file=base64://{escaped_base64},sub_type=1]" + + @staticmethod + def create_image_cq_base64(base64_data: str) -> str: + """ + 创建表情包CQ码 + Args: + base64_data: base64编码的表情包数据 + Returns: + 表情包CQ码字符串 + """ + # 转义base64数据 + escaped_base64 = ( + base64_data.replace("&", "&") + .replace("[", "[") + .replace("]", "]") + .replace(",", ",") + ) + # 生成CQ码,设置sub_type=1表示这是表情包 + return f"[CQ:image,file=base64://{escaped_base64},sub_type=0]" cq_code_tool = CQCode_tool() diff --git a/src/plugins/chat/emoji_manager.py b/src/plugins/chat/emoji_manager.py index 837ee245d..f3728ce92 100644 --- a/src/plugins/chat/emoji_manager.py +++ b/src/plugins/chat/emoji_manager.py @@ -239,7 +239,7 @@ class EmojiManager: # 即使表情包已存在,也检查是否需要同步到images集合 description = existing_emoji.get('discription') # 检查是否在images集合中存在 - existing_image = await image_manager.db.db.images.find_one({'hash': image_hash}) + existing_image = image_manager.db.db.images.find_one({'hash': image_hash}) if not existing_image: # 同步到images集合 image_doc = { @@ -249,7 +249,7 @@ class EmojiManager: 'description': description, 'timestamp': int(time.time()) } - await image_manager.db.db.images.update_one( + image_manager.db.db.images.update_one( {'hash': image_hash}, {'$set': image_doc}, upsert=True @@ -260,7 +260,7 @@ class EmojiManager: continue # 检查是否在images集合中已有描述 - existing_description = await image_manager._get_description_from_db(image_hash, 'emoji') + existing_description = image_manager._get_description_from_db(image_hash, 'emoji') if existing_description: description = existing_description @@ -302,13 +302,13 @@ class EmojiManager: 'description': description, 'timestamp': int(time.time()) } - await image_manager.db.db.images.update_one( + image_manager.db.db.images.update_one( {'hash': image_hash}, {'$set': image_doc}, upsert=True ) # 保存描述到image_descriptions集合 - await image_manager._save_description_to_db(image_hash, description, 'emoji') + image_manager._save_description_to_db(image_hash, description, 'emoji') logger.success(f"同步保存到images集合: {filename}") else: logger.warning(f"跳过表情包: {filename}") diff --git a/src/plugins/chat/llm_generator.py b/src/plugins/chat/llm_generator.py index dc019038e..bfd5eec2e 100644 --- a/src/plugins/chat/llm_generator.py +++ b/src/plugins/chat/llm_generator.py @@ -7,7 +7,7 @@ from nonebot import get_driver from ...common.database import Database from ..models.utils_model import LLM_request from .config import global_config -from .message import MessageRecv, MessageThinking, MessageSending +from .message import MessageRecv, MessageThinking, MessageSending,Message from .prompt_builder import prompt_builder from .relationship_manager import relationship_manager from .utils import process_llm_response @@ -144,7 +144,7 @@ class ResponseGenerator: # content: str, content_check: str, reasoning_content: str, reasoning_content_check: str): def _save_to_db( self, - message: Message, + message: MessageRecv, sender_name: str, prompt: str, prompt_check: str, @@ -155,7 +155,7 @@ class ResponseGenerator: self.db.db.reasoning_logs.insert_one( { "time": time.time(), - "group_id": message.group_id, + "chat_id": message.chat_stream.stream_id, "user": sender_name, "message": message.processed_plain_text, "model": self.current_model_type, diff --git a/src/plugins/chat/message.py b/src/plugins/chat/message.py index 408937fad..5eb93d700 100644 --- a/src/plugins/chat/message.py +++ b/src/plugins/chat/message.py @@ -7,7 +7,7 @@ from loguru import logger from .utils_image import image_manager from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase -from .chat_stream import ChatStream +from .chat_stream import ChatStream, chat_manager # 禁用SSL警告 urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) @@ -25,8 +25,8 @@ class MessageRecv(MessageBase): Args: message_dict: MessageCQ序列化后的字典 """ - message_info = BaseMessageInfo(**message_dict.get('message_info', {})) - message_segment = Seg(**message_dict.get('message_segment', {})) + message_info = BaseMessageInfo.from_dict(message_dict.get('message_info', {})) + message_segment = Seg.from_dict(message_dict.get('message_segment', {})) raw_message = message_dict.get('raw_message') super().__init__( @@ -39,7 +39,9 @@ class MessageRecv(MessageBase): self.processed_plain_text = "" # 初始化为空字符串 self.detailed_plain_text = "" # 初始化为空字符串 self.is_emoji=False - + def update_chat_stream(self,chat_stream:ChatStream): + self.chat_stream=chat_stream + async def process(self) -> None: """处理消息内容,生成纯文本和详细文本 @@ -83,12 +85,12 @@ class MessageRecv(MessageBase): return seg.data elif seg.type == 'image': # 如果是base64图片数据 - if isinstance(seg.data, str) and seg.data.startswith(('data:', 'base64:')): + if isinstance(seg.data, str): return await image_manager.get_image_description(seg.data) return '[图片]' elif seg.type == 'emoji': self.is_emoji=True - if isinstance(seg.data, str) and seg.data.startswith(('data:', 'base64:')): + if isinstance(seg.data, str): return await image_manager.get_emoji_description(seg.data) return '[表情]' else: @@ -217,11 +219,11 @@ class MessageProcessBase(Message): return seg.data elif seg.type == 'image': # 如果是base64图片数据 - if isinstance(seg.data, str) and seg.data.startswith(('data:', 'base64:')): + if isinstance(seg.data, str): return await image_manager.get_image_description(seg.data) return '[图片]' elif seg.type == 'emoji': - if isinstance(seg.data, str) and seg.data.startswith(('data:', 'base64:')): + if isinstance(seg.data, str): return await image_manager.get_emoji_description(seg.data) return '[表情]' elif seg.type == 'at': @@ -296,10 +298,15 @@ class MessageSending(MessageProcessBase): self.reply_to_message_id = reply.message_info.message_id if reply else None self.is_head = is_head self.is_emoji = is_emoji - if is_head: + + def set_reply(self, reply: Optional['MessageRecv']) -> None: + """设置回复消息""" + if reply: + self.reply = reply + self.reply_to_message_id = self.reply.message_info.message_id self.message_segment = Seg(type='seglist', data=[ - Seg(type='reply', data=reply.message_info.message_id), - self.message_segment + Seg(type='reply', data=reply.message_info.message_id), + self.message_segment ]) async def process(self) -> None: @@ -329,7 +336,7 @@ class MessageSending(MessageProcessBase): def to_dict(self): ret= super().to_dict() - ret['mesage_info']['user_info']=self.chat_stream.user_info.to_dict() + ret['message_info']['user_info']=self.chat_stream.user_info.to_dict() return ret @dataclass diff --git a/src/plugins/chat/message_base.py b/src/plugins/chat/message_base.py index 7b76403de..d17c2c357 100644 --- a/src/plugins/chat/message_base.py +++ b/src/plugins/chat/message_base.py @@ -2,7 +2,7 @@ from dataclasses import dataclass, asdict from typing import List, Optional, Union, Any, Dict @dataclass -class Seg(dict): +class Seg: """消息片段类,用于表示消息的不同部分 Attributes: @@ -15,47 +15,25 @@ class Seg(dict): """ type: str data: Union[str, List['Seg']] - translated_data: Optional[str] = None + - def __init__(self, type: str, data: Union[str, List['Seg']], translated_data: Optional[str] = None): - """初始化实例,确保字典和属性同步""" - # 先初始化字典 - super().__init__(type=type, data=data) - if translated_data is not None: - self['translated_data'] = translated_data - - # 再初始化属性 - object.__setattr__(self, 'type', type) - object.__setattr__(self, 'data', data) - object.__setattr__(self, 'translated_data', translated_data) + # def __init__(self, type: str, data: Union[str, List['Seg']],): + # """初始化实例,确保字典和属性同步""" + # # 先初始化字典 + # self.type = type + # self.data = data - # 验证数据类型 - self._validate_data() - - def _validate_data(self) -> None: - """验证数据类型的正确性""" - if self.type == 'seglist' and not isinstance(self.data, list): - raise ValueError("seglist类型的data必须是列表") - elif self.type == 'text' and not isinstance(self.data, str): - raise ValueError("text类型的data必须是字符串") - elif self.type == 'image' and not isinstance(self.data, str): - raise ValueError("image类型的data必须是字符串") - - def __setattr__(self, name: str, value: Any) -> None: - """重写属性设置,同时更新字典值""" - # 更新属性 - object.__setattr__(self, name, value) - # 同步更新字典 - if name in ['type', 'data', 'translated_data']: - self[name] = value - - def __setitem__(self, key: str, value: Any) -> None: - """重写字典值设置,同时更新属性""" - # 更新字典 - super().__setitem__(key, value) - # 同步更新属性 - if key in ['type', 'data', 'translated_data']: - object.__setattr__(self, key, value) + @classmethod + def from_dict(cls, data: Dict) -> 'Seg': + """从字典创建Seg实例""" + type=data.get('type') + data=data.get('data') + if type == 'seglist': + data = [Seg.from_dict(seg) for seg in data] + return cls( + type=type, + data=data + ) def to_dict(self) -> Dict: """转换为字典格式""" @@ -64,8 +42,6 @@ class Seg(dict): result['data'] = [seg.to_dict() for seg in self.data] else: result['data'] = self.data - if self.translated_data is not None: - result['translated_data'] = self.translated_data return result @dataclass @@ -79,6 +55,7 @@ class GroupInfo: """转换为字典格式""" return {k: v for k, v in asdict(self).items() if v is not None} + @classmethod def from_dict(cls, data: Dict) -> 'GroupInfo': """从字典创建GroupInfo实例 @@ -106,6 +83,7 @@ class UserInfo: """转换为字典格式""" return {k: v for k, v in asdict(self).items() if v is not None} + @classmethod def from_dict(cls, data: Dict) -> 'UserInfo': """从字典创建UserInfo实例 @@ -126,7 +104,7 @@ class UserInfo: class BaseMessageInfo: """消息信息类""" platform: Optional[str] = None - message_id: Optional[int,str] = None + message_id: Union[str,int,None] = None time: Optional[int] = None group_info: Optional[GroupInfo] = None user_info: Optional[UserInfo] = None @@ -141,6 +119,25 @@ class BaseMessageInfo: else: result[field] = value return result + @classmethod + def from_dict(cls, data: Dict) -> 'BaseMessageInfo': + """从字典创建BaseMessageInfo实例 + + Args: + data: 包含必要字段的字典 + + Returns: + BaseMessageInfo: 新的实例 + """ + group_info = GroupInfo(**data.get('group_info', {})) + user_info = UserInfo(**data.get('user_info', {})) + return cls( + platform=data.get('platform'), + message_id=data.get('message_id'), + time=data.get('time'), + group_info=group_info, + user_info=user_info + ) @dataclass class MessageBase: diff --git a/src/plugins/chat/message_cq.py b/src/plugins/chat/message_cq.py index 4d7489bbf..6bfa47c3f 100644 --- a/src/plugins/chat/message_cq.py +++ b/src/plugins/chat/message_cq.py @@ -27,27 +27,10 @@ class MessageCQ(MessageBase): def __init__( self, message_id: int, - user_id: int, - group_id: Optional[int] = None, + user_info: UserInfo, + group_info: Optional[GroupInfo] = None, platform: str = "qq" ): - # 构造用户信息 - user_info = UserInfo( - platform=platform, - user_id=user_id, - user_nickname=get_user_nickname(user_id), - user_cardname=get_user_cardname(user_id) if group_id else None - ) - - # 构造群组信息(如果有) - group_info = None - if group_id: - group_info = GroupInfo( - platform=platform, - group_id=group_id, - group_name=get_groupname(group_id) - ) - # 构造基础消息信息 message_info = BaseMessageInfo( platform=platform, @@ -56,7 +39,6 @@ class MessageCQ(MessageBase): group_info=group_info, user_info=user_info ) - # 调用父类初始化,message_segment 由子类设置 super().__init__( message_info=message_info, @@ -71,14 +53,17 @@ class MessageRecvCQ(MessageCQ): def __init__( self, message_id: int, - user_id: int, + user_info: UserInfo, raw_message: str, - group_id: Optional[int] = None, + group_info: Optional[GroupInfo] = None, + platform: str = "qq", reply_message: Optional[Dict] = None, - platform: str = "qq" ): # 调用父类初始化 - super().__init__(message_id, user_id, group_id, platform) + super().__init__(message_id, user_info, group_info, platform) + + if group_info.group_name is None: + group_info.group_name = get_groupname(group_info.group_id) # 解析消息段 self.message_segment = self._parse_message(raw_message, reply_message) @@ -117,7 +102,7 @@ class MessageRecvCQ(MessageCQ): # 转换CQ码为Seg对象 for code_item in cq_code_dict_list: - message_obj = cq_code_tool.cq_from_dict_to_class(code_item, reply=reply_message) + message_obj = cq_code_tool.cq_from_dict_to_class(code_item,msg=self,reply=reply_message) if message_obj.translated_segments: segments.append(message_obj.translated_segments) @@ -142,13 +127,14 @@ class MessageSendCQ(MessageCQ): data: Dict ): # 调用父类初始化 - message_info = BaseMessageInfo(**data.get('message_info', {})) - message_segment = Seg(**data.get('message_segment', {})) + message_info = BaseMessageInfo.from_dict(data.get('message_info', {})) + message_segment = Seg.from_dict(data.get('message_segment', {})) super().__init__( message_info.message_id, - message_info.user_info.user_id, - message_info.group_info.group_id if message_info.group_info else None, - message_info.platform) + message_info.user_info, + message_info.group_info if message_info.group_info else None, + message_info.platform + ) self.message_segment = message_segment self.raw_message = self._generate_raw_message() @@ -171,11 +157,9 @@ class MessageSendCQ(MessageCQ): if seg.type == 'text': return str(seg.data) elif seg.type == 'image': - # 如果是base64图片数据 - if seg.data.startswith(('data:', 'base64:')): - return cq_code_tool.create_emoji_cq_base64(seg.data) - # 如果是表情包(本地文件) - return cq_code_tool.create_emoji_cq(seg.data) + return cq_code_tool.create_image_cq_base64(seg.data) + elif seg.type == 'emoji': + return cq_code_tool.create_emoji_cq_base64(seg.data) elif seg.type == 'at': return f"[CQ:at,qq={seg.data}]" elif seg.type == 'reply': diff --git a/src/plugins/chat/message_sender.py b/src/plugins/chat/message_sender.py index ed91b614e..2c3880bb8 100644 --- a/src/plugins/chat/message_sender.py +++ b/src/plugins/chat/message_sender.py @@ -41,10 +41,10 @@ class Message_Sender: message=message_send.raw_message, auto_escape=False ) - print(f"\033[1;34m[调试]\033[0m 发送消息{message}成功") + print(f"\033[1;34m[调试]\033[0m 发送消息{message.processed_plain_text}成功") except Exception as e: print(f"发生错误 {e}") - print(f"\033[1;34m[调试]\033[0m 发送消息{message}失败") + print(f"\033[1;34m[调试]\033[0m 发送消息{message.processed_plain_text}失败") else: try: await self._current_bot.send_private_msg( @@ -52,10 +52,10 @@ class Message_Sender: message=message_send.raw_message, auto_escape=False ) - print(f"\033[1;34m[调试]\033[0m 发送消息{message}成功") + print(f"\033[1;34m[调试]\033[0m 发送消息{message.processed_plain_text}成功") except Exception as e: print(f"发生错误 {e}") - print(f"\033[1;34m[调试]\033[0m 发送消息{message}失败") + print(f"\033[1;34m[调试]\033[0m 发送消息{message.processed_plain_text}失败") class MessageContainer: @@ -137,11 +137,7 @@ class MessageManager: return self.containers[chat_id] def add_message(self, message: Union[MessageThinking, MessageSending, MessageSet]) -> None: - chat_stream = chat_manager.get_stream_by_info( - platform=message.message_info.platform, - user_info=message.message_info.user_info, - group_info=message.message_info.group_info - ) + chat_stream = message.chat_stream if not chat_stream: raise ValueError("无法找到对应的聊天流") container = self.get_container(chat_stream.stream_id) @@ -165,13 +161,14 @@ class MessageManager: else: print(f"\033[1;34m[调试]\033[0m 消息'{message_earliest.processed_plain_text}'正在发送中") if message_earliest.is_head and message_earliest.update_thinking_time() > 30: - await message_sender.send_message(message_earliest) + await message_sender.send_message(message_earliest.set_reply()) else: await message_sender.send_message(message_earliest) - if message_earliest.is_emoji: - message_earliest.processed_plain_text = "[表情包]" - await self.storage.store_message(message_earliest, None) + # if message_earliest.is_emoji: + # message_earliest.processed_plain_text = "[表情包]" + await message_earliest.process() + await self.storage.store_message(message_earliest, message_earliest.chat_stream,None) container.remove_message(message_earliest) @@ -184,13 +181,14 @@ class MessageManager: try: if msg.is_head and msg.update_thinking_time() > 30: - await message_sender.send_group_message(chat_id, msg.processed_plain_text, auto_escape=False, reply_message_id=msg.reply_message_id) + await message_sender.send_message(msg.set_reply()) else: - await message_sender.send_group_message(chat_id, msg.processed_plain_text, auto_escape=False) + await message_sender.send_message(msg) - if msg.is_emoji: - msg.processed_plain_text = "[表情包]" - await self.storage.store_message(msg, None) + # if msg.is_emoji: + # msg.processed_plain_text = "[表情包]" + await msg.process() + await self.storage.store_message(msg,msg.chat_stream, None) if not container.remove_message(msg): print("\033[1;33m[警告]\033[0m 尝试删除不存在的消息") diff --git a/src/plugins/chat/relationship_manager.py b/src/plugins/chat/relationship_manager.py index c08b962ed..5552aee8c 100644 --- a/src/plugins/chat/relationship_manager.py +++ b/src/plugins/chat/relationship_manager.py @@ -23,12 +23,12 @@ class Relationship: saved = False def __init__(self, chat:ChatStream=None,data:dict=None): - self.user_id=chat.user_info.user_id if chat.user_info else data.get('user_id',0) - self.platform=chat.platform if chat.user_info else data.get('platform','') - self.nickname=chat.user_info.user_nickname if chat.user_info else data.get('nickname','') - self.relationship_value=data.get('relationship_value',0) - self.age=data.get('age',0) - self.gender=data.get('gender','') + self.user_id=chat.user_info.user_id if chat else data.get('user_id',0) + self.platform=chat.platform if chat else data.get('platform','') + self.nickname=chat.user_info.user_nickname if chat else data.get('nickname','') + self.relationship_value=data.get('relationship_value',0) if data else 0 + self.age=data.get('age',0) if data else 0 + self.gender=data.get('gender','') if data else '' class RelationshipManager: diff --git a/src/plugins/chat/utils_image.py b/src/plugins/chat/utils_image.py index aba09714c..ac3ff5ac4 100644 --- a/src/plugins/chat/utils_image.py +++ b/src/plugins/chat/utils_image.py @@ -59,7 +59,7 @@ class ImageManager: self.db.db.image_descriptions.create_index([('hash', 1)], unique=True) self.db.db.image_descriptions.create_index([('type', 1)]) - async def _get_description_from_db(self, image_hash: str, description_type: str) -> Optional[str]: + def _get_description_from_db(self, image_hash: str, description_type: str) -> Optional[str]: """从数据库获取图片描述 Args: @@ -69,13 +69,13 @@ class ImageManager: Returns: Optional[str]: 描述文本,如果不存在则返回None """ - result = await self.db.db.image_descriptions.find_one({ + result= self.db.db.image_descriptions.find_one({ 'hash': image_hash, 'type': description_type }) return result['description'] if result else None - async def _save_description_to_db(self, image_hash: str, description: str, description_type: str) -> None: + def _save_description_to_db(self, image_hash: str, description: str, description_type: str) -> None: """保存图片描述到数据库 Args: @@ -83,7 +83,7 @@ class ImageManager: description: 描述文本 description_type: 描述类型 ('emoji' 或 'image') """ - await self.db.db.image_descriptions.update_one( + self.db.db.image_descriptions.update_one( {'hash': image_hash, 'type': description_type}, { '$set': { @@ -253,8 +253,9 @@ class ImageManager: image_hash = hashlib.md5(image_bytes).hexdigest() # 查询缓存的描述 - cached_description = await self._get_description_from_db(image_hash, 'emoji') + cached_description = self._get_description_from_db(image_hash, 'emoji') if cached_description: + logger.info(f"缓存表情包描述: {cached_description}") return f"[表情包:{cached_description}]" # 调用AI获取描述 @@ -281,7 +282,7 @@ class ImageManager: 'description': description, 'timestamp': timestamp } - await self.db.db.images.update_one( + self.db.db.images.update_one( {'hash': image_hash}, {'$set': image_doc}, upsert=True @@ -291,7 +292,7 @@ class ImageManager: logger.error(f"保存表情包文件失败: {str(e)}") # 保存描述到数据库 - await self._save_description_to_db(image_hash, description, 'emoji') + self._save_description_to_db(image_hash, description, 'emoji') return f"[表情包:{description}]" except Exception as e: @@ -306,7 +307,7 @@ class ImageManager: image_hash = hashlib.md5(image_bytes).hexdigest() # 查询缓存的描述 - cached_description = await self._get_description_from_db(image_hash, 'image') + cached_description = self._get_description_from_db(image_hash, 'image') if cached_description: return f"[图片:{cached_description}]" @@ -334,7 +335,7 @@ class ImageManager: 'description': description, 'timestamp': timestamp } - await self.db.db.images.update_one( + self.db.db.images.update_one( {'hash': image_hash}, {'$set': image_doc}, upsert=True @@ -357,80 +358,6 @@ class ImageManager: image_manager = ImageManager() -def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 1024 * 1024) -> str: - """压缩base64格式的图片到指定大小 - Args: - base64_data: base64编码的图片数据 - target_size: 目标文件大小(字节),默认0.8MB - Returns: - str: 压缩后的base64图片数据 - """ - try: - # 将base64转换为字节数据 - image_data = base64.b64decode(base64_data) - - # 如果已经小于目标大小,直接返回原图 - if len(image_data) <= 2*1024*1024: - return base64_data - - # 将字节数据转换为图片对象 - img = Image.open(io.BytesIO(image_data)) - - # 获取原始尺寸 - original_width, original_height = img.size - - # 计算缩放比例 - scale = min(1.0, (target_size / len(image_data)) ** 0.5) - - # 计算新的尺寸 - new_width = int(original_width * scale) - new_height = int(original_height * scale) - - # 创建内存缓冲区 - output_buffer = io.BytesIO() - - # 如果是GIF,处理所有帧 - if getattr(img, "is_animated", False): - frames = [] - for frame_idx in range(img.n_frames): - img.seek(frame_idx) - new_frame = img.copy() - new_frame = new_frame.resize((new_width//2, new_height//2), Image.Resampling.LANCZOS) # 动图折上折 - frames.append(new_frame) - - # 保存到缓冲区 - frames[0].save( - output_buffer, - format='GIF', - save_all=True, - append_images=frames[1:], - optimize=True, - duration=img.info.get('duration', 100), - loop=img.info.get('loop', 0) - ) - else: - # 处理静态图片 - resized_img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) - - # 保存到缓冲区,保持原始格式 - if img.format == 'PNG' and img.mode in ('RGBA', 'LA'): - resized_img.save(output_buffer, format='PNG', optimize=True) - else: - resized_img.save(output_buffer, format='JPEG', quality=95, optimize=True) - - # 获取压缩后的数据并转换为base64 - compressed_data = output_buffer.getvalue() - logger.success(f"压缩图片: {original_width}x{original_height} -> {new_width}x{new_height}") - logger.info(f"压缩前大小: {len(image_data)/1024:.1f}KB, 压缩后大小: {len(compressed_data)/1024:.1f}KB") - - return base64.b64encode(compressed_data).decode('utf-8') - - except Exception as e: - logger.error(f"压缩图片失败: {str(e)}") - import traceback - logger.error(traceback.format_exc()) - return base64_data - def image_path_to_base64(image_path: str) -> str: """将图片路径转换为base64编码 Args: diff --git a/src/plugins/models/utils_model.py b/src/plugins/models/utils_model.py index c70c26ff9..56ed80693 100644 --- a/src/plugins/models/utils_model.py +++ b/src/plugins/models/utils_model.py @@ -7,10 +7,11 @@ from typing import Tuple, Union import aiohttp from loguru import logger from nonebot import get_driver - +import base64 +from PIL import Image +import io from ...common.database import Database from ..chat.config import global_config -from ..chat.utils_image import compress_base64_image_by_scale driver = get_driver() config = driver.config @@ -405,3 +406,77 @@ class LLM_request: ) return embedding +def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 1024 * 1024) -> str: + """压缩base64格式的图片到指定大小 + Args: + base64_data: base64编码的图片数据 + target_size: 目标文件大小(字节),默认0.8MB + Returns: + str: 压缩后的base64图片数据 + """ + try: + # 将base64转换为字节数据 + image_data = base64.b64decode(base64_data) + + # 如果已经小于目标大小,直接返回原图 + if len(image_data) <= 2*1024*1024: + return base64_data + + # 将字节数据转换为图片对象 + img = Image.open(io.BytesIO(image_data)) + + # 获取原始尺寸 + original_width, original_height = img.size + + # 计算缩放比例 + scale = min(1.0, (target_size / len(image_data)) ** 0.5) + + # 计算新的尺寸 + new_width = int(original_width * scale) + new_height = int(original_height * scale) + + # 创建内存缓冲区 + output_buffer = io.BytesIO() + + # 如果是GIF,处理所有帧 + if getattr(img, "is_animated", False): + frames = [] + for frame_idx in range(img.n_frames): + img.seek(frame_idx) + new_frame = img.copy() + new_frame = new_frame.resize((new_width//2, new_height//2), Image.Resampling.LANCZOS) # 动图折上折 + frames.append(new_frame) + + # 保存到缓冲区 + frames[0].save( + output_buffer, + format='GIF', + save_all=True, + append_images=frames[1:], + optimize=True, + duration=img.info.get('duration', 100), + loop=img.info.get('loop', 0) + ) + else: + # 处理静态图片 + resized_img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) + + # 保存到缓冲区,保持原始格式 + if img.format == 'PNG' and img.mode in ('RGBA', 'LA'): + resized_img.save(output_buffer, format='PNG', optimize=True) + else: + resized_img.save(output_buffer, format='JPEG', quality=95, optimize=True) + + # 获取压缩后的数据并转换为base64 + compressed_data = output_buffer.getvalue() + logger.success(f"压缩图片: {original_width}x{original_height} -> {new_width}x{new_height}") + logger.info(f"压缩前大小: {len(image_data)/1024:.1f}KB, 压缩后大小: {len(compressed_data)/1024:.1f}KB") + + return base64.b64encode(compressed_data).decode('utf-8') + + except Exception as e: + logger.error(f"压缩图片失败: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + return base64_data +