Merge pull request #181 from tcmofashi/refractor

Refractor: 史上最好的消息流重构和图片管理系统
This commit is contained in:
tcmofashi
2025-03-11 17:22:04 +08:00
committed by GitHub
19 changed files with 2326 additions and 1221 deletions

View File

@@ -1,3 +0,0 @@
{
"editor.formatOnSave": true
}

View File

@@ -16,6 +16,7 @@ from .config import global_config
from .emoji_manager import emoji_manager from .emoji_manager import emoji_manager
from .relationship_manager import relationship_manager from .relationship_manager import relationship_manager
from .willing_manager import willing_manager from .willing_manager import willing_manager
from .chat_stream import chat_manager
from ..memory_system.memory import hippocampus, memory_graph from ..memory_system.memory import hippocampus, memory_graph
from .bot import ChatBot from .bot import ChatBot
from .message_sender import message_manager, message_sender from .message_sender import message_manager, message_sender
@@ -98,6 +99,8 @@ async def _(bot: Bot):
asyncio.create_task(emoji_manager._periodic_scan(interval_MINS=global_config.EMOJI_REGISTER_INTERVAL)) asyncio.create_task(emoji_manager._periodic_scan(interval_MINS=global_config.EMOJI_REGISTER_INTERVAL))
logger.success("-----------开始偷表情包!-----------") logger.success("-----------开始偷表情包!-----------")
asyncio.create_task(chat_manager._initialize())
asyncio.create_task(chat_manager._auto_save_task())
@group_msg.handle() @group_msg.handle()

View File

@@ -1,28 +1,28 @@
import re import re
import time import time
from random import random from random import random
from loguru import logger from loguru import logger
from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent
from ..memory_system.memory import hippocampus from ..memory_system.memory import hippocampus
from ..moods.moods import MoodManager # 导入情绪管理器 from ..moods.moods import MoodManager # 导入情绪管理器
from .config import global_config from .config import global_config
from .cq_code import CQCode # 导入CQCode模块 from .cq_code import CQCode,cq_code_tool # 导入CQCode模块
from .emoji_manager import emoji_manager # 导入表情包管理器 from .emoji_manager import emoji_manager # 导入表情包管理器
from .llm_generator import ResponseGenerator from .llm_generator import ResponseGenerator
from .message import ( from .message import MessageSending, MessageRecv, MessageThinking, MessageSet
Message, from .message_cq import (
Message_Sending, MessageRecvCQ,
Message_Thinking, # 导入 Message_Thinking 类
MessageSet,
) )
from .chat_stream import chat_manager
from .message_sender import message_manager # 导入新的消息管理器 from .message_sender import message_manager # 导入新的消息管理器
from .relationship_manager import relationship_manager from .relationship_manager import relationship_manager
from .storage import MessageStorage from .storage import MessageStorage
from .utils import calculate_typing_time, is_mentioned_bot_in_txt from .utils import calculate_typing_time, is_mentioned_bot_in_message
from .utils_image import image_path_to_base64
from .willing_manager import willing_manager # 导入意愿管理器 from .willing_manager import willing_manager # 导入意愿管理器
from .message_base import UserInfo, GroupInfo, Seg
class ChatBot: class ChatBot:
def __init__(self): def __init__(self):
@@ -44,35 +44,61 @@ class ChatBot:
async def handle_message(self, event: GroupMessageEvent, bot: Bot) -> None: async def handle_message(self, event: GroupMessageEvent, bot: Bot) -> None:
"""处理收到的群消息""" """处理收到的群消息"""
if event.group_id not in global_config.talk_allowed_groups:
return
self.bot = bot # 更新 bot 实例 self.bot = bot # 更新 bot 实例
# group_info = await bot.get_group_info(group_id=event.group_id)
# sender_info = await bot.get_group_member_info(group_id=event.group_id, user_id=event.user_id, no_cache=True)
# 白名单设定由nontbot侧完成
if event.group_id:
if event.group_id not in global_config.talk_allowed_groups:
return
if event.user_id in global_config.ban_user_id: if event.user_id in global_config.ban_user_id:
return return
group_info = await bot.get_group_info(group_id=event.group_id) user_info=UserInfo(
sender_info = await bot.get_group_member_info(group_id=event.group_id, user_id=event.user_id, no_cache=True)
await relationship_manager.update_relationship(user_id=event.user_id, data=sender_info)
await relationship_manager.update_relationship_value(user_id=event.user_id, relationship_value=0.5)
message = Message(
group_id=event.group_id,
user_id=event.user_id, user_id=event.user_id,
message_id=event.message_id, user_nickname=event.sender.nickname,
user_cardname=sender_info['card'], user_cardname=event.sender.card or None,
raw_message=str(event.original_message), platform='qq'
plain_text=event.get_plaintext(),
reply_message=event.reply,
) )
await message.initialize()
group_info=GroupInfo(
group_id=event.group_id,
group_name=None,
platform='qq'
)
message_cq=MessageRecvCQ(
message_id=event.message_id,
user_info=user_info,
raw_message=str(event.original_message),
group_info=group_info,
reply_message=event.reply,
platform='qq'
)
message_json=message_cq.to_dict()
# 进入maimbot
message=MessageRecv(message_json)
groupinfo=message.message_info.group_info
userinfo=message.message_info.user_info
messageinfo=message.message_info
# 消息过滤涉及到config有待更新
chat = await chat_manager.get_or_create_stream(platform=messageinfo.platform, user_info=userinfo, group_info=groupinfo)
message.update_chat_stream(chat)
await relationship_manager.update_relationship(chat_stream=chat,)
await relationship_manager.update_relationship_value(chat_stream=chat, relationship_value = 0.5)
await message.process()
# 过滤词 # 过滤词
for word in global_config.ban_words: for word in global_config.ban_words:
if word in message.detailed_plain_text: if word in message.processed_plain_text:
logger.info( logger.info(
f"[{message.group_name}]{message.user_nickname}:{message.processed_plain_text}") f"[{groupinfo.group_name}]{userinfo.user_nickname}:{message.processed_plain_text}")
logger.info(f"[过滤词识别]消息中含有{word}filtered") logger.info(f"[过滤词识别]消息中含有{word}filtered")
return return
@@ -84,7 +110,9 @@ class ChatBot:
logger.info(f"[正则表达式过滤]消息匹配到{pattern}filtered") logger.info(f"[正则表达式过滤]消息匹配到{pattern}filtered")
return return
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(message.time)) current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(messageinfo.time))
# topic=await topic_identifier.identify_topic_llm(message.processed_plain_text) # topic=await topic_identifier.identify_topic_llm(message.processed_plain_text)
topic = '' topic = ''
@@ -94,46 +122,59 @@ class ChatBot:
f"的激活度:{interested_rate}") f"的激活度:{interested_rate}")
# logger.info(f"\033[1;32m[主题识别]\033[0m 使用{global_config.topic_extract}主题: {topic}") # logger.info(f"\033[1;32m[主题识别]\033[0m 使用{global_config.topic_extract}主题: {topic}")
await self.storage.store_message(message, topic[0] if topic else None) await self.storage.store_message(message,chat, topic[0] if topic else None)
is_mentioned = is_mentioned_bot_in_txt(message.processed_plain_text) is_mentioned = is_mentioned_bot_in_message(message)
reply_probability = willing_manager.change_reply_willing_received( reply_probability = await willing_manager.change_reply_willing_received(
event.group_id, chat_stream=chat,
topic[0] if topic else None, topic=topic[0] if topic else None,
is_mentioned, is_mentioned_bot=is_mentioned,
global_config, config=global_config,
event.user_id, is_emoji=message.is_emoji,
message.is_emoji, interested_rate=interested_rate
interested_rate
) )
current_willing = willing_manager.get_willing(event.group_id) current_willing = willing_manager.get_willing(chat_stream=chat)
logger.info( logger.info(
f"[{current_time}][{message.group_name}]{message.user_nickname}:" f"[{current_time}][{chat.group_info.group_name}]{chat.user_info.user_nickname}:"
f"{message.processed_plain_text}[回复意愿:{current_willing:.2f}][概率:{reply_probability * 100:.1f}%]") f"{message.processed_plain_text}[回复意愿:{current_willing:.2f}][概率:{reply_probability * 100:.1f}%]"
)
response = "" response = None
if random() < reply_probability: if random() < reply_probability:
bot_user_info=UserInfo(
user_id=global_config.BOT_QQ,
user_nickname=global_config.BOT_NICKNAME,
platform=messageinfo.platform
)
tinking_time_point = round(time.time(), 2) tinking_time_point = round(time.time(), 2)
think_id = 'mt' + str(tinking_time_point) think_id = 'mt' + str(tinking_time_point)
thinking_message = Message_Thinking(message=message, message_id=think_id) thinking_message = MessageThinking(
message_id=think_id,
chat_stream=chat,
bot_user_info=bot_user_info,
reply=message
)
message_manager.add_message(thinking_message) message_manager.add_message(thinking_message)
willing_manager.change_reply_willing_sent(thinking_message.group_id) willing_manager.change_reply_willing_sent(chat)
response,raw_content = await self.gpt.generate_response(message) response,raw_content = await self.gpt.generate_response(message)
# print(f"response: {response}")
if response: if response:
container = message_manager.get_container(event.group_id) # print(f"有response: {response}")
container = message_manager.get_container(chat.stream_id)
thinking_message = None thinking_message = None
# 找到message,删除 # 找到message,删除
# print(f"开始找思考消息")
for msg in container.messages: for msg in container.messages:
if isinstance(msg, Message_Thinking) and msg.message_id == think_id: if isinstance(msg, MessageThinking) and msg.message_info.message_id == think_id:
# print(f"找到思考消息: {msg}")
thinking_message = msg thinking_message = msg
container.messages.remove(msg) container.messages.remove(msg)
# print(f"\033[1;32m[思考消息删除]\033[0m 已找到思考消息对象,开始删除")
break break
# 如果找不到思考消息,直接返回 # 如果找不到思考消息,直接返回
@@ -143,41 +184,38 @@ class ChatBot:
# 记录开始思考的时间,避免从思考到回复的时间太久 # 记录开始思考的时间,避免从思考到回复的时间太久
thinking_start_time = thinking_message.thinking_start_time thinking_start_time = thinking_message.thinking_start_time
message_set = MessageSet(event.group_id, global_config.BOT_QQ, message_set = MessageSet(chat, think_id)
think_id) # 发送消息的id和产生发送消息的message_thinking是一致的
#计算打字时间1是为了模拟打字2是避免多条回复乱序 #计算打字时间1是为了模拟打字2是避免多条回复乱序
accu_typing_time = 0 accu_typing_time = 0
# print(f"\033[1;32m[开始回复]\033[0m 开始将回复1载入发送容器")
mark_head = False mark_head = False
for msg in response: for msg in response:
# print(f"\033[1;32m[回复内容]\033[0m {msg}") # print(f"\033[1;32m[回复内容]\033[0m {msg}")
# 通过时间改变时间戳 # 通过时间改变时间戳
typing_time = calculate_typing_time(msg) typing_time = calculate_typing_time(msg)
print(f"typing_time: {typing_time}")
accu_typing_time += typing_time accu_typing_time += typing_time
timepoint = tinking_time_point + accu_typing_time timepoint = tinking_time_point + accu_typing_time
message_segment = Seg(type='text', data=msg)
bot_message = Message_Sending( print(f"message_segment: {message_segment}")
group_id=event.group_id, bot_message = MessageSending(
user_id=global_config.BOT_QQ,
message_id=think_id, message_id=think_id,
raw_message=msg, chat_stream=chat,
plain_text=msg, bot_user_info=bot_user_info,
processed_plain_text=msg, message_segment=message_segment,
user_nickname=global_config.BOT_NICKNAME, reply=message,
group_name=message.group_name, is_head=not mark_head,
time=timepoint, # 记录了回复生成的时间 is_emoji=False
thinking_start_time=thinking_start_time, # 记录了思考开始的时间
reply_message_id=message.message_id
) )
await bot_message.initialize() print(f"bot_message: {bot_message}")
if not mark_head: if not mark_head:
bot_message.is_head = True
mark_head = True mark_head = True
print(f"添加消息到message_set: {bot_message}")
message_set.add_message(bot_message) message_set.add_message(bot_message)
# message_set 可以直接加入 message_manager # message_set 可以直接加入 message_manager
# print(f"\033[1;32m[回复]\033[0m 将回复载入发送容器") # print(f"\033[1;32m[回复]\033[0m 将回复载入发送容器")
print(f"添加message_set到message_manager")
message_manager.add_message(message_set) message_manager.add_message(message_set)
bot_response_time = tinking_time_point bot_response_time = tinking_time_point
@@ -189,31 +227,25 @@ class ChatBot:
if emoji_raw != None: if emoji_raw != None:
emoji_path, description = emoji_raw emoji_path, description = emoji_raw
emoji_cq = CQCode.create_emoji_cq(emoji_path) emoji_cq = image_path_to_base64(emoji_path)
if random() < 0.5: if random() < 0.5:
bot_response_time = tinking_time_point - 1 bot_response_time = tinking_time_point - 1
else: else:
bot_response_time = bot_response_time + 1 bot_response_time = bot_response_time + 1
bot_message = Message_Sending( message_segment = Seg(type='emoji', data=emoji_cq)
group_id=event.group_id, bot_message = MessageSending(
user_id=global_config.BOT_QQ, message_id=think_id,
message_id=0, chat_stream=chat,
raw_message=emoji_cq, bot_user_info=bot_user_info,
plain_text=emoji_cq, message_segment=message_segment,
processed_plain_text=emoji_cq, reply=message,
detailed_plain_text=description, is_head=False,
user_nickname=global_config.BOT_NICKNAME, is_emoji=True
group_name=message.group_name,
time=bot_response_time,
is_emoji=True,
translate_cq=False,
thinking_start_time=thinking_start_time,
# reply_message_id=message.message_id
) )
await bot_message.initialize()
message_manager.add_message(bot_message) message_manager.add_message(bot_message)
emotion = await self.gpt._get_emotion_tags(raw_content) emotion = await self.gpt._get_emotion_tags(raw_content)
logger.debug(f"'{response}' 获取到的情感标签为:{emotion}") logger.debug(f"'{response}' 获取到的情感标签为:{emotion}")
valuedict = { valuedict = {
@@ -225,12 +257,13 @@ class ChatBot:
'fearful': -0.7, 'fearful': -0.7,
'neutral': 0.1 'neutral': 0.1
} }
await relationship_manager.update_relationship_value(message.user_id, await relationship_manager.update_relationship_value(chat_stream=chat, relationship_value=valuedict[emotion[0]])
relationship_value=valuedict[emotion[0]])
# 使用情绪管理器更新情绪 # 使用情绪管理器更新情绪
self.mood_manager.update_mood_from_emotion(emotion[0], global_config.mood_intensity_factor) self.mood_manager.update_mood_from_emotion(emotion[0], global_config.mood_intensity_factor)
# willing_manager.change_reply_willing_after_sent(event.group_id) # willing_manager.change_reply_willing_after_sent(
# chat_stream=chat
# )
# 创建全局ChatBot实例 # 创建全局ChatBot实例

View File

@@ -0,0 +1,226 @@
import asyncio
import hashlib
import time
import copy
from typing import Dict, Optional
from loguru import logger
from ...common.database import Database
from .message_base import GroupInfo, UserInfo
class ChatStream:
"""聊天流对象,存储一个完整的聊天上下文"""
def __init__(
self,
stream_id: str,
platform: str,
user_info: UserInfo,
group_info: Optional[GroupInfo] = None,
data: dict = None,
):
self.stream_id = stream_id
self.platform = platform
self.user_info = user_info
self.group_info = group_info
self.create_time = (
data.get("create_time", int(time.time())) if data else int(time.time())
)
self.last_active_time = (
data.get("last_active_time", self.create_time) if data else self.create_time
)
self.saved = False
def to_dict(self) -> dict:
"""转换为字典格式"""
result = {
"stream_id": self.stream_id,
"platform": self.platform,
"user_info": self.user_info.to_dict() if self.user_info else None,
"group_info": self.group_info.to_dict() if self.group_info else None,
"create_time": self.create_time,
"last_active_time": self.last_active_time,
}
return result
@classmethod
def from_dict(cls, data: dict) -> "ChatStream":
"""从字典创建实例"""
user_info = (
UserInfo(**data.get("user_info", {})) if data.get("user_info") else None
)
group_info = (
GroupInfo(**data.get("group_info", {})) if data.get("group_info") else None
)
return cls(
stream_id=data["stream_id"],
platform=data["platform"],
user_info=user_info,
group_info=group_info,
data=data,
)
def update_active_time(self):
"""更新最后活跃时间"""
self.last_active_time = int(time.time())
self.saved = False
class ChatManager:
"""聊天管理器,管理所有聊天流"""
_instance = None
_initialized = False
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not self._initialized:
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
self.db = Database.get_instance()
self._ensure_collection()
self._initialized = True
# 在事件循环中启动初始化
# asyncio.create_task(self._initialize())
# # 启动自动保存任务
# asyncio.create_task(self._auto_save_task())
async def _initialize(self):
"""异步初始化"""
try:
await self.load_all_streams()
logger.success(f"聊天管理器已启动,已加载 {len(self.streams)} 个聊天流")
except Exception as e:
logger.error(f"聊天管理器启动失败: {str(e)}")
async def _auto_save_task(self):
"""定期自动保存所有聊天流"""
while True:
await asyncio.sleep(300) # 每5分钟保存一次
try:
await self._save_all_streams()
logger.info("聊天流自动保存完成")
except Exception as e:
logger.error(f"聊天流自动保存失败: {str(e)}")
def _ensure_collection(self):
"""确保数据库集合存在并创建索引"""
if "chat_streams" not in self.db.db.list_collection_names():
self.db.db.create_collection("chat_streams")
# 创建索引
self.db.db.chat_streams.create_index([("stream_id", 1)], unique=True)
self.db.db.chat_streams.create_index(
[("platform", 1), ("user_info.user_id", 1), ("group_info.group_id", 1)]
)
def _generate_stream_id(
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
) -> str:
"""生成聊天流唯一ID"""
if group_info:
# 组合关键信息
components = [
platform,
str(group_info.group_id)
]
else:
components = [
platform,
str(user_info.user_id),
"private"
]
# 使用MD5生成唯一ID
key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest()
async def get_or_create_stream(
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
) -> ChatStream:
"""获取或创建聊天流
Args:
platform: 平台标识
user_info: 用户信息
group_info: 群组信息(可选)
Returns:
ChatStream: 聊天流对象
"""
# 生成stream_id
stream_id = self._generate_stream_id(platform, user_info, group_info)
# 检查内存中是否存在
if stream_id in self.streams:
stream = self.streams[stream_id]
# 更新用户信息和群组信息
stream.update_active_time()
stream=copy.deepcopy(stream)
stream.user_info = user_info
if group_info:
stream.group_info = group_info
return stream
# 检查数据库中是否存在
data = self.db.db.chat_streams.find_one({"stream_id": stream_id})
if data:
stream = ChatStream.from_dict(data)
# 更新用户信息和群组信息
stream.user_info = user_info
if group_info:
stream.group_info = group_info
stream.update_active_time()
else:
# 创建新的聊天流
stream = ChatStream(
stream_id=stream_id,
platform=platform,
user_info=user_info,
group_info=group_info,
)
# 保存到内存和数据库
self.streams[stream_id] = stream
await self._save_stream(stream)
return copy.deepcopy(stream)
def get_stream(self, stream_id: str) -> Optional[ChatStream]:
"""通过stream_id获取聊天流"""
return self.streams.get(stream_id)
def get_stream_by_info(
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
) -> Optional[ChatStream]:
"""通过信息获取聊天流"""
stream_id = self._generate_stream_id(platform, user_info, group_info)
return self.streams.get(stream_id)
async def _save_stream(self, stream: ChatStream):
"""保存聊天流到数据库"""
if not stream.saved:
self.db.db.chat_streams.update_one(
{"stream_id": stream.stream_id}, {"$set": stream.to_dict()}, upsert=True
)
stream.saved = True
async def _save_all_streams(self):
"""保存所有聊天流"""
for stream in self.streams.values():
await self._save_stream(stream)
async def load_all_streams(self):
"""从数据库加载所有聊天流"""
all_streams = self.db.db.chat_streams.find({})
for data in all_streams:
stream = ChatStream.from_dict(data)
self.streams[stream.stream_id] = stream
# 创建全局单例
chat_manager = ChatManager()

View File

@@ -2,22 +2,23 @@ import base64
import html import html
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Optional from typing import Dict, List, Optional, Union
from loguru import logger
import requests import requests
# 解析各种CQ码 # 解析各种CQ码
# 包含CQ码类 # 包含CQ码类
import urllib3 import urllib3
from loguru import logger
from nonebot import get_driver from nonebot import get_driver
from urllib3.util import create_urllib3_context from urllib3.util import create_urllib3_context
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
from .config import global_config from .config import global_config
from .mapper import emojimapper from .mapper import emojimapper
from .utils_image import image_path_to_base64, storage_emoji, storage_image from .message_base import Seg
from .utils_user import get_user_nickname from .utils_user import get_user_nickname,get_groupname
from .message_base import GroupInfo, UserInfo
driver = get_driver() driver = get_driver()
config = driver.config config = driver.config
@@ -35,8 +36,11 @@ class TencentSSLAdapter(requests.adapters.HTTPAdapter):
def init_poolmanager(self, connections, maxsize, block=False): def init_poolmanager(self, connections, maxsize, block=False):
self.poolmanager = urllib3.poolmanager.PoolManager( self.poolmanager = urllib3.poolmanager.PoolManager(
num_pools=connections, maxsize=maxsize, num_pools=connections,
block=block, ssl_context=self.ssl_context) maxsize=maxsize,
block=block,
ssl_context=self.ssl_context,
)
@dataclass @dataclass
@@ -48,52 +52,64 @@ class CQCode:
type: CQ码类型'image', 'at', 'face'等) type: CQ码类型'image', 'at', 'face'等)
params: CQ码的参数字典 params: CQ码的参数字典
raw_code: 原始CQ码字符串 raw_code: 原始CQ码字符串
translated_plain_text: 经过处理如AI翻译后的文本表示 translated_segments: 经过处理后的Seg对象列表
""" """
type: str type: str
params: Dict[str, str] params: Dict[str, str]
# raw_code: str group_info: Optional[GroupInfo] = None
group_id: int user_info: Optional[UserInfo] = None
user_id: int translated_segments: Optional[Union[Seg, List[Seg]]] = None
group_name: str = ""
user_nickname: str = ""
translated_plain_text: Optional[str] = None
reply_message: Dict = None # 存储回复消息 reply_message: Dict = None # 存储回复消息
image_base64: Optional[str] = None image_base64: Optional[str] = None
_llm: Optional[LLM_request] = None _llm: Optional[LLM_request] = None
def __post_init__(self): def __post_init__(self):
"""初始化LLM实例""" """初始化LLM实例"""
self._llm = LLM_request(model=global_config.vlm, temperature=0.4, max_tokens=300) pass
async def translate(self): def translate(self):
"""根据CQ码类型进行相应的翻译处理""" """根据CQ码类型进行相应的翻译处理转换为Seg对象"""
if self.type == 'text': if self.type == "text":
self.translated_plain_text = self.params.get('text', '') self.translated_segments = Seg(
elif self.type == 'image': type="text", data=self.params.get("text", "")
if self.params.get('sub_type') == '0': )
self.translated_plain_text = await self.translate_image() elif self.type == "image":
base64_data = self.translate_image()
if base64_data:
if self.params.get("sub_type") == "0":
self.translated_segments = Seg(type="image", data=base64_data)
else: else:
self.translated_plain_text = await self.translate_emoji() self.translated_segments = Seg(type="emoji", data=base64_data)
elif self.type == 'at':
user_nickname = get_user_nickname(self.params.get('qq', ''))
if user_nickname:
self.translated_plain_text = f"[@{user_nickname}]"
else: else:
self.translated_plain_text = "@某人" self.translated_segments = Seg(type="text", data="[图片]")
elif self.type == 'reply': elif self.type == "at":
self.translated_plain_text = await self.translate_reply() user_nickname = get_user_nickname(self.params.get("qq", ""))
elif self.type == 'face': self.translated_segments = Seg(
face_id = self.params.get('id', '') type="text", data=f"[@{user_nickname or '某人'}]"
# self.translated_plain_text = f"[表情{face_id}]" )
self.translated_plain_text = f"[{emojimapper.get(int(face_id), '表情')}]" elif self.type == "reply":
elif self.type == 'forward': reply_segments = self.translate_reply()
self.translated_plain_text = await self.translate_forward() if reply_segments:
self.translated_segments = Seg(type="seglist", data=reply_segments)
else: else:
self.translated_plain_text = f"[{self.type}]" self.translated_segments = Seg(type="text", data="[回复某人消息]")
elif self.type == "face":
face_id = self.params.get("id", "")
self.translated_segments = Seg(
type="text", data=f"[{emojimapper.get(int(face_id), '表情')}]"
)
elif self.type == "forward":
forward_segments = self.translate_forward()
if forward_segments:
self.translated_segments = Seg(type="seglist", data=forward_segments)
else:
self.translated_segments = Seg(type="text", data="[转发消息]")
else:
self.translated_segments = Seg(type="text", data=f"[{self.type}]")
def get_img(self): def get_img(self):
''' """
headers = { headers = {
'User-Agent': 'QQ/8.9.68.11565 CFNetwork/1220.1 Darwin/20.3.0', 'User-Agent': 'QQ/8.9.68.11565 CFNetwork/1220.1 Darwin/20.3.0',
'Accept': 'image/*;q=0.8', 'Accept': 'image/*;q=0.8',
@@ -102,18 +118,18 @@ class CQCode:
'Cache-Control': 'no-cache', 'Cache-Control': 'no-cache',
'Pragma': 'no-cache' 'Pragma': 'no-cache'
} }
''' """
# 腾讯专用请求头配置 # 腾讯专用请求头配置
headers = { headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.87 Safari/537.36', "User-Agent": "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.87 Safari/537.36",
'Accept': 'text/html, application/xhtml xml, */*', "Accept": "text/html, application/xhtml xml, */*",
'Accept-Encoding': 'gbk, GB2312', "Accept-Encoding": "gbk, GB2312",
'Accept-Language': 'zh-cn', "Accept-Language": "zh-cn",
'Content-Type': 'application/x-www-form-urlencoded', "Content-Type": "application/x-www-form-urlencoded",
'Cache-Control': 'no-cache' "Cache-Control": "no-cache",
} }
url = html.unescape(self.params['url']) url = html.unescape(self.params["url"])
if not url.startswith(('http://', 'https://')): if not url.startswith(("http://", "https://")):
return None return None
# 创建专用会话 # 创建专用会话
@@ -129,23 +145,23 @@ class CQCode:
headers=headers, headers=headers,
timeout=15, timeout=15,
allow_redirects=True, allow_redirects=True,
stream=True # 流式传输避免大内存问题 stream=True, # 流式传输避免大内存问题
) )
# 腾讯服务器特殊状态码处理 # 腾讯服务器特殊状态码处理
if response.status_code == 400 and 'multimedia.nt.qq.com.cn' in url: if response.status_code == 400 and "multimedia.nt.qq.com.cn" in url:
return None return None
if response.status_code != 200: if response.status_code != 200:
raise requests.exceptions.HTTPError(f"HTTP {response.status_code}") raise requests.exceptions.HTTPError(f"HTTP {response.status_code}")
# 验证内容类型 # 验证内容类型
content_type = response.headers.get('Content-Type', '') content_type = response.headers.get("Content-Type", "")
if not content_type.startswith('image/'): if not content_type.startswith("image/"):
raise ValueError(f"非图片内容类型: {content_type}") raise ValueError(f"非图片内容类型: {content_type}")
# 转换为Base64 # 转换为Base64
image_base64 = base64.b64encode(response.content).decode('utf-8') image_base64 = base64.b64encode(response.content).decode("utf-8")
self.image_base64 = image_base64 self.image_base64 = image_base64
return image_base64 return image_base64
@@ -160,187 +176,157 @@ class CQCode:
return None return None
async def translate_emoji(self) -> str: def translate_image(self) -> Optional[str]:
"""处理表情包类型的CQ码""" """处理图片类型的CQ码返回base64字符串"""
if 'url' not in self.params: if "url" not in self.params:
return '[表情包]' return None
base64_str = self.get_img() return self.get_img()
if base64_str:
# 将 base64 字符串转换为字节类型
image_bytes = base64.b64decode(base64_str)
storage_emoji(image_bytes)
return await self.get_emoji_description(base64_str)
else:
return '[表情包]'
async def translate_image(self) -> str: def translate_forward(self) -> Optional[List[Seg]]:
"""处理图片类型的CQ码区分普通图片和表情包""" """处理转发消息返回Seg列表"""
# 没有url直接返回默认文本
if 'url' not in self.params:
return '[图片]'
base64_str = self.get_img()
if base64_str:
image_bytes = base64.b64decode(base64_str)
storage_image(image_bytes)
return await self.get_image_description(base64_str)
else:
return '[图片]'
async def get_emoji_description(self, image_base64: str) -> str:
"""调用AI接口获取表情包描述"""
try: try:
prompt = "这是一个表情包请用简短的中文描述这个表情包传达的情感和含义。最多20个字。" if "content" not in self.params:
# description, _ = self._llm.generate_response_for_image_sync(prompt, image_base64) return None
description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
return f"[表情包:{description}]"
except Exception as e:
logger.exception(f"AI接口调用失败: {str(e)}")
return "[表情包]"
async def get_image_description(self, image_base64: str) -> str: content = self.unescape(self.params["content"])
"""调用AI接口获取普通图片描述"""
try:
prompt = "请用中文描述这张图片的内容。如果有文字请把文字都描述出来。并尝试猜测这个图片的含义。最多200个字。"
# description, _ = self._llm.generate_response_for_image_sync(prompt, image_base64)
description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
return f"[图片:{description}]"
except Exception as e:
logger.exception(f"AI接口调用失败: {str(e)}")
return "[图片]"
async def translate_forward(self) -> str:
"""处理转发消息"""
try:
if 'content' not in self.params:
return '[转发消息]'
# 解析content内容需要先反转义
content = self.unescape(self.params['content'])
# print(f"\033[1;34m[调试信息]\033[0m 转发消息内容: {content}")
# 将字符串形式的列表转换为Python对象
import ast import ast
try: try:
messages = ast.literal_eval(content) messages = ast.literal_eval(content)
except ValueError as e: except ValueError as e:
logger.error(f"解析转发消息内容失败: {str(e)}") logger.error(f"解析转发消息内容失败: {str(e)}")
return '[转发消息]' return None
# 处理每条消息 formatted_segments = []
formatted_messages = []
for msg in messages: for msg in messages:
sender = msg.get('sender', {}) sender = msg.get("sender", {})
nickname = sender.get('card') or sender.get('nickname', '未知用户') nickname = sender.get("card") or sender.get("nickname", "未知用户")
raw_message = msg.get("raw_message", "")
# 获取消息内容并使用Message类处理 message_array = msg.get("message", [])
raw_message = msg.get('raw_message', '')
message_array = msg.get('message', [])
if message_array and isinstance(message_array, list): if message_array and isinstance(message_array, list):
# 检查是否包含嵌套的转发消息
for message_part in message_array: for message_part in message_array:
if message_part.get('type') == 'forward': if message_part.get("type") == "forward":
content = '[转发消息]' content_seg = Seg(type="text", data="[转发消息]")
break break
else: else:
# 处理普通消息
if raw_message: if raw_message:
from .message import Message from .message_cq import MessageRecvCQ
message_obj = Message( user_info=UserInfo(
user_id=msg.get('user_id', 0), platform='qq',
message_id=msg.get('message_id', 0), user_id=msg.get("user_id", 0),
user_nickname=nickname,
)
group_info=GroupInfo(
platform='qq',
group_id=msg.get("group_id", 0),
group_name=get_groupname(msg.get("group_id", 0))
)
message_obj = MessageRecvCQ(
message_id=msg.get("message_id", 0),
user_info=user_info,
raw_message=raw_message, raw_message=raw_message,
plain_text=raw_message, plain_text=raw_message,
group_id=msg.get('group_id', 0) group_info=group_info,
) )
await message_obj.initialize() content_seg = Seg(
content = message_obj.processed_plain_text type="seglist", data=message_obj.message_segment )
else: else:
content = '[空消息]' content_seg = Seg(type="text", data="[空消息]")
else: else:
# 处理普通消息
if raw_message: if raw_message:
from .message import Message from .message_cq import MessageRecvCQ
message_obj = Message(
user_id=msg.get('user_id', 0), user_info=UserInfo(
message_id=msg.get('message_id', 0), platform='qq',
user_id=msg.get("user_id", 0),
user_nickname=nickname,
)
group_info=GroupInfo(
platform='qq',
group_id=msg.get("group_id", 0),
group_name=get_groupname(msg.get("group_id", 0))
)
message_obj = MessageRecvCQ(
message_id=msg.get("message_id", 0),
user_info=user_info,
raw_message=raw_message, raw_message=raw_message,
plain_text=raw_message, plain_text=raw_message,
group_id=msg.get('group_id', 0) group_info=group_info,
)
content_seg = Seg(
type="seglist", data=message_obj.message_segment
) )
await message_obj.initialize()
content = message_obj.processed_plain_text
else: else:
content = '[空消息]' content_seg = Seg(type="text", data="[空消息]")
formatted_msg = f"{nickname}: {content}" formatted_segments.append(Seg(type="text", data=f"{nickname}: "))
formatted_messages.append(formatted_msg) formatted_segments.append(content_seg)
formatted_segments.append(Seg(type="text", data="\n"))
# 合并所有消息 return formatted_segments
combined_messages = '\n'.join(formatted_messages)
logger.debug(f"合并后的转发消息: {combined_messages}")
return f"[转发消息:\n{combined_messages}]"
except Exception: except Exception as e:
logger.exception("处理转发消息失败") logger.error(f"处理转发消息失败: {str(e)}")
return '[转发消息]' return None
async def translate_reply(self) -> str: def translate_reply(self) -> Optional[List[Seg]]:
"""处理回复类型的CQ码""" """处理回复类型的CQ码返回Seg列表"""
from .message_cq import MessageRecvCQ
# 创建Message对象 if self.reply_message is None:
from .message import Message return None
if self.reply_message == None:
# print(f"\033[1;31m[错误]\033[0m 回复消息为空")
return '[回复某人消息]'
if self.reply_message.sender.user_id: if self.reply_message.sender.user_id:
message_obj = Message(
user_id=self.reply_message.sender.user_id, message_obj = MessageRecvCQ(
user_info=UserInfo(user_id=self.reply_message.sender.user_id,user_nickname=self.reply_message.sender.get("nickname",None)),
message_id=self.reply_message.message_id, message_id=self.reply_message.message_id,
raw_message=str(self.reply_message.message), raw_message=str(self.reply_message.message),
group_id=self.group_id group_info=GroupInfo(group_id=self.reply_message.group_id),
) )
await message_obj.initialize()
if message_obj.user_id == global_config.BOT_QQ:
return f"[回复 {global_config.BOT_NICKNAME} 的消息: {message_obj.processed_plain_text}]"
else:
return f"[回复 {self.reply_message.sender.nickname} 的消息: {message_obj.processed_plain_text}]"
segments = []
if message_obj.message_info.user_info.user_id == global_config.BOT_QQ:
segments.append(
Seg(
type="text", data=f"[回复 {global_config.BOT_NICKNAME} 的消息: "
)
)
else: else:
logger.error("回复消息的sender.user_id为空") segments.append(
return '[回复某人消息]' Seg(
type="text",
data=f"[回复 {self.reply_message.sender.nickname} 的消息: ",
)
)
segments.append(Seg(type="seglist", data=message_obj.message_segment))
segments.append(Seg(type="text", data="]"))
return segments
else:
return None
@staticmethod @staticmethod
def unescape(text: str) -> str: def unescape(text: str) -> str:
"""反转义CQ码中的特殊字符""" """反转义CQ码中的特殊字符"""
return text.replace('&#44;', ',') \ return (
.replace('&#91;', '[') \ text.replace("&#44;", ",")
.replace('&#93;', ']') \ .replace("&#91;", "[")
.replace('&amp;', '&') .replace("&#93;", "]")
.replace("&amp;", "&")
@staticmethod )
def create_emoji_cq(file_path: str) -> str:
"""
创建表情包CQ码
Args:
file_path: 本地表情包文件路径
Returns:
表情包CQ码字符串
"""
base64_content = image_path_to_base64(file_path)
# 生成CQ码设置sub_type=1表示这是表情包
return f"[CQ:image,file=base64://{base64_content},sub_type=1]"
class CQCode_tool: class CQCode_tool:
@staticmethod @staticmethod
async def cq_from_dict_to_class(cq_code: Dict, reply: Optional[Dict] = None) -> CQCode: def cq_from_dict_to_class(cq_code: Dict,msg ,reply: Optional[Dict] = None) -> CQCode:
""" """
将CQ码字典转换为CQCode对象 将CQ码字典转换为CQCode对象
Args: Args:
cq_code: CQ码字典 cq_code: CQ码字典
msg: MessageCQ对象
reply: 回复消息的字典(可选) reply: 回复消息的字典(可选)
Returns: Returns:
@@ -348,23 +334,23 @@ class CQCode_tool:
""" """
# 处理字典形式的CQ码 # 处理字典形式的CQ码
# 从cq_code字典中获取type字段的值,如果不存在则默认为'text' # 从cq_code字典中获取type字段的值,如果不存在则默认为'text'
cq_type = cq_code.get('type', 'text') cq_type = cq_code.get("type", "text")
params = {} params = {}
if cq_type == 'text': if cq_type == "text":
params['text'] = cq_code.get('data', {}).get('text', '') params["text"] = cq_code.get("data", {}).get("text", "")
else: else:
params = cq_code.get('data', {}) params = cq_code.get("data", {})
instance = CQCode( instance = CQCode(
type=cq_type, type=cq_type,
params=params, params=params,
group_id=0, group_info=msg.message_info.group_info,
user_id=0, user_info=msg.message_info.user_info,
reply_message=reply reply_message=reply
) )
# 进行翻译处理 # 进行翻译处理
await instance.translate() instance.translate()
return instance return instance
@staticmethod @staticmethod
@@ -378,5 +364,64 @@ class CQCode_tool:
""" """
return f"[CQ:reply,id={message_id}]" return f"[CQ:reply,id={message_id}]"
@staticmethod
def create_emoji_cq(file_path: str) -> str:
"""
创建表情包CQ码
Args:
file_path: 本地表情包文件路径
Returns:
表情包CQ码字符串
"""
# 确保使用绝对路径
abs_path = os.path.abspath(file_path)
# 转义特殊字符
escaped_path = (
abs_path.replace("&", "&amp;")
.replace("[", "&#91;")
.replace("]", "&#93;")
.replace(",", "&#44;")
)
# 生成CQ码设置sub_type=1表示这是表情包
return f"[CQ:image,file=file:///{escaped_path},sub_type=1]"
@staticmethod
def create_emoji_cq_base64(base64_data: str) -> str:
"""
创建表情包CQ码
Args:
base64_data: base64编码的表情包数据
Returns:
表情包CQ码字符串
"""
# 转义base64数据
escaped_base64 = (
base64_data.replace("&", "&amp;")
.replace("[", "&#91;")
.replace("]", "&#93;")
.replace(",", "&#44;")
)
# 生成CQ码设置sub_type=1表示这是表情包
return f"[CQ:image,file=base64://{escaped_base64},sub_type=1]"
@staticmethod
def create_image_cq_base64(base64_data: str) -> str:
"""
创建表情包CQ码
Args:
base64_data: base64编码的表情包数据
Returns:
表情包CQ码字符串
"""
# 转义base64数据
escaped_base64 = (
base64_data.replace("&", "&amp;")
.replace("[", "&#91;")
.replace("]", "&#93;")
.replace(",", "&#44;")
)
# 生成CQ码设置sub_type=1表示这是表情包
return f"[CQ:image,file=base64://{escaped_base64},sub_type=0]"
cq_code_tool = CQCode_tool() cq_code_tool = CQCode_tool()

View File

@@ -1,9 +1,11 @@
import asyncio import asyncio
import base64
import hashlib
import os import os
import random import random
import time import time
import traceback import traceback
from typing import Optional from typing import Optional, Tuple
from loguru import logger from loguru import logger
from nonebot import get_driver from nonebot import get_driver
@@ -11,11 +13,12 @@ from nonebot import get_driver
from ...common.database import Database from ...common.database import Database
from ..chat.config import global_config from ..chat.config import global_config
from ..chat.utils import get_embedding from ..chat.utils import get_embedding
from ..chat.utils_image import image_path_to_base64 from ..chat.utils_image import ImageManager, image_path_to_base64
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
driver = get_driver() driver = get_driver()
config = driver.config config = driver.config
image_manager = ImageManager()
class EmojiManager: class EmojiManager:
@@ -76,7 +79,6 @@ class EmojiManager:
if 'emoji' not in self.db.db.list_collection_names(): if 'emoji' not in self.db.db.list_collection_names():
self.db.db.create_collection('emoji') self.db.db.create_collection('emoji')
self.db.db.emoji.create_index([('embedding', '2dsphere')]) self.db.db.emoji.create_index([('embedding', '2dsphere')])
self.db.db.emoji.create_index([('tags', 1)])
self.db.db.emoji.create_index([('filename', 1)], unique=True) self.db.db.emoji.create_index([('filename', 1)], unique=True)
def record_usage(self, emoji_id: str): def record_usage(self, emoji_id: str):
@@ -87,10 +89,10 @@ class EmojiManager:
{'_id': emoji_id}, {'_id': emoji_id},
{'$inc': {'usage_count': 1}} {'$inc': {'usage_count': 1}}
) )
except Exception: except Exception as e:
logger.exception("记录表情使用失败") logger.error(f"记录表情使用失败: {str(e)}")
async def get_emoji_for_text(self, text: str) -> Optional[str]: async def get_emoji_for_text(self, text: str) -> Optional[Tuple[str,str]]:
"""根据文本内容获取相关表情包 """根据文本内容获取相关表情包
Args: Args:
text: 输入文本 text: 输入文本
@@ -144,14 +146,14 @@ class EmojiManager:
emoji_similarities.sort(key=lambda x: x[1], reverse=True) emoji_similarities.sort(key=lambda x: x[1], reverse=True)
# 获取前3个最相似的表情包 # 获取前3个最相似的表情包
top_3_emojis = emoji_similarities[:3] top_10_emojis = emoji_similarities[:10 if len(emoji_similarities) > 10 else len(emoji_similarities)]
if not top_3_emojis: if not top_10_emojis:
logger.warning("未找到匹配的表情包") logger.warning("未找到匹配的表情包")
return None return None
# 从前3个中随机选择一个 # 从前3个中随机选择一个
selected_emoji, similarity = random.choice(top_3_emojis) selected_emoji, similarity = random.choice(top_10_emojis)
if selected_emoji and 'path' in selected_emoji: if selected_emoji and 'path' in selected_emoji:
# 更新使用次数 # 更新使用次数
@@ -174,14 +176,14 @@ class EmojiManager:
logger.error(f"获取表情包失败: {str(e)}") logger.error(f"获取表情包失败: {str(e)}")
return None return None
async def _get_emoji_description(self, image_base64: str) -> str: async def _get_emoji_discription(self, image_base64: str) -> str:
"""获取表情包的标签""" """获取表情包的标签使用image_manager的描述生成功能"""
try: try:
prompt = '这是一个表情包,使用中文简洁的描述一下表情包的内容和表情包所表达的情感' # 使用image_manager获取描述去掉前后的方括号和"表情包:"前缀
description = await image_manager.get_emoji_description(image_base64)
content, _ = await self.vlm.generate_response_for_image(prompt, image_base64) # 去掉[表情包xxx]的格式,只保留描述内容
logger.debug(f"输出描述: {content}") description = description.strip('[]').replace('表情包:', '')
return content return description
except Exception as e: except Exception as e:
logger.error(f"获取标签失败: {str(e)}") logger.error(f"获取标签失败: {str(e)}")
@@ -224,28 +226,65 @@ class EmojiManager:
for filename in files_to_process: for filename in files_to_process:
image_path = os.path.join(emoji_dir, filename) image_path = os.path.join(emoji_dir, filename)
# 检查是否已经注册过 # 获取图片的base64编码和哈希值
existing_emoji = self.db.db['emoji'].find_one({'filename': filename})
if existing_emoji:
continue
# 压缩图片并获取base64编码
image_base64 = image_path_to_base64(image_path) image_base64 = image_path_to_base64(image_path)
if image_base64 is None: if image_base64 is None:
os.remove(image_path) os.remove(image_path)
continue continue
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
# 检查是否已经注册过
existing_emoji = self.db.db['emoji'].find_one({'filename': filename})
description = None
if existing_emoji:
# 即使表情包已存在也检查是否需要同步到images集合
description = existing_emoji.get('discription')
# 检查是否在images集合中存在
existing_image = image_manager.db.db.images.find_one({'hash': image_hash})
if not existing_image:
# 同步到images集合
image_doc = {
'hash': image_hash,
'path': image_path,
'type': 'emoji',
'description': description,
'timestamp': int(time.time())
}
image_manager.db.db.images.update_one(
{'hash': image_hash},
{'$set': image_doc},
upsert=True
)
# 保存描述到image_descriptions集合
image_manager._save_description_to_db(image_hash, description, 'emoji')
logger.success(f"同步已存在的表情包到images集合: {filename}")
continue
# 检查是否在images集合中已有描述
existing_description = image_manager._get_description_from_db(image_hash, 'emoji')
if existing_description:
description = existing_description
else:
# 获取表情包的描述 # 获取表情包的描述
description = await self._get_emoji_description(image_base64) description = await self._get_emoji_discription(image_base64)
if global_config.EMOJI_CHECK: if global_config.EMOJI_CHECK:
check = await self._check_emoji(image_base64) check = await self._check_emoji(image_base64)
if '' not in check: if '' not in check:
os.remove(image_path) os.remove(image_path)
logger.info(f"描述: {description}") logger.info(f"描述: {description}")
logger.info(f"描述: {description}")
logger.info(f"其不满足过滤规则,被剔除 {check}") logger.info(f"其不满足过滤规则,被剔除 {check}")
continue continue
logger.info(f"check通过 {check}") logger.info(f"check通过 {check}")
if description is not None:
embedding = await get_embedding(description)
if description is not None: if description is not None:
embedding = await get_embedding(description) embedding = await get_embedding(description)
# 准备数据库记录 # 准备数据库记录
@@ -253,14 +292,32 @@ class EmojiManager:
'filename': filename, 'filename': filename,
'path': image_path, 'path': image_path,
'embedding': embedding, 'embedding': embedding,
'description': description, 'discription': description,
'hash': image_hash,
'timestamp': int(time.time()) 'timestamp': int(time.time())
} }
# 保存到数据库 # 保存到emoji数据库
self.db.db['emoji'].insert_one(emoji_record) self.db.db['emoji'].insert_one(emoji_record)
logger.success(f"注册新表情包: {filename}") logger.success(f"注册新表情包: {filename}")
logger.info(f"描述: {description}") logger.info(f"描述: {description}")
# 保存到images数据库
image_doc = {
'hash': image_hash,
'path': image_path,
'type': 'emoji',
'description': description,
'timestamp': int(time.time())
}
image_manager.db.db.images.update_one(
{'hash': image_hash},
{'$set': image_doc},
upsert=True
)
# 保存描述到image_descriptions集合
image_manager._save_description_to_db(image_hash, description, 'emoji')
logger.success(f"同步保存到images集合: {filename}")
else: else:
logger.warning(f"跳过表情包: {filename}") logger.warning(f"跳过表情包: {filename}")

View File

@@ -8,7 +8,7 @@ from loguru import logger
from ...common.database import Database from ...common.database import Database
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
from .config import global_config from .config import global_config
from .message import Message from .message import MessageRecv, MessageThinking, MessageSending,Message
from .prompt_builder import prompt_builder from .prompt_builder import prompt_builder
from .relationship_manager import relationship_manager from .relationship_manager import relationship_manager
from .utils import process_llm_response from .utils import process_llm_response
@@ -19,48 +19,79 @@ config = driver.config
class ResponseGenerator: class ResponseGenerator:
def __init__(self): def __init__(self):
self.model_r1 = LLM_request(model=global_config.llm_reasoning, temperature=0.7,max_tokens=1000,stream=True) self.model_r1 = LLM_request(
self.model_v3 = LLM_request(model=global_config.llm_normal, temperature=0.7,max_tokens=1000) model=global_config.llm_reasoning,
self.model_r1_distill = LLM_request(model=global_config.llm_reasoning_minor, temperature=0.7,max_tokens=1000) temperature=0.7,
self.model_v25 = LLM_request(model=global_config.llm_normal_minor, temperature=0.7,max_tokens=1000) max_tokens=1000,
stream=True,
)
self.model_v3 = LLM_request(
model=global_config.llm_normal, temperature=0.7, max_tokens=1000
)
self.model_r1_distill = LLM_request(
model=global_config.llm_reasoning_minor, temperature=0.7, max_tokens=1000
)
self.model_v25 = LLM_request(
model=global_config.llm_normal_minor, temperature=0.7, max_tokens=1000
)
self.db = Database.get_instance() self.db = Database.get_instance()
self.current_model_type = 'r1' # 默认使用 R1 self.current_model_type = "r1" # 默认使用 R1
async def generate_response(self, message: Message) -> Optional[Union[str, List[str]]]: async def generate_response(
self, message: MessageThinking
) -> Optional[Union[str, List[str]]]:
"""根据当前模型类型选择对应的生成函数""" """根据当前模型类型选择对应的生成函数"""
# 从global_config中获取模型概率值并选择模型 # 从global_config中获取模型概率值并选择模型
rand = random.random() rand = random.random()
if rand < global_config.MODEL_R1_PROBABILITY: if rand < global_config.MODEL_R1_PROBABILITY:
self.current_model_type = 'r1' self.current_model_type = "r1"
current_model = self.model_r1 current_model = self.model_r1
elif rand < global_config.MODEL_R1_PROBABILITY + global_config.MODEL_V3_PROBABILITY: elif (
self.current_model_type = 'v3' rand
< global_config.MODEL_R1_PROBABILITY + global_config.MODEL_V3_PROBABILITY
):
self.current_model_type = "v3"
current_model = self.model_v3 current_model = self.model_v3
else: else:
self.current_model_type = 'r1_distill' self.current_model_type = "r1_distill"
current_model = self.model_r1_distill current_model = self.model_r1_distill
logger.info(f"{global_config.BOT_NICKNAME}{self.current_model_type}思考中") logger.info(f"{global_config.BOT_NICKNAME}{self.current_model_type}思考中")
model_response = await self._generate_response_with_model(message, current_model) model_response = await self._generate_response_with_model(
message, current_model
)
raw_content = model_response raw_content = model_response
# print(f"raw_content: {raw_content}")
# print(f"model_response: {model_response}")
if model_response: if model_response:
logger.info(f'{global_config.BOT_NICKNAME}的回复是:{model_response}') logger.info(f'{global_config.BOT_NICKNAME}的回复是:{model_response}')
model_response = await self._process_response(model_response) model_response = await self._process_response(model_response)
if model_response: if model_response:
return model_response, raw_content return model_response, raw_content
return None, raw_content return None, raw_content
async def _generate_response_with_model(self, message: Message, model: LLM_request) -> Optional[str]: async def _generate_response_with_model(
self, message: MessageThinking, model: LLM_request
) -> Optional[str]:
"""使用指定的模型生成回复""" """使用指定的模型生成回复"""
sender_name = message.user_nickname or f"用户{message.user_id}" sender_name = (
if message.user_cardname: message.chat_stream.user_info.user_nickname
sender_name=f"[({message.user_id}){message.user_nickname}]{message.user_cardname}" or f"用户{message.chat_stream.user_info.user_id}"
)
if message.chat_stream.user_info.user_cardname:
sender_name = f"[({message.chat_stream.user_info.user_id}){message.chat_stream.user_info.user_nickname}]{message.chat_stream.user_info.user_cardname}"
# 获取关系值 # 获取关系值
relationship_value = relationship_manager.get_relationship(message.user_id).relationship_value if relationship_manager.get_relationship(message.user_id) else 0.0 relationship_value = (
relationship_manager.get_relationship(
message.chat_stream
).relationship_value
if relationship_manager.get_relationship(message.chat_stream)
else 0.0
)
if relationship_value != 0.0: if relationship_value != 0.0:
# print(f"\033[1;32m[关系管理]\033[0m 回复中_当前关系值: {relationship_value}") # print(f"\033[1;32m[关系管理]\033[0m 回复中_当前关系值: {relationship_value}")
pass pass
@@ -70,7 +101,7 @@ class ResponseGenerator:
message_txt=message.processed_plain_text, message_txt=message.processed_plain_text,
sender_name=sender_name, sender_name=sender_name,
relationship_value=relationship_value, relationship_value=relationship_value,
group_id=message.group_id stream_id=message.chat_stream.stream_id,
) )
# 读空气模块 简化逻辑,先停用 # 读空气模块 简化逻辑,先停用
@@ -113,40 +144,57 @@ class ResponseGenerator:
# def _save_to_db(self, message: Message, sender_name: str, prompt: str, prompt_check: str, # def _save_to_db(self, message: Message, sender_name: str, prompt: str, prompt_check: str,
# content: str, content_check: str, reasoning_content: str, reasoning_content_check: str): # content: str, content_check: str, reasoning_content: str, reasoning_content_check: str):
def _save_to_db(self, message: Message, sender_name: str, prompt: str, prompt_check: str, def _save_to_db(
content: str, reasoning_content: str,): self,
message: MessageRecv,
sender_name: str,
prompt: str,
prompt_check: str,
content: str,
reasoning_content: str,
):
"""保存对话记录到数据库""" """保存对话记录到数据库"""
self.db.db.reasoning_logs.insert_one({ self.db.db.reasoning_logs.insert_one(
'time': time.time(), {
'group_id': message.group_id, "time": time.time(),
'user': sender_name, "chat_id": message.chat_stream.stream_id,
'message': message.processed_plain_text, "user": sender_name,
'model': self.current_model_type, "message": message.processed_plain_text,
"model": self.current_model_type,
# 'reasoning_check': reasoning_content_check, # 'reasoning_check': reasoning_content_check,
# 'response_check': content_check, # 'response_check': content_check,
'reasoning': reasoning_content, "reasoning": reasoning_content,
'response': content, "response": content,
'prompt': prompt, "prompt": prompt,
'prompt_check': prompt_check "prompt_check": prompt_check,
}) }
)
async def _get_emotion_tags(self, content: str) -> List[str]: async def _get_emotion_tags(self, content: str) -> List[str]:
"""提取情感标签""" """提取情感标签"""
try: try:
prompt = f'''请从以下内容中,从"happy,angry,sad,surprised,disgusted,fearful,neutral"中选出最匹配的1个情感标签并输出 prompt = f"""请从以下内容中,从"happy,angry,sad,surprised,disgusted,fearful,neutral"中选出最匹配的1个情感标签并输出
只输出标签就好,不要输出其他内容: 只输出标签就好,不要输出其他内容:
内容:{content} 内容:{content}
输出: 输出:
''' """
content, _ = await self.model_v25.generate_response(prompt) content, _ = await self.model_v25.generate_response(prompt)
content = content.strip() content = content.strip()
if content in ['happy','angry','sad','surprised','disgusted','fearful','neutral']: if content in [
"happy",
"angry",
"sad",
"surprised",
"disgusted",
"fearful",
"neutral",
]:
return [content] return [content]
else: else:
return ["neutral"] return ["neutral"]
except Exception: except Exception as e:
logger.exception("获取情感标签时出错") print(f"获取情感标签时出错: {e}")
return ["neutral"] return ["neutral"]
async def _process_response(self, content: str) -> Tuple[List[str], List[str]]: async def _process_response(self, content: str) -> Tuple[List[str], List[str]]:
@@ -156,6 +204,8 @@ class ResponseGenerator:
processed_response = process_llm_response(content) processed_response = process_llm_response(content)
# print(f"得到了处理后的llm返回{processed_response}")
return processed_response return processed_response

View File

@@ -1,14 +1,13 @@
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, ForwardRef, List, Optional from typing import Dict, ForwardRef, List, Optional, Union
import urllib3 import urllib3
from loguru import logger
from .cq_code import CQCode, cq_code_tool from .utils_image import image_manager
from .utils_cq import parse_cq_code from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase
from .utils_user import get_groupname, get_user_cardname, get_user_nickname from .chat_stream import ChatStream, chat_manager
Message = ForwardRef('Message') # 添加这行
# 禁用SSL警告 # 禁用SSL警告
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
@@ -16,216 +15,372 @@ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
#它定义了消息的属性包括群组ID、用户ID、消息ID、原始消息内容、纯文本内容和时间戳。 #它定义了消息的属性包括群组ID、用户ID、消息ID、原始消息内容、纯文本内容和时间戳。
#它还定义了两个辅助属性keywords用于提取消息的关键词is_plain_text用于判断消息是否为纯文本。 #它还定义了两个辅助属性keywords用于提取消息的关键词is_plain_text用于判断消息是否为纯文本。
@dataclass
class MessageRecv(MessageBase):
"""接收消息类用于处理从MessageCQ序列化的消息"""
def __init__(self, message_dict: Dict):
"""从MessageCQ的字典初始化
Args:
message_dict: MessageCQ序列化后的字典
"""
message_info = BaseMessageInfo.from_dict(message_dict.get('message_info', {}))
message_segment = Seg.from_dict(message_dict.get('message_segment', {}))
raw_message = message_dict.get('raw_message')
super().__init__(
message_info=message_info,
message_segment=message_segment,
raw_message=raw_message
)
# 处理消息内容
self.processed_plain_text = "" # 初始化为空字符串
self.detailed_plain_text = "" # 初始化为空字符串
self.is_emoji=False
def update_chat_stream(self,chat_stream:ChatStream):
self.chat_stream=chat_stream
async def process(self) -> None:
"""处理消息内容,生成纯文本和详细文本
这个方法必须在创建实例后显式调用,因为它包含异步操作。
"""
self.processed_plain_text = await self._process_message_segments(self.message_segment)
self.detailed_plain_text = self._generate_detailed_text()
async def _process_message_segments(self, segment: Seg) -> str:
"""递归处理消息段,转换为文字描述
Args:
segment: 要处理的消息段
Returns:
str: 处理后的文本
"""
if segment.type == 'seglist':
# 处理消息段列表
segments_text = []
for seg in segment.data:
processed = await self._process_message_segments(seg)
if processed:
segments_text.append(processed)
return ' '.join(segments_text)
else:
# 处理单个消息段
return await self._process_single_segment(segment)
async def _process_single_segment(self, seg: Seg) -> str:
"""处理单个消息段
Args:
seg: 要处理的消息段
Returns:
str: 处理后的文本
"""
try:
if seg.type == 'text':
return seg.data
elif seg.type == 'image':
# 如果是base64图片数据
if isinstance(seg.data, str):
return await image_manager.get_image_description(seg.data)
return '[图片]'
elif seg.type == 'emoji':
self.is_emoji=True
if isinstance(seg.data, str):
return await image_manager.get_emoji_description(seg.data)
return '[表情]'
else:
return f"[{seg.type}:{str(seg.data)}]"
except Exception as e:
logger.error(f"处理消息段失败: {str(e)}, 类型: {seg.type}, 数据: {seg.data}")
return f"[处理失败的{seg.type}消息]"
def _generate_detailed_text(self) -> str:
"""生成详细文本,包含时间和用户信息"""
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.message_info.time))
user_info = self.message_info.user_info
name = (
f"{user_info.user_nickname}(ta的昵称:{user_info.user_cardname},ta的id:{user_info.user_id})"
if user_info.user_cardname!=''
else f"{user_info.user_nickname}(ta的id:{user_info.user_id})"
)
return f"[{time_str}] {name}: {self.processed_plain_text}\n"
@dataclass
class Message(MessageBase):
chat_stream: ChatStream=None
reply: Optional['Message'] = None
detailed_plain_text: str = ""
processed_plain_text: str = ""
def __init__(
self,
message_id: str,
time: int,
chat_stream: ChatStream,
user_info: UserInfo,
message_segment: Optional[Seg] = None,
reply: Optional['MessageRecv'] = None,
detailed_plain_text: str = "",
processed_plain_text: str = "",
):
# 构造基础消息信息
message_info = BaseMessageInfo(
platform=chat_stream.platform,
message_id=message_id,
time=time,
group_info=chat_stream.group_info,
user_info=user_info
)
# 调用父类初始化
super().__init__(
message_info=message_info,
message_segment=message_segment,
raw_message=None
)
self.chat_stream = chat_stream
# 文本处理相关属性
self.processed_plain_text = detailed_plain_text
self.detailed_plain_text = processed_plain_text
# 回复消息
self.reply = reply
@dataclass @dataclass
class Message: class MessageProcessBase(Message):
"""消息数据类""" """消息处理基类,用于处理中和发送中的消息"""
message_id: int = None
time: float = None
group_id: int = None def __init__(
group_name: str = None # 群名称 self,
message_id: str,
user_id: int = None chat_stream: ChatStream,
user_nickname: str = None # 用户昵称 bot_user_info: UserInfo,
user_cardname: str = None # 用户群昵称 message_segment: Optional[Seg] = None,
reply: Optional['MessageRecv'] = None
raw_message: str = None # 原始消息包含未解析的cq码 ):
plain_text: str = None # 纯文本 # 调用父类初始化
super().__init__(
reply_message: Dict = None # 存储 回复的 源消息 message_id=message_id,
time=int(time.time()),
# 延迟初始化字段 chat_stream=chat_stream,
_initialized: bool = False user_info=bot_user_info,
message_segments: List[Dict] = None # 存储解析后的消息片段 message_segment=message_segment,
processed_plain_text: str = None # 用于存储处理后的plain_text reply=reply
detailed_plain_text: str = None # 用于存储详细可读文本
# 状态标志
is_emoji: bool = False
has_emoji: bool = False
translate_cq: bool = True
async def initialize(self):
"""显式异步初始化方法(必须调用)"""
if self._initialized:
return
# 异步获取补充信息
self.group_name = self.group_name or get_groupname(self.group_id)
self.user_nickname = self.user_nickname or get_user_nickname(self.user_id)
self.user_cardname = self.user_cardname or get_user_cardname(self.user_id)
# 消息解析
if self.raw_message:
if not isinstance(self,Message_Sending):
self.message_segments = await self.parse_message_segments(self.raw_message)
self.processed_plain_text = ' '.join(
seg.translated_plain_text
for seg in self.message_segments
) )
# 构建详细文本 # 处理状态相关属性
if self.time is None:
self.time = int(time.time())
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.time))
name = (
f"{self.user_nickname}(ta的昵称:{self.user_cardname},ta的id:{self.user_id})"
if self.user_cardname
else f"{self.user_nickname or f'用户{self.user_id}'}"
)
if isinstance(self,Message_Sending) and self.is_emoji:
self.detailed_plain_text = f"[{time_str}] {name}: {self.detailed_plain_text}\n"
else:
self.detailed_plain_text = f"[{time_str}] {name}: {self.processed_plain_text}\n"
self._initialized = True
async def parse_message_segments(self, message: str) -> List[CQCode]:
"""
将消息解析为片段列表包括纯文本和CQ码
返回的列表中每个元素都是字典,包含:
- cq_code_list:分割出的聊天对象包括文本和CQ码
- trans_list:翻译后的对象列表
"""
# print(f"\033[1;34m[调试信息]\033[0m 正在处理消息: {message}")
cq_code_dict_list = []
trans_list = []
start = 0
while True:
# 查找下一个CQ码的开始位置
cq_start = message.find('[CQ:', start)
#如果没有cq码直接返回文本内容
if cq_start == -1:
# 如果没有找到更多CQ码添加剩余文本
if start < len(message):
text = message[start:].strip()
if text: # 只添加非空文本
cq_code_dict_list.append(parse_cq_code(text))
break
# 添加CQ码前的文本
if cq_start > start:
text = message[start:cq_start].strip()
if text: # 只添加非空文本
cq_code_dict_list.append(parse_cq_code(text))
# 查找CQ码的结束位置
cq_end = message.find(']', cq_start)
if cq_end == -1:
# CQ码未闭合作为普通文本处理
text = message[cq_start:].strip()
if text:
cq_code_dict_list.append(parse_cq_code(text))
break
cq_code = message[cq_start:cq_end + 1]
#将cq_code解析成字典
cq_code_dict_list.append(parse_cq_code(cq_code))
# 更新start位置到当前CQ码之后
start = cq_end + 1
# print(f"\033[1;34m[调试信息]\033[0m 提取的消息对象:列表: {cq_code_dict_list}")
#判定是否是表情包消息,以及是否含有表情包
if len(cq_code_dict_list) == 1 and cq_code_dict_list[0]['type'] == 'image':
self.is_emoji = True
self.has_emoji_emoji = True
else:
for segment in cq_code_dict_list:
if segment['type'] == 'image' and segment['data'].get('sub_type') == '1':
self.has_emoji_emoji = True
break
#翻译作为字典的CQ码
for _code_item in cq_code_dict_list:
message_obj = await cq_code_tool.cq_from_dict_to_class(_code_item,reply = self.reply_message)
trans_list.append(message_obj)
return trans_list
class Message_Thinking:
"""消息思考类"""
def __init__(self, message: Message,message_id: str):
# 复制原始消息的基本属性
self.group_id = message.group_id
self.user_id = message.user_id
self.user_nickname = message.user_nickname
self.user_cardname = message.user_cardname
self.group_name = message.group_name
self.message_id = message_id
# 思考状态相关属性
self.thinking_start_time = int(time.time()) self.thinking_start_time = int(time.time())
self.thinking_time = 0 self.thinking_time = 0
self.interupt=False
def update_thinking_time(self): def update_thinking_time(self) -> float:
self.thinking_time = round(time.time(), 2) - self.thinking_start_time """更新思考时间"""
self.thinking_time = round(time.time() - self.thinking_start_time, 2)
@dataclass
class Message_Sending(Message):
"""发送中的消息类"""
thinking_start_time: float = None # 思考开始时间
thinking_time: float = None # 思考时间
reply_message_id: int = None # 存储 回复的 源消息ID
is_head: bool = False # 是否是头部消息
def update_thinking_time(self):
self.thinking_time = round(time.time(), 2) - self.thinking_start_time
return self.thinking_time return self.thinking_time
async def _process_message_segments(self, segment: Seg) -> str:
"""递归处理消息段,转换为文字描述
Args:
segment: 要处理的消息段
Returns:
str: 处理后的文本
"""
if segment.type == 'seglist':
# 处理消息段列表
segments_text = []
for seg in segment.data:
processed = await self._process_message_segments(seg)
if processed:
segments_text.append(processed)
return ' '.join(segments_text)
else:
# 处理单个消息段
return await self._process_single_segment(segment)
async def _process_single_segment(self, seg: Seg) -> str:
"""处理单个消息段
Args:
seg: 要处理的消息段
Returns:
str: 处理后的文本
"""
try:
if seg.type == 'text':
return seg.data
elif seg.type == 'image':
# 如果是base64图片数据
if isinstance(seg.data, str):
return await image_manager.get_image_description(seg.data)
return '[图片]'
elif seg.type == 'emoji':
if isinstance(seg.data, str):
return await image_manager.get_emoji_description(seg.data)
return '[表情]'
elif seg.type == 'at':
return f"[@{seg.data}]"
elif seg.type == 'reply':
if self.reply and hasattr(self.reply, 'processed_plain_text'):
return f"[回复:{self.reply.processed_plain_text}]"
else:
return f"[{seg.type}:{str(seg.data)}]"
except Exception as e:
logger.error(f"处理消息段失败: {str(e)}, 类型: {seg.type}, 数据: {seg.data}")
return f"[处理失败的{seg.type}消息]"
def _generate_detailed_text(self) -> str:
"""生成详细文本,包含时间和用户信息"""
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.message_info.time))
user_info = self.message_info.user_info
name = (
f"{user_info.user_nickname}(ta的昵称:{user_info.user_cardname},ta的id:{user_info.user_id})"
if user_info.user_cardname != ''
else f"{user_info.user_nickname}(ta的id:{user_info.user_id})"
)
return f"[{time_str}] {name}: {self.processed_plain_text}\n"
@dataclass
class MessageThinking(MessageProcessBase):
"""思考状态的消息类"""
def __init__(
self,
message_id: str,
chat_stream: ChatStream,
bot_user_info: UserInfo,
reply: Optional['MessageRecv'] = None
):
# 调用父类初始化
super().__init__(
message_id=message_id,
chat_stream=chat_stream,
bot_user_info=bot_user_info,
message_segment=None, # 思考状态不需要消息段
reply=reply
)
# 思考状态特有属性
self.interrupt = False
@dataclass
class MessageSending(MessageProcessBase):
"""发送状态的消息类"""
def __init__(
self,
message_id: str,
chat_stream: ChatStream,
bot_user_info: UserInfo,
message_segment: Seg,
reply: Optional['MessageRecv'] = None,
is_head: bool = False,
is_emoji: bool = False
):
# 调用父类初始化
super().__init__(
message_id=message_id,
chat_stream=chat_stream,
bot_user_info=bot_user_info,
message_segment=message_segment,
reply=reply
)
# 发送状态特有属性
self.reply_to_message_id = reply.message_info.message_id if reply else None
self.is_head = is_head
self.is_emoji = is_emoji
def set_reply(self, reply: Optional['MessageRecv']) -> None:
"""设置回复消息"""
if reply:
self.reply = reply
self.reply_to_message_id = self.reply.message_info.message_id
self.message_segment = Seg(type='seglist', data=[
Seg(type='reply', data=reply.message_info.message_id),
self.message_segment
])
async def process(self) -> None:
"""处理消息内容,生成纯文本和详细文本"""
if self.message_segment:
self.processed_plain_text = await self._process_message_segments(self.message_segment)
self.detailed_plain_text = self._generate_detailed_text()
@classmethod
def from_thinking(
cls,
thinking: MessageThinking,
message_segment: Seg,
is_head: bool = False,
is_emoji: bool = False
) -> 'MessageSending':
"""从思考状态消息创建发送状态消息"""
return cls(
message_id=thinking.message_info.message_id,
chat_stream=thinking.chat_stream,
message_segment=message_segment,
bot_user_info=thinking.message_info.user_info,
reply=thinking.reply,
is_head=is_head,
is_emoji=is_emoji
)
def to_dict(self):
ret= super().to_dict()
ret['message_info']['user_info']=self.chat_stream.user_info.to_dict()
return ret
@dataclass
class MessageSet: class MessageSet:
"""消息集合类,可以存储多个发送消息""" """消息集合类,可以存储多个发送消息"""
def __init__(self, group_id: int, user_id: int, message_id: str): def __init__(self, chat_stream: ChatStream, message_id: str):
self.group_id = group_id self.chat_stream = chat_stream
self.user_id = user_id
self.message_id = message_id self.message_id = message_id
self.messages: List[Message_Sending] = [] # 修改类型标注 self.messages: List[MessageSending] = []
self.time = round(time.time(), 2) self.time = round(time.time(), 2)
def add_message(self, message: Message_Sending) -> None: def add_message(self, message: MessageSending) -> None:
"""添加消息到集合只接受Message_Sending类型""" """添加消息到集合"""
if not isinstance(message, Message_Sending): if not isinstance(message, MessageSending):
raise TypeError("MessageSet只能添加Message_Sending类型的消息") raise TypeError("MessageSet只能添加MessageSending类型的消息")
self.messages.append(message) self.messages.append(message)
# 按时间排序 self.messages.sort(key=lambda x: x.message_info.time)
self.messages.sort(key=lambda x: x.time)
def get_message_by_index(self, index: int) -> Optional[Message_Sending]: def get_message_by_index(self, index: int) -> Optional[MessageSending]:
"""通过索引获取消息""" """通过索引获取消息"""
if 0 <= index < len(self.messages): if 0 <= index < len(self.messages):
return self.messages[index] return self.messages[index]
return None return None
def get_message_by_time(self, target_time: float) -> Optional[Message_Sending]: def get_message_by_time(self, target_time: float) -> Optional[MessageSending]:
"""获取最接近指定时间的消息""" """获取最接近指定时间的消息"""
if not self.messages: if not self.messages:
return None return None
# 使用二分查找找到最接近的消息
left, right = 0, len(self.messages) - 1 left, right = 0, len(self.messages) - 1
while left < right: while left < right:
mid = (left + right) // 2 mid = (left + right) // 2
if self.messages[mid].time < target_time: if self.messages[mid].message_info.time < target_time:
left = mid + 1 left = mid + 1
else: else:
right = mid right = mid
return self.messages[left] return self.messages[left]
def clear_messages(self) -> None: def clear_messages(self) -> None:
"""清空所有消息""" """清空所有消息"""
self.messages.clear() self.messages.clear()
def remove_message(self, message: Message_Sending) -> bool: def remove_message(self, message: MessageSending) -> bool:
"""移除指定消息""" """移除指定消息"""
if message in self.messages: if message in self.messages:
self.messages.remove(message) self.messages.remove(message)

View File

@@ -0,0 +1,186 @@
from dataclasses import dataclass, asdict
from typing import List, Optional, Union, Any, Dict
@dataclass
class Seg:
"""消息片段类,用于表示消息的不同部分
Attributes:
type: 片段类型,可以是 'text''image''seglist'
data: 片段的具体内容
- 对于 text 类型data 是字符串
- 对于 image 类型data 是 base64 字符串
- 对于 seglist 类型data 是 Seg 列表
translated_data: 经过翻译处理的数据(可选)
"""
type: str
data: Union[str, List['Seg']]
# def __init__(self, type: str, data: Union[str, List['Seg']],):
# """初始化实例,确保字典和属性同步"""
# # 先初始化字典
# self.type = type
# self.data = data
@classmethod
def from_dict(cls, data: Dict) -> 'Seg':
"""从字典创建Seg实例"""
type=data.get('type')
data=data.get('data')
if type == 'seglist':
data = [Seg.from_dict(seg) for seg in data]
return cls(
type=type,
data=data
)
def to_dict(self) -> Dict:
"""转换为字典格式"""
result = {'type': self.type}
if self.type == 'seglist':
result['data'] = [seg.to_dict() for seg in self.data]
else:
result['data'] = self.data
return result
@dataclass
class GroupInfo:
"""群组信息类"""
platform: Optional[str] = None
group_id: Optional[int] = None
group_name: Optional[str] = None # 群名称
def to_dict(self) -> Dict:
"""转换为字典格式"""
return {k: v for k, v in asdict(self).items() if v is not None}
@classmethod
def from_dict(cls, data: Dict) -> 'GroupInfo':
"""从字典创建GroupInfo实例
Args:
data: 包含必要字段的字典
Returns:
GroupInfo: 新的实例
"""
return cls(
platform=data.get('platform'),
group_id=data.get('group_id'),
group_name=data.get('group_name',None)
)
@dataclass
class UserInfo:
"""用户信息类"""
platform: Optional[str] = None
user_id: Optional[int] = None
user_nickname: Optional[str] = None # 用户昵称
user_cardname: Optional[str] = None # 用户群昵称
def to_dict(self) -> Dict:
"""转换为字典格式"""
return {k: v for k, v in asdict(self).items() if v is not None}
@classmethod
def from_dict(cls, data: Dict) -> 'UserInfo':
"""从字典创建UserInfo实例
Args:
data: 包含必要字段的字典
Returns:
UserInfo: 新的实例
"""
return cls(
platform=data.get('platform'),
user_id=data.get('user_id'),
user_nickname=data.get('user_nickname',None),
user_cardname=data.get('user_cardname',None)
)
@dataclass
class BaseMessageInfo:
"""消息信息类"""
platform: Optional[str] = None
message_id: Union[str,int,None] = None
time: Optional[int] = None
group_info: Optional[GroupInfo] = None
user_info: Optional[UserInfo] = None
def to_dict(self) -> Dict:
"""转换为字典格式"""
result = {}
for field, value in asdict(self).items():
if value is not None:
if isinstance(value, (GroupInfo, UserInfo)):
result[field] = value.to_dict()
else:
result[field] = value
return result
@classmethod
def from_dict(cls, data: Dict) -> 'BaseMessageInfo':
"""从字典创建BaseMessageInfo实例
Args:
data: 包含必要字段的字典
Returns:
BaseMessageInfo: 新的实例
"""
group_info = GroupInfo(**data.get('group_info', {}))
user_info = UserInfo(**data.get('user_info', {}))
return cls(
platform=data.get('platform'),
message_id=data.get('message_id'),
time=data.get('time'),
group_info=group_info,
user_info=user_info
)
@dataclass
class MessageBase:
"""消息类"""
message_info: BaseMessageInfo
message_segment: Seg
raw_message: Optional[str] = None # 原始消息包含未解析的cq码
def to_dict(self) -> Dict:
"""转换为字典格式
Returns:
Dict: 包含所有非None字段的字典其中
- message_info: 转换为字典格式
- message_segment: 转换为字典格式
- raw_message: 如果存在则包含
"""
result = {
'message_info': self.message_info.to_dict(),
'message_segment': self.message_segment.to_dict()
}
if self.raw_message is not None:
result['raw_message'] = self.raw_message
return result
@classmethod
def from_dict(cls, data: Dict) -> 'MessageBase':
"""从字典创建MessageBase实例
Args:
data: 包含必要字段的字典
Returns:
MessageBase: 新的实例
"""
message_info = BaseMessageInfo(**data.get('message_info', {}))
message_segment = Seg(**data.get('message_segment', {}))
raw_message = data.get('raw_message',None)
return cls(
message_info=message_info,
message_segment=message_segment,
raw_message=raw_message
)

View File

@@ -0,0 +1,169 @@
import time
from dataclasses import dataclass
from typing import Dict, ForwardRef, List, Optional, Union
import urllib3
from .cq_code import CQCode, cq_code_tool
from .utils_cq import parse_cq_code
from .utils_user import get_groupname, get_user_cardname, get_user_nickname
from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase
# 禁用SSL警告
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
#这个类是消息数据类,用于存储和管理消息数据。
#它定义了消息的属性包括群组ID、用户ID、消息ID、原始消息内容、纯文本内容和时间戳。
#它还定义了两个辅助属性keywords用于提取消息的关键词is_plain_text用于判断消息是否为纯文本。
@dataclass
class MessageCQ(MessageBase):
"""QQ消息基类继承自MessageBase
最小必要参数:
- message_id: 消息ID
- user_id: 发送者/接收者ID
- platform: 平台标识(默认为"qq"
"""
def __init__(
self,
message_id: int,
user_info: UserInfo,
group_info: Optional[GroupInfo] = None,
platform: str = "qq"
):
# 构造基础消息信息
message_info = BaseMessageInfo(
platform=platform,
message_id=message_id,
time=int(time.time()),
group_info=group_info,
user_info=user_info
)
# 调用父类初始化message_segment 由子类设置
super().__init__(
message_info=message_info,
message_segment=None,
raw_message=None
)
@dataclass
class MessageRecvCQ(MessageCQ):
"""QQ接收消息类用于解析raw_message到Seg对象"""
def __init__(
self,
message_id: int,
user_info: UserInfo,
raw_message: str,
group_info: Optional[GroupInfo] = None,
platform: str = "qq",
reply_message: Optional[Dict] = None,
):
# 调用父类初始化
super().__init__(message_id, user_info, group_info, platform)
if group_info.group_name is None:
group_info.group_name = get_groupname(group_info.group_id)
# 解析消息段
self.message_segment = self._parse_message(raw_message, reply_message)
self.raw_message = raw_message
def _parse_message(self, message: str, reply_message: Optional[Dict] = None) -> Seg:
"""解析消息内容为Seg对象"""
cq_code_dict_list = []
segments = []
start = 0
while True:
cq_start = message.find('[CQ:', start)
if cq_start == -1:
if start < len(message):
text = message[start:].strip()
if text:
cq_code_dict_list.append(parse_cq_code(text))
break
if cq_start > start:
text = message[start:cq_start].strip()
if text:
cq_code_dict_list.append(parse_cq_code(text))
cq_end = message.find(']', cq_start)
if cq_end == -1:
text = message[cq_start:].strip()
if text:
cq_code_dict_list.append(parse_cq_code(text))
break
cq_code = message[cq_start:cq_end + 1]
cq_code_dict_list.append(parse_cq_code(cq_code))
start = cq_end + 1
# 转换CQ码为Seg对象
for code_item in cq_code_dict_list:
message_obj = cq_code_tool.cq_from_dict_to_class(code_item,msg=self,reply=reply_message)
if message_obj.translated_segments:
segments.append(message_obj.translated_segments)
# 如果只有一个segment直接返回
if len(segments) == 1:
return segments[0]
# 否则返回seglist类型的Seg
return Seg(type='seglist', data=segments)
def to_dict(self) -> Dict:
"""转换为字典格式,包含所有必要信息"""
base_dict = super().to_dict()
return base_dict
@dataclass
class MessageSendCQ(MessageCQ):
"""QQ发送消息类用于将Seg对象转换为raw_message"""
def __init__(
self,
data: Dict
):
# 调用父类初始化
message_info = BaseMessageInfo.from_dict(data.get('message_info', {}))
message_segment = Seg.from_dict(data.get('message_segment', {}))
super().__init__(
message_info.message_id,
message_info.user_info,
message_info.group_info if message_info.group_info else None,
message_info.platform
)
self.message_segment = message_segment
self.raw_message = self._generate_raw_message()
def _generate_raw_message(self, ) -> str:
"""将Seg对象转换为raw_message"""
segments = []
# 处理消息段
if self.message_segment.type == 'seglist':
for seg in self.message_segment.data:
segments.append(self._seg_to_cq_code(seg))
else:
segments.append(self._seg_to_cq_code(self.message_segment))
return ''.join(segments)
def _seg_to_cq_code(self, seg: Seg) -> str:
"""将单个Seg对象转换为CQ码字符串"""
if seg.type == 'text':
return str(seg.data)
elif seg.type == 'image':
return cq_code_tool.create_image_cq_base64(seg.data)
elif seg.type == 'emoji':
return cq_code_tool.create_emoji_cq_base64(seg.data)
elif seg.type == 'at':
return f"[CQ:at,qq={seg.data}]"
elif seg.type == 'reply':
return cq_code_tool.create_reply_cq(int(seg.data))
else:
return f"[{seg.data}]"

View File

@@ -6,10 +6,11 @@ from loguru import logger
from nonebot.adapters.onebot.v11 import Bot from nonebot.adapters.onebot.v11 import Bot
from .cq_code import cq_code_tool from .cq_code import cq_code_tool
from .message import Message, Message_Sending, Message_Thinking, MessageSet from .message_cq import MessageSendCQ
from .message import MessageSending, MessageThinking, MessageRecv,MessageSet
from .storage import MessageStorage from .storage import MessageStorage
from .utils import calculate_typing_time
from .config import global_config from .config import global_config
from .chat_stream import chat_manager
class Message_Sender: class Message_Sender:
@@ -24,64 +25,57 @@ class Message_Sender:
"""设置当前bot实例""" """设置当前bot实例"""
self._current_bot = bot self._current_bot = bot
async def send_group_message( async def send_message(
self, self,
group_id: int, message: MessageSending,
send_text: str,
auto_escape: bool = False,
reply_message_id: int = None,
at_user_id: int = None
) -> None: ) -> None:
"""发送消息"""
if isinstance(message, MessageSending):
message_json = message.to_dict()
message_send=MessageSendCQ(
data=message_json
)
if not self._current_bot: if message_send.message_info.group_info:
raise RuntimeError("Bot未设置请先调用set_bot方法设置bot实例")
message = send_text
# 如果需要回复
if reply_message_id:
reply_cq = cq_code_tool.create_reply_cq(reply_message_id)
message = reply_cq + message
# 如果需要at
# if at_user_id:
# at_cq = cq_code_tool.create_at_cq(at_user_id)
# message = at_cq + " " + message
typing_time = calculate_typing_time(message)
if typing_time > 10:
typing_time = 10
await asyncio.sleep(typing_time)
# 发送消息
try: try:
await self._current_bot.send_group_msg( await self._current_bot.send_group_msg(
group_id=group_id, group_id=message.message_info.group_info.group_id,
message=message, message=message_send.raw_message,
auto_escape=auto_escape auto_escape=False
) )
logger.debug(f"发送消息{message}成功") logger.success(f"[调试] 发送消息{message.processed_plain_text}成功")
except Exception: except Exception as e:
logger.exception(f"发送消息{message}失败") logger.error(f"[调试] 发生错误 {e}")
logger.error(f"[调试] 发送消息{message.processed_plain_text}失败")
else:
try:
await self._current_bot.send_private_msg(
user_id=message.message_info.user_info.user_id,
message=message_send.raw_message,
auto_escape=False
)
logger.success(f"[调试] 发送消息{message.processed_plain_text}成功")
except Exception as e:
logger.error(f"发生错误 {e}")
logger.error(f"[调试] 发送消息{message.processed_plain_text}失败")
class MessageContainer: class MessageContainer:
"""单个的发送/思考消息容器""" """单个聊天流的发送/思考消息容器"""
def __init__(self, chat_id: str, max_size: int = 100):
def __init__(self, group_id: int, max_size: int = 100): self.chat_id = chat_id
self.group_id = group_id
self.max_size = max_size self.max_size = max_size
self.messages = [] self.messages = []
self.last_send_time = 0 self.last_send_time = 0
self.thinking_timeout = 20 # 思考超时时间(秒) self.thinking_timeout = 20 # 思考超时时间(秒)
def get_timeout_messages(self) -> List[Message_Sending]: def get_timeout_messages(self) -> List[MessageSending]:
"""获取所有超时的Message_Sending对象思考时间超过30秒按thinking_start_time排序""" """获取所有超时的Message_Sending对象思考时间超过30秒按thinking_start_time排序"""
current_time = time.time() current_time = time.time()
timeout_messages = [] timeout_messages = []
for msg in self.messages: for msg in self.messages:
if isinstance(msg, Message_Sending): if isinstance(msg, MessageSending):
if current_time - msg.thinking_start_time > self.thinking_timeout: if current_time - msg.thinking_start_time > self.thinking_timeout:
timeout_messages.append(msg) timeout_messages.append(msg)
@@ -90,7 +84,7 @@ class MessageContainer:
return timeout_messages return timeout_messages
def get_earliest_message(self) -> Optional[Union[Message_Thinking, Message_Sending]]: def get_earliest_message(self) -> Optional[Union[MessageThinking, MessageSending]]:
"""获取thinking_start_time最早的消息对象""" """获取thinking_start_time最早的消息对象"""
if not self.messages: if not self.messages:
return None return None
@@ -103,16 +97,15 @@ class MessageContainer:
earliest_message = msg earliest_message = msg
return earliest_message return earliest_message
def add_message(self, message: Union[Message_Thinking, Message_Sending]) -> None: def add_message(self, message: Union[MessageThinking, MessageSending]) -> None:
"""添加消息到队列""" """添加消息到队列"""
# print(f"\033[1;32m[添加消息]\033[0m 添加消息到对应群")
if isinstance(message, MessageSet): if isinstance(message, MessageSet):
for single_message in message.messages: for single_message in message.messages:
self.messages.append(single_message) self.messages.append(single_message)
else: else:
self.messages.append(message) self.messages.append(message)
def remove_message(self, message: Union[Message_Thinking, Message_Sending]) -> bool: def remove_message(self, message: Union[MessageThinking, MessageSending]) -> bool:
"""移除消息如果消息存在则返回True否则返回False""" """移除消息如果消息存在则返回True否则返回False"""
try: try:
if message in self.messages: if message in self.messages:
@@ -127,41 +120,39 @@ class MessageContainer:
"""检查是否有待发送的消息""" """检查是否有待发送的消息"""
return bool(self.messages) return bool(self.messages)
def get_all_messages(self) -> List[Union[Message, Message_Thinking]]: def get_all_messages(self) -> List[Union[MessageSending, MessageThinking]]:
"""获取所有消息""" """获取所有消息"""
return list(self.messages) return list(self.messages)
class MessageManager: class MessageManager:
"""管理所有的消息容器""" """管理所有聊天流的消息容器"""
def __init__(self): def __init__(self):
self.containers: Dict[int, MessageContainer] = {} self.containers: Dict[str, MessageContainer] = {} # chat_id -> MessageContainer
self.storage = MessageStorage() self.storage = MessageStorage()
self._running = True self._running = True
def get_container(self, group_id: int) -> MessageContainer: def get_container(self, chat_id: str) -> MessageContainer:
"""获取或创建的消息容器""" """获取或创建聊天流的消息容器"""
if group_id not in self.containers: if chat_id not in self.containers:
self.containers[group_id] = MessageContainer(group_id) self.containers[chat_id] = MessageContainer(chat_id)
return self.containers[group_id] return self.containers[chat_id]
def add_message(self, message: Union[Message_Thinking, Message_Sending, MessageSet]) -> None: def add_message(self, message: Union[MessageThinking, MessageSending, MessageSet]) -> None:
container = self.get_container(message.group_id) chat_stream = message.chat_stream
if not chat_stream:
raise ValueError("无法找到对应的聊天流")
container = self.get_container(chat_stream.stream_id)
container.add_message(message) container.add_message(message)
async def process_group_messages(self, group_id: int): async def process_chat_messages(self, chat_id: str):
"""处理消息""" """处理聊天流消息"""
# if int(time.time() / 3) == time.time() / 3: container = self.get_container(chat_id)
# print(f"\033[1;34m[调试]\033[0m 开始处理群{group_id}的消息")
container = self.get_container(group_id)
if container.has_messages(): if container.has_messages():
# 最早的对象,可能是思考消息,也可能是发送消息 # print(f"处理有message的容器chat_id: {chat_id}")
message_earliest = container.get_earliest_message() # 一个message_thinking or message_sending message_earliest = container.get_earliest_message()
# 如果是思考消息 if isinstance(message_earliest, MessageThinking):
if isinstance(message_earliest, Message_Thinking):
# 优先等待这条消息
message_earliest.update_thinking_time() message_earliest.update_thinking_time()
thinking_time = message_earliest.thinking_time thinking_time = message_earliest.thinking_time
print(f"消息正在思考中,已思考{int(thinking_time)}\r", end='', flush=True) print(f"消息正在思考中,已思考{int(thinking_time)}\r", end='', flush=True)
@@ -170,47 +161,38 @@ class MessageManager:
if thinking_time > global_config.thinking_timeout: if thinking_time > global_config.thinking_timeout:
logger.warning(f"消息思考超时({thinking_time}秒),移除该消息") logger.warning(f"消息思考超时({thinking_time}秒),移除该消息")
container.remove_message(message_earliest) container.remove_message(message_earliest)
else: # 如果不是message_thinking就只能是message_sending
logger.debug(f"消息'{message_earliest.processed_plain_text}'正在发送中")
# 直接发,等什么呢
if message_earliest.is_head and message_earliest.update_thinking_time() > 30:
await message_sender.send_group_message(group_id, message_earliest.processed_plain_text,
auto_escape=False,
reply_message_id=message_earliest.reply_message_id)
else: else:
await message_sender.send_group_message(group_id, message_earliest.processed_plain_text,
auto_escape=False) if message_earliest.is_head and message_earliest.update_thinking_time() > 30:
# 移除消息 await message_sender.send_message(message_earliest.set_reply())
if message_earliest.is_emoji: else:
message_earliest.processed_plain_text = "[表情包]" await message_sender.send_message(message_earliest)
await self.storage.store_message(message_earliest, None) await message_earliest.process()
print(f"\033[1;34m[调试]\033[0m 消息'{message_earliest.processed_plain_text}'正在发送中")
await self.storage.store_message(message_earliest, message_earliest.chat_stream,None)
container.remove_message(message_earliest) container.remove_message(message_earliest)
# 获取并处理超时消息 message_timeout = container.get_timeout_messages()
message_timeout = container.get_timeout_messages() # 也许是一堆message_sending
if message_timeout: if message_timeout:
logger.warning(f"发现{len(message_timeout)}条超时消息") logger.warning(f"发现{len(message_timeout)}条超时消息")
for msg in message_timeout: for msg in message_timeout:
if msg == message_earliest: if msg == message_earliest:
continue # 跳过已经处理过的消息 continue
try: try:
# 发送
if msg.is_head and msg.update_thinking_time() > 30: if msg.is_head and msg.update_thinking_time() > 30:
await message_sender.send_group_message(group_id, msg.processed_plain_text, await message_sender.send_message(msg.set_reply())
auto_escape=False,
reply_message_id=msg.reply_message_id)
else: else:
await message_sender.send_group_message(group_id, msg.processed_plain_text, await message_sender.send_message(msg)
auto_escape=False)
# 如果是表情包,则替换为"[表情包]" # if msg.is_emoji:
if msg.is_emoji: # msg.processed_plain_text = "[表情包]"
msg.processed_plain_text = "[表情包]" await msg.process()
await self.storage.store_message(msg, None) await self.storage.store_message(msg,msg.chat_stream, None)
# 安全地移除消息
if not container.remove_message(msg): if not container.remove_message(msg):
logger.warning("尝试删除不存在的消息") logger.warning("尝试删除不存在的消息")
except Exception: except Exception:
@@ -222,8 +204,8 @@ class MessageManager:
while self._running: while self._running:
await asyncio.sleep(1) await asyncio.sleep(1)
tasks = [] tasks = []
for group_id in self.containers.keys(): for chat_id in self.containers.keys():
tasks.append(self.process_group_messages(group_id)) tasks.append(self.process_chat_messages(chat_id))
await asyncio.gather(*tasks) await asyncio.gather(*tasks)

View File

@@ -9,6 +9,7 @@ from ..moods.moods import MoodManager
from ..schedule.schedule_generator import bot_schedule from ..schedule.schedule_generator import bot_schedule
from .config import global_config from .config import global_config
from .utils import get_embedding, get_recent_group_detailed_plain_text from .utils import get_embedding, get_recent_group_detailed_plain_text
from .chat_stream import ChatStream, chat_manager
class PromptBuilder: class PromptBuilder:
@@ -17,11 +18,13 @@ class PromptBuilder:
self.activate_messages = '' self.activate_messages = ''
self.db = Database.get_instance() self.db = Database.get_instance()
async def _build_prompt(self, async def _build_prompt(self,
message_txt: str, message_txt: str,
sender_name: str = "某人", sender_name: str = "某人",
relationship_value: float = 0.0, relationship_value: float = 0.0,
group_id: Optional[int] = None) -> tuple[str, str]: stream_id: Optional[int] = None) -> tuple[str, str]:
"""构建prompt """构建prompt
Args: Args:
@@ -70,13 +73,19 @@ class PromptBuilder:
logger.debug(f"知识检索耗时: {(end_time - start_time):.3f}") logger.debug(f"知识检索耗时: {(end_time - start_time):.3f}")
# 获取聊天上下文 # 获取聊天上下文
chat_in_group=True
chat_talking_prompt = '' chat_talking_prompt = ''
if group_id: if stream_id:
chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, group_id, chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, stream_id, limit=global_config.MAX_CONTEXT_SIZE,combine = True)
limit=global_config.MAX_CONTEXT_SIZE, chat_stream=chat_manager.get_stream(stream_id)
combine=True) if chat_stream.group_info:
chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}" chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}"
else:
chat_in_group=False
chat_talking_prompt = f"以下是你正在和{sender_name}私聊的内容:\n{chat_talking_prompt}"
# print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}")
# 使用新的记忆获取方法 # 使用新的记忆获取方法
memory_prompt = '' memory_prompt = ''
@@ -108,15 +117,10 @@ class PromptBuilder:
# 激活prompt构建 # 激活prompt构建
activate_prompt = '' activate_prompt = ''
activate_prompt = f"以上是群里正在进行的聊天,{memory_prompt} 现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和他{relation_prompt},{mood_prompt},你想要{relation_prompt_2}" if chat_in_group:
activate_prompt = f"以上是群里正在进行的聊天,{memory_prompt} 现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和ta{relation_prompt},{mood_prompt},你想要{relation_prompt_2}"
# 检测机器人相关词汇,改为关键词检测与反应功能了,提取到全局配置中 else:
# bot_keywords = ['人机', 'bot', '机器', '入机', 'robot', '机器人'] activate_prompt = f"以上是你正在和{sender_name}私聊的内容,{memory_prompt} 现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和ta{relation_prompt},{mood_prompt},你想要{relation_prompt_2}"
# is_bot = any(keyword in message_txt.lower() for keyword in bot_keywords)
# if is_bot:
# is_bot_prompt = '有人说你是人机或者机器人,否定这一事实,攻击或反问他。注意,只有认为你是机器人才回复,否则不要否认'
# else:
# is_bot_prompt = ''
# 关键词检测与反应 # 关键词检测与反应
keywords_reaction_prompt = '' keywords_reaction_prompt = ''
@@ -134,6 +138,10 @@ class PromptBuilder:
prompt_personality = f'{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},你还有很多别名:{"/".join(global_config.BOT_ALIAS_NAMES)}' prompt_personality = f'{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},你还有很多别名:{"/".join(global_config.BOT_ALIAS_NAMES)}'
personality_choice = random.random() personality_choice = random.random()
if chat_in_group:
prompt_in_group=f"你正在浏览{chat_stream.platform}"
else:
prompt_in_group=f"你正在{chat_stream.platform}上和{sender_name}私聊"
if personality_choice < probability_1: # 第一种人格 if personality_choice < probability_1: # 第一种人格
prompt_personality += f'''{personality[0]}, 你正在浏览qq群,{promt_info_prompt}, prompt_personality += f'''{personality[0]}, 你正在浏览qq群,{promt_info_prompt},
现在请你给出日常且口语化的回复,平淡一些,尽量简短一些。{keywords_reaction_prompt} 现在请你给出日常且口语化的回复,平淡一些,尽量简短一些。{keywords_reaction_prompt}

View File

@@ -1,9 +1,11 @@
import asyncio import asyncio
from typing import Optional, Union
from typing import Optional, Union
from loguru import logger from loguru import logger
from typing import Optional
from ...common.database import Database from ...common.database import Database
from .message_base import UserInfo
from .chat_stream import ChatStream
class Impression: class Impression:
traits: str = None traits: str = None
@@ -15,59 +17,66 @@ class Impression:
class Relationship: class Relationship:
user_id: int = None user_id: int = None
# impression: Impression = None platform: str = None
# group_id: int = None
# group_name: str = None
gender: str = None gender: str = None
age: int = None age: int = None
nickname: str = None nickname: str = None
relationship_value: float = None relationship_value: float = None
saved = False saved = False
def __init__(self, user_id: int, data=None, **kwargs): def __init__(self, chat:ChatStream=None,data:dict=None):
if isinstance(data, dict): self.user_id=chat.user_info.user_id if chat else data.get('user_id',0)
# 如果输入是字典,使用字典解析 self.platform=chat.platform if chat else data.get('platform','')
self.user_id = data.get('user_id') self.nickname=chat.user_info.user_nickname if chat else data.get('nickname','')
self.gender = data.get('gender') self.relationship_value=data.get('relationship_value',0) if data else 0
self.age = data.get('age') self.age=data.get('age',0) if data else 0
self.nickname = data.get('nickname') self.gender=data.get('gender','') if data else ''
self.relationship_value = data.get('relationship_value', 0.0)
self.saved = data.get('saved', False)
else:
# 如果是直接传入属性值
self.user_id = kwargs.get('user_id')
self.gender = kwargs.get('gender')
self.age = kwargs.get('age')
self.nickname = kwargs.get('nickname')
self.relationship_value = kwargs.get('relationship_value', 0.0)
self.saved = kwargs.get('saved', False)
class RelationshipManager: class RelationshipManager:
def __init__(self): def __init__(self):
self.relationships: dict[int, Relationship] = {} self.relationships: dict[tuple[int, str], Relationship] = {} # 修改为使用(user_id, platform)作为键
async def update_relationship(self,
chat_stream:ChatStream,
data: dict = None,
**kwargs) -> Optional[Relationship]:
"""更新或创建关系
Args:
chat_stream: 聊天流对象
data: 字典格式的数据(可选)
**kwargs: 其他参数
Returns:
Relationship: 关系对象
"""
# 确定user_id和platform
if chat_stream.user_info is not None:
user_id = chat_stream.user_info.user_id
platform = chat_stream.user_info.platform or 'qq'
else:
platform = platform or 'qq'
if user_id is None:
raise ValueError("必须提供user_id或user_info")
# 使用(user_id, platform)作为键
key = (user_id, platform)
async def update_relationship(self, user_id: int, data=None, **kwargs):
# 检查是否在内存中已存在 # 检查是否在内存中已存在
relationship = self.relationships.get(user_id) relationship = self.relationships.get(key)
if relationship: if relationship:
# 如果存在,更新现有对象 # 如果存在,更新现有对象
if isinstance(data, dict): if isinstance(data, dict):
for key, value in data.items(): for k, value in data.items():
if hasattr(relationship, key) and value is not None: if hasattr(relationship, k) and value is not None:
setattr(relationship, key, value) setattr(relationship, k, value)
else:
for key, value in kwargs.items():
if hasattr(relationship, key) and value is not None:
setattr(relationship, key, value)
else: else:
# 如果不存在,创建新对象 # 如果不存在,创建新对象
relationship = Relationship(user_id, data=data) if isinstance(data, dict) else Relationship(user_id, if chat_stream.user_info is not None:
**kwargs) relationship = Relationship(chat=chat_stream, **kwargs)
self.relationships[user_id] = relationship else:
raise ValueError("必须提供user_id或user_info")
# 更新 id_name_nickname_table self.relationships[key] = relationship
# self.id_name_nickname_table[user_id] = [relationship.nickname] # 别称设置为空列表
# 保存到数据库 # 保存到数据库
await self.storage_relationship(relationship) await self.storage_relationship(relationship)
@@ -75,32 +84,86 @@ class RelationshipManager:
return relationship return relationship
async def update_relationship_value(self, user_id: int, **kwargs): async def update_relationship_value(self,
chat_stream:ChatStream,
**kwargs) -> Optional[Relationship]:
"""更新关系值
Args:
user_id: 用户ID可选如果提供user_info则不需要
platform: 平台可选如果提供user_info则不需要
user_info: 用户信息对象(可选)
**kwargs: 其他参数
Returns:
Relationship: 关系对象
"""
# 确定user_id和platform
user_info = chat_stream.user_info
if user_info is not None:
user_id = user_info.user_id
platform = user_info.platform or 'qq'
else:
platform = platform or 'qq'
if user_id is None:
raise ValueError("必须提供user_id或user_info")
# 使用(user_id, platform)作为键
key = (user_id, platform)
# 检查是否在内存中已存在 # 检查是否在内存中已存在
relationship = self.relationships.get(user_id) relationship = self.relationships.get(key)
if relationship: if relationship:
for key, value in kwargs.items(): for k, value in kwargs.items():
if key == 'relationship_value': if k == 'relationship_value':
relationship.relationship_value += value relationship.relationship_value += value
await self.storage_relationship(relationship) await self.storage_relationship(relationship)
relationship.saved = True relationship.saved = True
return relationship return relationship
else: else:
logger.warning(f"用户 {user_id} 不存在,无法更新") # 如果不存在且提供了user_info则创建新的关系
if user_info is not None:
return await self.update_relationship(chat_stream=chat_stream, **kwargs)
logger.warning(f"[关系管理] 用户 {user_id}({platform}) 不存在,无法更新")
return None return None
def get_relationship(self, user_id: int) -> Optional[Relationship]: def get_relationship(self,
"""获取用户关系对象""" chat_stream:ChatStream) -> Optional[Relationship]:
if user_id in self.relationships: """获取用户关系对象
return self.relationships[user_id] Args:
user_id: 用户ID可选如果提供user_info则不需要
platform: 平台可选如果提供user_info则不需要
user_info: 用户信息对象(可选)
Returns:
Relationship: 关系对象
"""
# 确定user_id和platform
user_info = chat_stream.user_info
platform = chat_stream.user_info.platform or 'qq'
if user_info is not None:
user_id = user_info.user_id
platform = user_info.platform or 'qq'
else:
platform = platform or 'qq'
if user_id is None:
raise ValueError("必须提供user_id或user_info")
key = (user_id, platform)
if key in self.relationships:
return self.relationships[key]
else: else:
return 0 return 0
async def load_relationship(self, data: dict) -> Relationship: async def load_relationship(self, data: dict) -> Relationship:
"""从数据库加载或创建新的关系对象""" """从数据库加载或创建新的关系对象"""
rela = Relationship(user_id=data['user_id'], data=data) # 确保data中有platform字段如果没有则默认为'qq'
if 'platform' not in data:
data['platform'] = 'qq'
rela = Relationship(data=data)
rela.saved = True rela.saved = True
self.relationships[rela.user_id] = rela key = (rela.user_id, rela.platform)
self.relationships[key] = rela
return rela return rela
async def load_all_relationships(self): async def load_all_relationships(self):
@@ -117,10 +180,8 @@ class RelationshipManager:
all_relationships = db.db.relationships.find({}) all_relationships = db.db.relationships.find({})
# 依次加载每条记录 # 依次加载每条记录
for data in all_relationships: for data in all_relationships:
user_id = data['user_id'] await self.load_relationship(data)
relationship = await self.load_relationship(data) logger.debug(f"[关系管理] 已加载 {len(self.relationships)} 条关系记录")
self.relationships[user_id] = relationship
logger.debug(f"已加载 {len(self.relationships)} 条关系记录")
while True: while True:
logger.debug("正在自动保存关系") logger.debug("正在自动保存关系")
@@ -130,16 +191,15 @@ class RelationshipManager:
async def _save_all_relationships(self): async def _save_all_relationships(self):
"""将所有关系数据保存到数据库""" """将所有关系数据保存到数据库"""
# 保存所有关系数据 # 保存所有关系数据
for userid, relationship in self.relationships.items(): for (userid, platform), relationship in self.relationships.items():
if not relationship.saved: if not relationship.saved:
relationship.saved = True relationship.saved = True
await self.storage_relationship(relationship) await self.storage_relationship(relationship)
async def storage_relationship(self, relationship: Relationship): async def storage_relationship(self, relationship: Relationship):
""" """将关系记录存储到数据库中"""
将关系记录存储到数据库中
"""
user_id = relationship.user_id user_id = relationship.user_id
platform = relationship.platform
nickname = relationship.nickname nickname = relationship.nickname
relationship_value = relationship.relationship_value relationship_value = relationship.relationship_value
gender = relationship.gender gender = relationship.gender
@@ -148,8 +208,9 @@ class RelationshipManager:
db = Database.get_instance() db = Database.get_instance()
db.db.relationships.update_one( db.db.relationships.update_one(
{'user_id': user_id}, {'user_id': user_id, 'platform': platform},
{'$set': { {'$set': {
'platform': platform,
'nickname': nickname, 'nickname': nickname,
'relationship_value': relationship_value, 'relationship_value': relationship_value,
'gender': gender, 'gender': gender,
@@ -159,12 +220,36 @@ class RelationshipManager:
upsert=True upsert=True
) )
def get_name(self, user_id: int) -> str:
def get_name(self,
user_id: int = None,
platform: str = None,
user_info: UserInfo = None) -> str:
"""获取用户昵称
Args:
user_id: 用户ID可选如果提供user_info则不需要
platform: 平台可选如果提供user_info则不需要
user_info: 用户信息对象(可选)
Returns:
str: 用户昵称
"""
# 确定user_id和platform
if user_info is not None:
user_id = user_info.user_id
platform = user_info.platform or 'qq'
else:
platform = platform or 'qq'
if user_id is None:
raise ValueError("必须提供user_id或user_info")
# 确保user_id是整数类型 # 确保user_id是整数类型
user_id = int(user_id) user_id = int(user_id)
if user_id in self.relationships: key = (user_id, platform)
if key in self.relationships:
return self.relationships[user_id].nickname return self.relationships[key].nickname
elif user_info is not None:
return user_info.user_nickname or user_info.user_cardname or "某人"
else: else:
return "某人" return "某人"

View File

@@ -1,7 +1,10 @@
from typing import Optional from typing import Optional, Union
from typing import Optional, Union
from ...common.database import Database from ...common.database import Database
from .message import Message from .message_base import MessageBase
from .message import MessageSending, MessageRecv
from .chat_stream import ChatStream
from loguru import logger from loguru import logger
@@ -9,40 +12,19 @@ class MessageStorage:
def __init__(self): def __init__(self):
self.db = Database.get_instance() self.db = Database.get_instance()
async def store_message(self, message: Message, topic: Optional[str] = None) -> None: async def store_message(self, message: Union[MessageSending, MessageRecv],chat_stream:ChatStream, topic: Optional[str] = None) -> None:
"""存储消息到数据库""" """存储消息到数据库"""
try: try:
if not message.is_emoji:
message_data = { message_data = {
"group_id": message.group_id, "message_id": message.message_info.message_id,
"user_id": message.user_id, "time": message.message_info.time,
"message_id": message.message_id, "chat_id":chat_stream.stream_id,
"raw_message": message.raw_message, "chat_info": chat_stream.to_dict(),
"plain_text": message.plain_text, "user_info": message.message_info.user_info.to_dict(),
"processed_plain_text": message.processed_plain_text, "processed_plain_text": message.processed_plain_text,
"time": message.time,
"user_nickname": message.user_nickname,
"user_cardname": message.user_cardname,
"group_name": message.group_name,
"topic": topic,
"detailed_plain_text": message.detailed_plain_text, "detailed_plain_text": message.detailed_plain_text,
}
else:
message_data = {
"group_id": message.group_id,
"user_id": message.user_id,
"message_id": message.message_id,
"raw_message": message.raw_message,
"plain_text": message.plain_text,
"processed_plain_text": '[表情包]',
"time": message.time,
"user_nickname": message.user_nickname,
"user_cardname": message.user_cardname,
"group_name": message.group_name,
"topic": topic, "topic": topic,
"detailed_plain_text": message.detailed_plain_text,
} }
self.db.db.messages.insert_one(message_data) self.db.db.messages.insert_one(message_data)
except Exception: except Exception:
logger.exception("存储消息失败") logger.exception("存储消息失败")

View File

@@ -12,32 +12,15 @@ from loguru import logger
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
from ..utils.typo_generator import ChineseTypoGenerator from ..utils.typo_generator import ChineseTypoGenerator
from .config import global_config from .config import global_config
from .message import Message from .message import MessageThinking, MessageRecv,MessageSending,MessageProcessBase,Message
from .message_base import MessageBase,BaseMessageInfo,UserInfo,GroupInfo
from .chat_stream import ChatStream
from ..moods.moods import MoodManager from ..moods.moods import MoodManager
driver = get_driver() driver = get_driver()
config = driver.config config = driver.config
def combine_messages(messages: List[Message]) -> str:
"""将消息列表组合成格式化的字符串
Args:
messages: Message对象列表
Returns:
str: 格式化后的消息字符串
"""
result = ""
for message in messages:
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message.time))
name = message.user_nickname or f"用户{message.user_id}"
content = message.processed_plain_text or message.plain_text
result += f"[{time_str}] {name}: {content}\n"
return result
def db_message_to_str(message_dict: Dict) -> str: def db_message_to_str(message_dict: Dict) -> str:
logger.debug(f"message_dict: {message_dict}") logger.debug(f"message_dict: {message_dict}")
@@ -53,14 +36,11 @@ def db_message_to_str(message_dict: Dict) -> str:
return result return result
def is_mentioned_bot_in_txt(message: str) -> bool: def is_mentioned_bot_in_message(message: MessageRecv) -> bool:
"""检查消息是否提到了机器人""" """检查消息是否提到了机器人"""
if global_config.BOT_NICKNAME is None: keywords = [global_config.BOT_NICKNAME]
return True for keyword in keywords:
if global_config.BOT_NICKNAME in message: if keyword in message.processed_plain_text:
return True
for keyword in global_config.BOT_ALIAS_NAMES:
if keyword in message:
return True return True
return False return False
@@ -93,46 +73,45 @@ def calculate_information_content(text):
def get_cloest_chat_from_db(db, length: int, timestamp: str): def get_cloest_chat_from_db(db, length: int, timestamp: str):
"""从数据库中获取最接近指定时间戳的聊天记录,并记录读取次数 """从数据库中获取最接近指定时间戳的聊天记录
Args:
db: 数据库实例
length: 要获取的消息数量
timestamp: 时间戳
Returns: Returns:
list: 消息记录字典列表,每个字典包含消息内容和时间信息 list: 消息记录列表,每个记录包含时间和文本信息
""" """
chat_records = [] chat_records = []
closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)])
if closest_record and closest_record.get('memorized', 0) < 4: if closest_record:
closest_time = closest_record['time'] closest_time = closest_record['time']
group_id = closest_record['group_id'] chat_id = closest_record['chat_id'] # 获取chat_id
# 获取该时间戳之后的length条消息且groupid相同 # 获取该时间戳之后的length条消息保持相同的chat_id
records = list(db.db.messages.find( chat_records = list(db.db.messages.find(
{"time": {"$gt": closest_time}, "group_id": group_id} {
"time": {"$gt": closest_time},
"chat_id": chat_id # 添加chat_id过滤
}
).sort('time', 1).limit(length)) ).sort('time', 1).limit(length))
# 更新每条消息的memorized属性 # 转换记录格式
for record in records: formatted_records = []
current_memorized = record.get('memorized', 0) for record in chat_records:
if current_memorized > 3: formatted_records.append({
print("消息已读取3次跳过")
return ''
# 更新memorized值
db.db.messages.update_one(
{"_id": record["_id"]},
{"$set": {"memorized": current_memorized + 1}}
)
# 添加到记录列表中
chat_records.append({
'text': record["detailed_plain_text"],
'time': record["time"], 'time': record["time"],
'group_id': record["group_id"] 'chat_id': record["chat_id"],
'detailed_plain_text': record.get("detailed_plain_text", "") # 添加文本内容
}) })
return chat_records return formatted_records
return []
async def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list: async def get_recent_group_messages(db, chat_id:str, limit: int = 12) -> list:
"""从数据库获取群组最近的消息记录 """从数据库获取群组最近的消息记录
Args: Args:
@@ -146,35 +125,28 @@ async def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
# 从数据库获取最近消息 # 从数据库获取最近消息
recent_messages = list(db.db.messages.find( recent_messages = list(db.db.messages.find(
{"group_id": group_id}, {"chat_id": chat_id},
# {
# "time": 1,
# "user_id": 1,
# "user_nickname": 1,
# "message_id": 1,
# "raw_message": 1,
# "processed_text": 1
# }
).sort("time", -1).limit(limit)) ).sort("time", -1).limit(limit))
if not recent_messages: if not recent_messages:
return [] return []
# 转换为 Message对象列表 # 转换为 Message对象列表
from .message import Message
message_objects = [] message_objects = []
for msg_data in recent_messages: for msg_data in recent_messages:
try: try:
chat_info=msg_data.get("chat_info",{})
chat_stream=ChatStream.from_dict(chat_info)
user_info=msg_data.get("user_info",{})
user_info=UserInfo.from_dict(user_info)
msg = Message( msg = Message(
time=msg_data["time"],
user_id=msg_data["user_id"],
user_nickname=msg_data.get("user_nickname", ""),
message_id=msg_data["message_id"], message_id=msg_data["message_id"],
raw_message=msg_data["raw_message"], chat_stream=chat_stream,
time=msg_data["time"],
user_info=user_info,
processed_plain_text=msg_data.get("processed_text", ""), processed_plain_text=msg_data.get("processed_text", ""),
group_id=group_id detailed_plain_text=msg_data.get("detailed_plain_text", "")
) )
await msg.initialize()
message_objects.append(msg) message_objects.append(msg)
except KeyError: except KeyError:
logger.warning("数据库中存在无效的消息") logger.warning("数据库中存在无效的消息")
@@ -185,13 +157,14 @@ async def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
return message_objects return message_objects
def get_recent_group_detailed_plain_text(db, group_id: int, limit: int = 12, combine=False): def get_recent_group_detailed_plain_text(db, chat_stream_id: int, limit: int = 12, combine=False):
recent_messages = list(db.db.messages.find( recent_messages = list(db.db.messages.find(
{"group_id": group_id}, {"chat_id": chat_stream_id},
{ {
"time": 1, # 返回时间字段 "time": 1, # 返回时间字段
"user_id": 1, # 返回用户ID字段 "chat_id":1,
"user_nickname": 1, # 返回用户昵称字段 "chat_info":1,
"user_info": 1,
"message_id": 1, # 返回消息ID字段 "message_id": 1, # 返回消息ID字段
"detailed_plain_text": 1 # 返回处理后的文本字段 "detailed_plain_text": 1 # 返回处理后的文本字段
} }

View File

@@ -2,7 +2,11 @@ import base64
import io import io
import os import os
import time import time
import zlib # 用于 CRC32 import zlib
import aiohttp
import hashlib
from typing import Optional, Tuple, Union
from urllib.parse import urlparse
from loguru import logger from loguru import logger
from nonebot import get_driver from nonebot import get_driver
@@ -10,280 +14,353 @@ from PIL import Image
from ...common.database import Database from ...common.database import Database
from ..chat.config import global_config from ..chat.config import global_config
from ..models.utils_model import LLM_request
driver = get_driver() driver = get_driver()
config = driver.config config = driver.config
class ImageManager:
_instance = None
IMAGE_DIR = "data" # 图像存储根目录
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance.db = None
cls._instance._initialized = False
return cls._instance
def __init__(self):
if not self._initialized:
self.db = Database.get_instance()
self._ensure_image_collection()
self._ensure_description_collection()
self._ensure_image_dir()
self._initialized = True
self._llm = LLM_request(model=global_config.vlm, temperature=0.4, max_tokens=300)
def _ensure_image_dir(self):
"""确保图像存储目录存在"""
os.makedirs(self.IMAGE_DIR, exist_ok=True)
def _ensure_image_collection(self):
"""确保images集合存在并创建索引"""
if 'images' not in self.db.db.list_collection_names():
self.db.db.create_collection('images')
# 创建索引
self.db.db.images.create_index([('hash', 1)], unique=True)
self.db.db.images.create_index([('url', 1)])
self.db.db.images.create_index([('path', 1)])
def _ensure_description_collection(self):
"""确保image_descriptions集合存在并创建索引"""
if 'image_descriptions' not in self.db.db.list_collection_names():
self.db.db.create_collection('image_descriptions')
# 创建索引
self.db.db.image_descriptions.create_index([('hash', 1)], unique=True)
self.db.db.image_descriptions.create_index([('type', 1)])
def _get_description_from_db(self, image_hash: str, description_type: str) -> Optional[str]:
"""从数据库获取图片描述
def storage_compress_image(base64_data: str, max_size: int = 200) -> str:
"""
压缩base64格式的图片到指定大小单位KB并在数据库中记录图片信息
Args: Args:
base64_data: base64编码的图片数据 image_hash: 图片哈希值
max_size: 最大文件大小KB description_type: 描述类型 ('emoji''image')
Returns: Returns:
str: 压缩后的base64图片数据 Optional[str]: 描述文本如果不存在则返回None
"""
result= self.db.db.image_descriptions.find_one({
'hash': image_hash,
'type': description_type
})
return result['description'] if result else None
def _save_description_to_db(self, image_hash: str, description: str, description_type: str) -> None:
"""保存图片描述到数据库
Args:
image_hash: 图片哈希值
description: 描述文本
description_type: 描述类型 ('emoji''image')
"""
self.db.db.image_descriptions.update_one(
{'hash': image_hash, 'type': description_type},
{
'$set': {
'description': description,
'timestamp': int(time.time())
}
},
upsert=True
)
async def save_image(self,
image_data: Union[str, bytes],
url: str = None,
description: str = None,
is_base64: bool = False) -> Optional[str]:
"""保存图像
Args:
image_data: 图像数据(base64字符串或字节)
url: 图像URL
description: 图像描述
is_base64: image_data是否为base64格式
Returns:
str: 保存后的文件路径,失败返回None
""" """
try: try:
# 将base64转换为字节数据 # 转换为字节格式
image_data = base64.b64decode(base64_data) if is_base64:
if isinstance(image_data, str):
# 使用 CRC32 计算哈希值 image_bytes = base64.b64decode(image_data)
hash_value = format(zlib.crc32(image_data) & 0xFFFFFFFF, 'x')
# 确保图片目录存在
images_dir = "data/images"
os.makedirs(images_dir, exist_ok=True)
# 连接数据库
db = Database.get_instance()
# 检查是否已存在相同哈希值的图片
collection = db.db['images']
existing_image = collection.find_one({'hash': hash_value})
if existing_image:
print(f"\033[1;33m[提示]\033[0m 发现重复图片,使用已存在的文件: {existing_image['path']}")
return base64_data
# 将字节数据转换为图片对象
img = Image.open(io.BytesIO(image_data))
# 如果是动图,直接返回原图
if getattr(img, 'is_animated', False):
return base64_data
# 计算当前大小KB
current_size = len(image_data) / 1024
# 如果已经小于目标大小,直接使用原图
if current_size <= max_size:
compressed_data = image_data
else: else:
# 压缩逻辑 return None
# 先缩放到50% else:
new_width = int(img.width * 0.5) if isinstance(image_data, bytes):
new_height = int(img.height * 0.5) image_bytes = image_data
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) else:
return None
# 如果缩放后的最大边长仍然大于400继续缩放 # 计算哈希值
max_dimension = 400 image_hash = hashlib.md5(image_bytes).hexdigest()
max_current = max(new_width, new_height)
if max_current > max_dimension:
ratio = max_dimension / max_current
new_width = int(new_width * ratio)
new_height = int(new_height * ratio)
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
# 转换为RGB模式去除透明通道 # 查重
if img.mode in ('RGBA', 'P'): existing = self.db.db.images.find_one({'hash': image_hash})
img = img.convert('RGB') if existing:
return existing['path']
# 使用固定质量参数压缩 # 生成文件名和路径
output = io.BytesIO()
img.save(output, format='JPEG', quality=85, optimize=True)
compressed_data = output.getvalue()
# 生成文件名(使用时间戳和哈希值确保唯一性)
timestamp = int(time.time()) timestamp = int(time.time())
filename = f"{timestamp}_{hash_value}.jpg" filename = f"{timestamp}_{image_hash[:8]}.jpg"
image_path = os.path.join(images_dir, filename) file_path = os.path.join(self.IMAGE_DIR, filename)
# 保存文件 # 保存文件
with open(image_path, "wb") as f: with open(file_path, "wb") as f:
f.write(compressed_data) f.write(image_bytes)
print(f"\033[1;32m[成功]\033[0m 保存图片到: {image_path}") # 保存到数据库
image_doc = {
try: 'hash': image_hash,
# 准备数据库记录 'path': file_path,
image_record = { 'url': url,
'filename': filename, 'description': description,
'path': image_path, 'timestamp': timestamp
'size': len(compressed_data) / 1024,
'timestamp': timestamp,
'width': img.width,
'height': img.height,
'description': '',
'tags': [],
'type': 'image',
'hash': hash_value
} }
self.db.db.images.insert_one(image_doc)
# 保存记录 return file_path
collection.insert_one(image_record)
print("\033[1;32m[成功]\033[0m 保存图片记录到数据库")
except Exception as db_error:
print(f"\033[1;31m[错误]\033[0m 数据库操作失败: {str(db_error)}")
# 将压缩后的数据转换为base64
compressed_base64 = base64.b64encode(compressed_data).decode('utf-8')
return compressed_base64
except Exception as e: except Exception as e:
print(f"\033[1;31m[错误]\033[0m 压缩图片失败: {str(e)}") logger.error(f"保存图像失败: {str(e)}")
import traceback return None
print(traceback.format_exc())
return base64_data
def storage_emoji(image_data: bytes) -> bytes: async def get_image_by_url(self, url: str) -> Optional[str]:
""" """根据URL获取图像路径(带查重)
存储表情包到本地文件夹
Args: Args:
image_data: 图片字节数据 url: 图像URL
group_id: 群组ID仅用于日志
user_id: 用户ID仅用于日志
Returns: Returns:
bytes: 原始图片数据 str: 本地文件路径,不存在返回None
""" """
if not global_config.EMOJI_SAVE:
return image_data
try: try:
# 使用 CRC32 计算哈希值 # 先查找是否已存在
hash_value = format(zlib.crc32(image_data) & 0xFFFFFFFF, 'x') existing = self.db.db.images.find_one({'url': url})
if existing:
return existing['path']
# 确保表情包目录存在 # 下载图像
emoji_dir = "data/emoji" async with aiohttp.ClientSession() as session:
os.makedirs(emoji_dir, exist_ok=True) async with session.get(url) as resp:
if resp.status == 200:
image_bytes = await resp.read()
return await self.save_image(image_bytes, url=url)
return None
# 检查是否已存在相同哈希值的文件 except Exception as e:
for filename in os.listdir(emoji_dir): logger.error(f"获取图像失败: {str(e)}")
if hash_value in filename: return None
# print(f"\033[1;33m[提示]\033[0m 发现重复表情包: {filename}")
return image_data
# 生成文件名 async def get_base64_by_url(self, url: str) -> Optional[str]:
"""根据URL获取base64(带查重)
Args:
url: 图像URL
Returns:
str: base64字符串,失败返回None
"""
try:
image_path = await self.get_image_by_url(url)
if not image_path:
return None
with open(image_path, 'rb') as f:
image_bytes = f.read()
return base64.b64encode(image_bytes).decode('utf-8')
except Exception as e:
logger.error(f"获取base64失败: {str(e)}")
return None
async def save_base64_image(self, base64_str: str, description: str = None) -> Optional[str]:
"""保存base64图像(带查重)
Args:
base64_str: base64字符串
description: 图像描述
Returns:
str: 保存路径,失败返回None
"""
return await self.save_image(base64_str, description=description, is_base64=True)
def check_url_exists(self, url: str) -> bool:
"""检查URL是否已存在
Args:
url: 图像URL
Returns:
bool: 是否存在
"""
return self.db.db.images.find_one({'url': url}) is not None
def check_hash_exists(self, image_data: Union[str, bytes], is_base64: bool = False) -> bool:
"""检查图像是否已存在
Args:
image_data: 图像数据(base64或字节)
is_base64: 是否为base64格式
Returns:
bool: 是否存在
"""
try:
if is_base64:
if isinstance(image_data, str):
image_bytes = base64.b64decode(image_data)
else:
return False
else:
if isinstance(image_data, bytes):
image_bytes = image_data
else:
return False
image_hash = hashlib.md5(image_bytes).hexdigest()
return self.db.db.images.find_one({'hash': image_hash}) is not None
except Exception as e:
logger.error(f"检查哈希失败: {str(e)}")
return False
async def get_emoji_description(self, image_base64: str) -> str:
"""获取表情包描述,带查重和保存功能"""
try:
# 计算图片哈希
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
# 查询缓存的描述
cached_description = self._get_description_from_db(image_hash, 'emoji')
if cached_description:
logger.info(f"缓存表情包描述: {cached_description}")
return f"[表情包:{cached_description}]"
# 调用AI获取描述
prompt = "这是一个表情包,使用中文简洁的描述一下表情包的内容和表情包所表达的情感"
description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
# 根据配置决定是否保存图片
if global_config.EMOJI_SAVE:
# 生成文件名和路径
timestamp = int(time.time()) timestamp = int(time.time())
filename = f"{timestamp}_{hash_value}.jpg" filename = f"emoji_{timestamp}_{image_hash[:8]}.jpg"
emoji_path = os.path.join(emoji_dir, filename) file_path = os.path.join(self.IMAGE_DIR, filename)
# 直接保存原始文件
with open(emoji_path, "wb") as f:
f.write(image_data)
print(f"\033[1;32m[成功]\033[0m 保存表情包到: {emoji_path}")
return image_data
except Exception as e:
print(f"\033[1;31m[错误]\033[0m 保存表情包失败: {str(e)}")
return image_data
def storage_image(image_data: bytes) -> bytes:
"""
存储图片到本地文件夹
Args:
image_data: 图片字节数据
group_id: 群组ID仅用于日志
user_id: 用户ID仅用于日志
Returns:
bytes: 原始图片数据
"""
try: try:
# 使用 CRC32 计算哈希值 # 保存文件
hash_value = format(zlib.crc32(image_data) & 0xFFFFFFFF, 'x') with open(file_path, "wb") as f:
f.write(image_bytes)
# 确保表情包目录存在 # 保存到数据库
image_dir = "data/image" image_doc = {
os.makedirs(image_dir, exist_ok=True) 'hash': image_hash,
'path': file_path,
# 检查是否已存在相同哈希值的文件 'type': 'emoji',
for filename in os.listdir(image_dir): 'description': description,
if hash_value in filename: 'timestamp': timestamp
# print(f"\033[1;33m[提示]\033[0m 发现重复表情包: {filename}") }
return image_data self.db.db.images.update_one(
{'hash': image_hash},
# 生成文件名 {'$set': image_doc},
timestamp = int(time.time()) upsert=True
filename = f"{timestamp}_{hash_value}.jpg"
image_path = os.path.join(image_dir, filename)
# 直接保存原始文件
with open(image_path, "wb") as f:
f.write(image_data)
print(f"\033[1;32m[成功]\033[0m 保存图片到: {image_path}")
return image_data
except Exception as e:
print(f"\033[1;31m[错误]\033[0m 保存图片失败: {str(e)}")
return image_data
def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 1024 * 1024) -> str:
"""压缩base64格式的图片到指定大小
Args:
base64_data: base64编码的图片数据
target_size: 目标文件大小字节默认0.8MB
Returns:
str: 压缩后的base64图片数据
"""
try:
# 将base64转换为字节数据
image_data = base64.b64decode(base64_data)
# 如果已经小于目标大小,直接返回原图
if len(image_data) <= 2*1024*1024:
return base64_data
# 将字节数据转换为图片对象
img = Image.open(io.BytesIO(image_data))
# 获取原始尺寸
original_width, original_height = img.size
# 计算缩放比例
scale = min(1.0, (target_size / len(image_data)) ** 0.5)
# 计算新的尺寸
new_width = int(original_width * scale)
new_height = int(original_height * scale)
# 创建内存缓冲区
output_buffer = io.BytesIO()
# 如果是GIF处理所有帧
if getattr(img, "is_animated", False):
frames = []
for frame_idx in range(img.n_frames):
img.seek(frame_idx)
new_frame = img.copy()
new_frame = new_frame.resize((new_width//2, new_height//2), Image.Resampling.LANCZOS) # 动图折上折
frames.append(new_frame)
# 保存到缓冲区
frames[0].save(
output_buffer,
format='GIF',
save_all=True,
append_images=frames[1:],
optimize=True,
duration=img.info.get('duration', 100),
loop=img.info.get('loop', 0)
) )
else: logger.success(f"保存表情包: {file_path}")
# 处理静态图片
resized_img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
# 保存到缓冲区,保持原始格式
if img.format == 'PNG' and img.mode in ('RGBA', 'LA'):
resized_img.save(output_buffer, format='PNG', optimize=True)
else:
resized_img.save(output_buffer, format='JPEG', quality=95, optimize=True)
# 获取压缩后的数据并转换为base64
compressed_data = output_buffer.getvalue()
logger.success(f"压缩图片: {original_width}x{original_height} -> {new_width}x{new_height}")
logger.info(f"压缩前大小: {len(image_data)/1024:.1f}KB, 压缩后大小: {len(compressed_data)/1024:.1f}KB")
return base64.b64encode(compressed_data).decode('utf-8')
except Exception as e: except Exception as e:
logger.error(f"压缩图片失败: {str(e)}") logger.error(f"保存表情包文件失败: {str(e)}")
import traceback
logger.error(traceback.format_exc()) # 保存描述到数据库
return base64_data self._save_description_to_db(image_hash, description, 'emoji')
return f"[表情包:{description}]"
except Exception as e:
logger.error(f"获取表情包描述失败: {str(e)}")
return "[表情包]"
async def get_image_description(self, image_base64: str) -> str:
"""获取普通图片描述,带查重和保存功能"""
try:
# 计算图片哈希
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
# 查询缓存的描述
cached_description = self._get_description_from_db(image_hash, 'image')
if cached_description:
return f"[图片:{cached_description}]"
# 调用AI获取描述
prompt = "请用中文描述这张图片的内容。如果有文字请把文字都描述出来。并尝试猜测这个图片的含义。最多200个字。"
description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
if description is None:
logger.warning("AI未能生成图片描述")
return "[图片]"
# 根据配置决定是否保存图片
if global_config.EMOJI_SAVE:
# 生成文件名和路径
timestamp = int(time.time())
filename = f"image_{timestamp}_{image_hash[:8]}.jpg"
file_path = os.path.join(self.IMAGE_DIR, filename)
try:
# 保存文件
with open(file_path, "wb") as f:
f.write(image_bytes)
# 保存到数据库
image_doc = {
'hash': image_hash,
'path': file_path,
'type': 'image',
'description': description,
'timestamp': timestamp
}
self.db.db.images.update_one(
{'hash': image_hash},
{'$set': image_doc},
upsert=True
)
logger.success(f"保存图片: {file_path}")
except Exception as e:
logger.error(f"保存图片文件失败: {str(e)}")
# 保存描述到数据库
self._save_description_to_db(image_hash, description, 'image')
return f"[图片:{description}]"
except Exception as e:
logger.error(f"获取图片描述失败: {str(e)}")
return "[图片]"
# 创建全局单例
image_manager = ImageManager()
def image_path_to_base64(image_path: str) -> str: def image_path_to_base64(image_path: str) -> str:
"""将图片路径转换为base64编码 """将图片路径转换为base64编码

View File

@@ -1,106 +1,108 @@
import asyncio import asyncio
from typing import Dict
from loguru import logger from loguru import logger
from typing import Dict
from loguru import logger
from .config import global_config from .config import global_config
from .message_base import UserInfo, GroupInfo
from .chat_stream import chat_manager,ChatStream
class WillingManager: class WillingManager:
def __init__(self): def __init__(self):
self.group_reply_willing = {} # 存储每个的回复意愿 self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿
self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿
self._decay_task = None self._decay_task = None
self._started = False self._started = False
self.min_reply_willing = 0.01
self.attenuation_coefficient = 0.75
async def _decay_reply_willing(self): async def _decay_reply_willing(self):
"""定期衰减回复意愿""" """定期衰减回复意愿"""
while True: while True:
await asyncio.sleep(5) await asyncio.sleep(5)
for group_id in self.group_reply_willing: for chat_id in self.chat_reply_willing:
self.group_reply_willing[group_id] = max( self.chat_reply_willing[chat_id] = max(0, self.chat_reply_willing[chat_id] * 0.6)
self.min_reply_willing, for chat_id in self.chat_reply_willing:
self.group_reply_willing[group_id] * self.attenuation_coefficient self.chat_reply_willing[chat_id] = max(0, self.chat_reply_willing[chat_id] * 0.6)
)
def get_willing(self, group_id: int) -> float: def get_willing(self,chat_stream:ChatStream) -> float:
"""获取指定群组的回复意愿""" """获取指定聊天流的回复意愿"""
return self.group_reply_willing.get(group_id, 0) stream = chat_stream
if stream:
return self.chat_reply_willing.get(stream.stream_id, 0)
return 0
def set_willing(self, group_id: int, willing: float): def set_willing(self, chat_id: str, willing: float):
"""设置指定群组的回复意愿""" """设置指定聊天流的回复意愿"""
self.group_reply_willing[group_id] = willing self.chat_reply_willing[chat_id] = willing
def set_willing(self, chat_id: str, willing: float):
"""设置指定聊天流的回复意愿"""
self.chat_reply_willing[chat_id] = willing
def change_reply_willing_received(self, group_id: int, topic: str, is_mentioned_bot: bool, config, async def change_reply_willing_received(self,
user_id: int = None, is_emoji: bool = False, interested_rate: float = 0) -> float: chat_stream:ChatStream,
topic: str = None,
is_mentioned_bot: bool = False,
config = None,
is_emoji: bool = False,
interested_rate: float = 0) -> float:
"""改变指定聊天流的回复意愿并返回回复概率"""
# 获取或创建聊天流
stream = chat_stream
chat_id = stream.stream_id
# 若非目标回复群组则直接return current_willing = self.chat_reply_willing.get(chat_id, 0)
if group_id not in config.talk_allowed_groups:
reply_probability = 0
return reply_probability
current_willing = self.group_reply_willing.get(group_id, 0) # print(f"初始意愿: {current_willing}")
if is_mentioned_bot and current_willing < 1.0:
logger.debug(f"[{group_id}]的初始回复意愿: {current_willing}") current_willing += 0.9
print(f"被提及, 当前意愿: {current_willing}")
# 根据消息类型被cue/表情包)调控 elif is_mentioned_bot:
if is_mentioned_bot: current_willing += 0.05
current_willing = min( print(f"被重复提及, 当前意愿: {current_willing}")
3.0,
current_willing + 0.9
)
logger.debug(f"被提及, 当前意愿: {current_willing}")
if is_emoji: if is_emoji:
current_willing *= 0.1 current_willing *= 0.1
logger.debug(f"表情包, 当前意愿: {current_willing}") print(f"表情包, 当前意愿: {current_willing}")
# 兴趣放大系数,若兴趣 > 0.4则增加回复概率 print(f"放大系数_interested_rate: {global_config.response_interested_rate_amplifier}")
interested_rate_amplifier = global_config.response_interested_rate_amplifier interested_rate *= global_config.response_interested_rate_amplifier #放大回复兴趣度
logger.debug(f"放大系数_interested_rate: {interested_rate_amplifier}") if interested_rate > 0.4:
interested_rate *= interested_rate_amplifier # print(f"兴趣度: {interested_rate}, 当前意愿: {current_willing}")
current_willing += interested_rate-0.4
current_willing += max( current_willing *= global_config.response_willing_amplifier #放大回复意愿
0.0, # print(f"放大系数_willing: {global_config.response_willing_amplifier}, 当前意愿: {current_willing}")
interested_rate - 0.4
)
# 回复意愿系数调控,独立乘区 reply_probability = max((current_willing - 0.45) * 2, 0)
willing_amplifier = max(
global_config.response_willing_amplifier,
self.min_reply_willing
)
current_willing *= willing_amplifier
logger.debug(f"放大系数_willing: {global_config.response_willing_amplifier}, 当前意愿: {current_willing}")
# 回复概率迭代保底0.01回复概率 # 检查群组权限(如果是群聊)
reply_probability = max( if chat_stream.group_info:
(current_willing - 0.45) * 2, if chat_stream.group_info.group_id in config.talk_frequency_down_groups:
self.min_reply_willing reply_probability = reply_probability / global_config.down_frequency_rate
)
# 降低目标低频群组回复概率
down_frequency_rate = max(
1.0,
global_config.down_frequency_rate
)
if group_id in config.talk_frequency_down_groups:
reply_probability = reply_probability / down_frequency_rate
reply_probability = min(reply_probability, 1) reply_probability = min(reply_probability, 1)
if reply_probability < 0:
reply_probability = 0
self.group_reply_willing[group_id] = min(current_willing, 3.0) self.chat_reply_willing[chat_id] = min(current_willing, 3.0)
logger.debug(f"当前群组{group_id}回复概率:{reply_probability}")
return reply_probability return reply_probability
def change_reply_willing_sent(self, group_id: int): def change_reply_willing_sent(self, chat_stream:ChatStream):
"""开始思考后降低群组的回复意愿""" """开始思考后降低聊天流的回复意愿"""
current_willing = self.group_reply_willing.get(group_id, 0) stream = chat_stream
self.group_reply_willing[group_id] = max(0, current_willing - 2) if stream:
current_willing = self.chat_reply_willing.get(stream.stream_id, 0)
self.chat_reply_willing[stream.stream_id] = max(0, current_willing - 2)
def change_reply_willing_after_sent(self, group_id: int): def change_reply_willing_after_sent(self,chat_stream:ChatStream):
"""发送消息后提高群组的回复意愿""" """发送消息后提高聊天流的回复意愿"""
current_willing = self.group_reply_willing.get(group_id, 0) stream = chat_stream
if stream:
current_willing = self.chat_reply_willing.get(stream.stream_id, 0)
if current_willing < 1: if current_willing < 1:
self.group_reply_willing[group_id] = min(1, current_willing + 0.2) self.chat_reply_willing[stream.stream_id] = min(1, current_willing + 0.2)
async def ensure_started(self): async def ensure_started(self):
"""确保衰减任务已启动""" """确保衰减任务已启动"""
@@ -109,6 +111,5 @@ class WillingManager:
self._decay_task = asyncio.create_task(self._decay_reply_willing()) self._decay_task = asyncio.create_task(self._decay_reply_willing())
self._started = True self._started = True
# 创建全局实例 # 创建全局实例
willing_manager = WillingManager() willing_manager = WillingManager()

View File

@@ -239,7 +239,7 @@ class Hippocampus:
time_info += f"是从 {earliest_str}{latest_str} 的对话:\n" time_info += f"是从 {earliest_str}{latest_str} 的对话:\n"
for msg in messages: for msg in messages:
input_text += f"{msg['text']}\n" input_text += f"{msg['detailed_plain_text']}\n"
logger.debug(input_text) logger.debug(input_text)

View File

@@ -7,10 +7,11 @@ from typing import Tuple, Union
import aiohttp import aiohttp
from loguru import logger from loguru import logger
from nonebot import get_driver from nonebot import get_driver
import base64
from PIL import Image
import io
from ...common.database import Database from ...common.database import Database
from ..chat.config import global_config from ..chat.config import global_config
from ..chat.utils_image import compress_base64_image_by_scale
driver = get_driver() driver = get_driver()
config = driver.config config = driver.config
@@ -432,3 +433,78 @@ class LLM_request:
response_handler=embedding_handler response_handler=embedding_handler
) )
return embedding return embedding
def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 1024 * 1024) -> str:
"""压缩base64格式的图片到指定大小
Args:
base64_data: base64编码的图片数据
target_size: 目标文件大小字节默认0.8MB
Returns:
str: 压缩后的base64图片数据
"""
try:
# 将base64转换为字节数据
image_data = base64.b64decode(base64_data)
# 如果已经小于目标大小,直接返回原图
if len(image_data) <= 2*1024*1024:
return base64_data
# 将字节数据转换为图片对象
img = Image.open(io.BytesIO(image_data))
# 获取原始尺寸
original_width, original_height = img.size
# 计算缩放比例
scale = min(1.0, (target_size / len(image_data)) ** 0.5)
# 计算新的尺寸
new_width = int(original_width * scale)
new_height = int(original_height * scale)
# 创建内存缓冲区
output_buffer = io.BytesIO()
# 如果是GIF处理所有帧
if getattr(img, "is_animated", False):
frames = []
for frame_idx in range(img.n_frames):
img.seek(frame_idx)
new_frame = img.copy()
new_frame = new_frame.resize((new_width//2, new_height//2), Image.Resampling.LANCZOS) # 动图折上折
frames.append(new_frame)
# 保存到缓冲区
frames[0].save(
output_buffer,
format='GIF',
save_all=True,
append_images=frames[1:],
optimize=True,
duration=img.info.get('duration', 100),
loop=img.info.get('loop', 0)
)
else:
# 处理静态图片
resized_img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
# 保存到缓冲区,保持原始格式
if img.format == 'PNG' and img.mode in ('RGBA', 'LA'):
resized_img.save(output_buffer, format='PNG', optimize=True)
else:
resized_img.save(output_buffer, format='JPEG', quality=95, optimize=True)
# 获取压缩后的数据并转换为base64
compressed_data = output_buffer.getvalue()
logger.success(f"压缩图片: {original_width}x{original_height} -> {new_width}x{new_height}")
logger.info(f"压缩前大小: {len(image_data)/1024:.1f}KB, 压缩后大小: {len(compressed_data)/1024:.1f}KB")
return base64.b64encode(compressed_data).decode('utf-8')
except Exception as e:
logger.error(f"压缩图片失败: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return base64_data