Merge pull request #181 from tcmofashi/refractor
Refractor: 史上最好的消息流重构和图片管理系统
This commit is contained in:
3
.vscode/settings.json
vendored
3
.vscode/settings.json
vendored
@@ -1,3 +0,0 @@
|
||||
{
|
||||
"editor.formatOnSave": true
|
||||
}
|
||||
@@ -16,6 +16,7 @@ from .config import global_config
|
||||
from .emoji_manager import emoji_manager
|
||||
from .relationship_manager import relationship_manager
|
||||
from .willing_manager import willing_manager
|
||||
from .chat_stream import chat_manager
|
||||
from ..memory_system.memory import hippocampus, memory_graph
|
||||
from .bot import ChatBot
|
||||
from .message_sender import message_manager, message_sender
|
||||
@@ -98,6 +99,8 @@ async def _(bot: Bot):
|
||||
|
||||
asyncio.create_task(emoji_manager._periodic_scan(interval_MINS=global_config.EMOJI_REGISTER_INTERVAL))
|
||||
logger.success("-----------开始偷表情包!-----------")
|
||||
asyncio.create_task(chat_manager._initialize())
|
||||
asyncio.create_task(chat_manager._auto_save_task())
|
||||
|
||||
|
||||
@group_msg.handle()
|
||||
|
||||
@@ -1,28 +1,28 @@
|
||||
import re
|
||||
import time
|
||||
from random import random
|
||||
|
||||
from loguru import logger
|
||||
from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent
|
||||
|
||||
from ..memory_system.memory import hippocampus
|
||||
from ..moods.moods import MoodManager # 导入情绪管理器
|
||||
from .config import global_config
|
||||
from .cq_code import CQCode # 导入CQCode模块
|
||||
from .cq_code import CQCode,cq_code_tool # 导入CQCode模块
|
||||
from .emoji_manager import emoji_manager # 导入表情包管理器
|
||||
from .llm_generator import ResponseGenerator
|
||||
from .message import (
|
||||
Message,
|
||||
Message_Sending,
|
||||
Message_Thinking, # 导入 Message_Thinking 类
|
||||
MessageSet,
|
||||
from .message import MessageSending, MessageRecv, MessageThinking, MessageSet
|
||||
from .message_cq import (
|
||||
MessageRecvCQ,
|
||||
)
|
||||
from .chat_stream import chat_manager
|
||||
|
||||
from .message_sender import message_manager # 导入新的消息管理器
|
||||
from .relationship_manager import relationship_manager
|
||||
from .storage import MessageStorage
|
||||
from .utils import calculate_typing_time, is_mentioned_bot_in_txt
|
||||
from .utils import calculate_typing_time, is_mentioned_bot_in_message
|
||||
from .utils_image import image_path_to_base64
|
||||
from .willing_manager import willing_manager # 导入意愿管理器
|
||||
|
||||
from .message_base import UserInfo, GroupInfo, Seg
|
||||
|
||||
class ChatBot:
|
||||
def __init__(self):
|
||||
@@ -44,35 +44,61 @@ class ChatBot:
|
||||
async def handle_message(self, event: GroupMessageEvent, bot: Bot) -> None:
|
||||
"""处理收到的群消息"""
|
||||
|
||||
if event.group_id not in global_config.talk_allowed_groups:
|
||||
return
|
||||
self.bot = bot # 更新 bot 实例
|
||||
|
||||
# group_info = await bot.get_group_info(group_id=event.group_id)
|
||||
# sender_info = await bot.get_group_member_info(group_id=event.group_id, user_id=event.user_id, no_cache=True)
|
||||
|
||||
# 白名单设定由nontbot侧完成
|
||||
if event.group_id:
|
||||
if event.group_id not in global_config.talk_allowed_groups:
|
||||
return
|
||||
if event.user_id in global_config.ban_user_id:
|
||||
return
|
||||
|
||||
group_info = await bot.get_group_info(group_id=event.group_id)
|
||||
sender_info = await bot.get_group_member_info(group_id=event.group_id, user_id=event.user_id, no_cache=True)
|
||||
|
||||
await relationship_manager.update_relationship(user_id=event.user_id, data=sender_info)
|
||||
await relationship_manager.update_relationship_value(user_id=event.user_id, relationship_value=0.5)
|
||||
|
||||
message = Message(
|
||||
group_id=event.group_id,
|
||||
|
||||
user_info=UserInfo(
|
||||
user_id=event.user_id,
|
||||
message_id=event.message_id,
|
||||
user_cardname=sender_info['card'],
|
||||
raw_message=str(event.original_message),
|
||||
plain_text=event.get_plaintext(),
|
||||
reply_message=event.reply,
|
||||
user_nickname=event.sender.nickname,
|
||||
user_cardname=event.sender.card or None,
|
||||
platform='qq'
|
||||
)
|
||||
await message.initialize()
|
||||
|
||||
group_info=GroupInfo(
|
||||
group_id=event.group_id,
|
||||
group_name=None,
|
||||
platform='qq'
|
||||
)
|
||||
|
||||
message_cq=MessageRecvCQ(
|
||||
message_id=event.message_id,
|
||||
user_info=user_info,
|
||||
raw_message=str(event.original_message),
|
||||
group_info=group_info,
|
||||
reply_message=event.reply,
|
||||
platform='qq'
|
||||
)
|
||||
message_json=message_cq.to_dict()
|
||||
|
||||
# 进入maimbot
|
||||
message=MessageRecv(message_json)
|
||||
|
||||
groupinfo=message.message_info.group_info
|
||||
userinfo=message.message_info.user_info
|
||||
messageinfo=message.message_info
|
||||
|
||||
# 消息过滤,涉及到config有待更新
|
||||
|
||||
chat = await chat_manager.get_or_create_stream(platform=messageinfo.platform, user_info=userinfo, group_info=groupinfo)
|
||||
message.update_chat_stream(chat)
|
||||
await relationship_manager.update_relationship(chat_stream=chat,)
|
||||
await relationship_manager.update_relationship_value(chat_stream=chat, relationship_value = 0.5)
|
||||
|
||||
await message.process()
|
||||
# 过滤词
|
||||
for word in global_config.ban_words:
|
||||
if word in message.detailed_plain_text:
|
||||
if word in message.processed_plain_text:
|
||||
logger.info(
|
||||
f"[{message.group_name}]{message.user_nickname}:{message.processed_plain_text}")
|
||||
f"[{groupinfo.group_name}]{userinfo.user_nickname}:{message.processed_plain_text}")
|
||||
logger.info(f"[过滤词识别]消息中含有{word},filtered")
|
||||
return
|
||||
|
||||
@@ -83,8 +109,10 @@ class ChatBot:
|
||||
f"[{message.group_name}]{message.user_nickname}:{message.raw_message}")
|
||||
logger.info(f"[正则表达式过滤]消息匹配到{pattern},filtered")
|
||||
return
|
||||
|
||||
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(messageinfo.time))
|
||||
|
||||
|
||||
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(message.time))
|
||||
|
||||
# topic=await topic_identifier.identify_topic_llm(message.processed_plain_text)
|
||||
topic = ''
|
||||
@@ -93,47 +121,60 @@ class ChatBot:
|
||||
logger.debug(f"对{message.processed_plain_text}"
|
||||
f"的激活度:{interested_rate}")
|
||||
# logger.info(f"\033[1;32m[主题识别]\033[0m 使用{global_config.topic_extract}主题: {topic}")
|
||||
|
||||
await self.storage.store_message(message,chat, topic[0] if topic else None)
|
||||
|
||||
await self.storage.store_message(message, topic[0] if topic else None)
|
||||
|
||||
is_mentioned = is_mentioned_bot_in_txt(message.processed_plain_text)
|
||||
reply_probability = willing_manager.change_reply_willing_received(
|
||||
event.group_id,
|
||||
topic[0] if topic else None,
|
||||
is_mentioned,
|
||||
global_config,
|
||||
event.user_id,
|
||||
message.is_emoji,
|
||||
interested_rate
|
||||
is_mentioned = is_mentioned_bot_in_message(message)
|
||||
reply_probability = await willing_manager.change_reply_willing_received(
|
||||
chat_stream=chat,
|
||||
topic=topic[0] if topic else None,
|
||||
is_mentioned_bot=is_mentioned,
|
||||
config=global_config,
|
||||
is_emoji=message.is_emoji,
|
||||
interested_rate=interested_rate
|
||||
)
|
||||
current_willing = willing_manager.get_willing(event.group_id)
|
||||
|
||||
current_willing = willing_manager.get_willing(chat_stream=chat)
|
||||
|
||||
logger.info(
|
||||
f"[{current_time}][{message.group_name}]{message.user_nickname}:"
|
||||
f"{message.processed_plain_text}[回复意愿:{current_willing:.2f}][概率:{reply_probability * 100:.1f}%]")
|
||||
|
||||
response = ""
|
||||
f"[{current_time}][{chat.group_info.group_name}]{chat.user_info.user_nickname}:"
|
||||
f"{message.processed_plain_text}[回复意愿:{current_willing:.2f}][概率:{reply_probability * 100:.1f}%]"
|
||||
)
|
||||
|
||||
response = None
|
||||
|
||||
if random() < reply_probability:
|
||||
bot_user_info=UserInfo(
|
||||
user_id=global_config.BOT_QQ,
|
||||
user_nickname=global_config.BOT_NICKNAME,
|
||||
platform=messageinfo.platform
|
||||
)
|
||||
tinking_time_point = round(time.time(), 2)
|
||||
think_id = 'mt' + str(tinking_time_point)
|
||||
thinking_message = Message_Thinking(message=message, message_id=think_id)
|
||||
|
||||
thinking_message = MessageThinking(
|
||||
message_id=think_id,
|
||||
chat_stream=chat,
|
||||
bot_user_info=bot_user_info,
|
||||
reply=message
|
||||
)
|
||||
|
||||
message_manager.add_message(thinking_message)
|
||||
|
||||
willing_manager.change_reply_willing_sent(thinking_message.group_id)
|
||||
|
||||
response, raw_content = await self.gpt.generate_response(message)
|
||||
|
||||
willing_manager.change_reply_willing_sent(chat)
|
||||
|
||||
response,raw_content = await self.gpt.generate_response(message)
|
||||
|
||||
# print(f"response: {response}")
|
||||
if response:
|
||||
container = message_manager.get_container(event.group_id)
|
||||
# print(f"有response: {response}")
|
||||
container = message_manager.get_container(chat.stream_id)
|
||||
thinking_message = None
|
||||
# 找到message,删除
|
||||
# print(f"开始找思考消息")
|
||||
for msg in container.messages:
|
||||
if isinstance(msg, Message_Thinking) and msg.message_id == think_id:
|
||||
if isinstance(msg, MessageThinking) and msg.message_info.message_id == think_id:
|
||||
# print(f"找到思考消息: {msg}")
|
||||
thinking_message = msg
|
||||
container.messages.remove(msg)
|
||||
# print(f"\033[1;32m[思考消息删除]\033[0m 已找到思考消息对象,开始删除")
|
||||
break
|
||||
|
||||
# 如果找不到思考消息,直接返回
|
||||
@@ -143,41 +184,38 @@ class ChatBot:
|
||||
|
||||
# 记录开始思考的时间,避免从思考到回复的时间太久
|
||||
thinking_start_time = thinking_message.thinking_start_time
|
||||
message_set = MessageSet(event.group_id, global_config.BOT_QQ,
|
||||
think_id) # 发送消息的id和产生发送消息的message_thinking是一致的
|
||||
# 计算打字时间,1是为了模拟打字,2是避免多条回复乱序
|
||||
message_set = MessageSet(chat, think_id)
|
||||
#计算打字时间,1是为了模拟打字,2是避免多条回复乱序
|
||||
accu_typing_time = 0
|
||||
|
||||
# print(f"\033[1;32m[开始回复]\033[0m 开始将回复1载入发送容器")
|
||||
|
||||
mark_head = False
|
||||
for msg in response:
|
||||
# print(f"\033[1;32m[回复内容]\033[0m {msg}")
|
||||
# 通过时间改变时间戳
|
||||
typing_time = calculate_typing_time(msg)
|
||||
print(f"typing_time: {typing_time}")
|
||||
accu_typing_time += typing_time
|
||||
timepoint = tinking_time_point + accu_typing_time
|
||||
|
||||
bot_message = Message_Sending(
|
||||
group_id=event.group_id,
|
||||
user_id=global_config.BOT_QQ,
|
||||
message_segment = Seg(type='text', data=msg)
|
||||
print(f"message_segment: {message_segment}")
|
||||
bot_message = MessageSending(
|
||||
message_id=think_id,
|
||||
raw_message=msg,
|
||||
plain_text=msg,
|
||||
processed_plain_text=msg,
|
||||
user_nickname=global_config.BOT_NICKNAME,
|
||||
group_name=message.group_name,
|
||||
time=timepoint, # 记录了回复生成的时间
|
||||
thinking_start_time=thinking_start_time, # 记录了思考开始的时间
|
||||
reply_message_id=message.message_id
|
||||
chat_stream=chat,
|
||||
bot_user_info=bot_user_info,
|
||||
message_segment=message_segment,
|
||||
reply=message,
|
||||
is_head=not mark_head,
|
||||
is_emoji=False
|
||||
)
|
||||
await bot_message.initialize()
|
||||
print(f"bot_message: {bot_message}")
|
||||
if not mark_head:
|
||||
bot_message.is_head = True
|
||||
mark_head = True
|
||||
print(f"添加消息到message_set: {bot_message}")
|
||||
message_set.add_message(bot_message)
|
||||
|
||||
# message_set 可以直接加入 message_manager
|
||||
# print(f"\033[1;32m[回复]\033[0m 将回复载入发送容器")
|
||||
print(f"添加message_set到message_manager")
|
||||
message_manager.add_message(message_set)
|
||||
|
||||
bot_response_time = tinking_time_point
|
||||
@@ -189,31 +227,25 @@ class ChatBot:
|
||||
if emoji_raw != None:
|
||||
emoji_path, description = emoji_raw
|
||||
|
||||
emoji_cq = CQCode.create_emoji_cq(emoji_path)
|
||||
|
||||
emoji_cq = image_path_to_base64(emoji_path)
|
||||
|
||||
if random() < 0.5:
|
||||
bot_response_time = tinking_time_point - 1
|
||||
else:
|
||||
bot_response_time = bot_response_time + 1
|
||||
|
||||
bot_message = Message_Sending(
|
||||
group_id=event.group_id,
|
||||
user_id=global_config.BOT_QQ,
|
||||
message_id=0,
|
||||
raw_message=emoji_cq,
|
||||
plain_text=emoji_cq,
|
||||
processed_plain_text=emoji_cq,
|
||||
detailed_plain_text=description,
|
||||
user_nickname=global_config.BOT_NICKNAME,
|
||||
group_name=message.group_name,
|
||||
time=bot_response_time,
|
||||
is_emoji=True,
|
||||
translate_cq=False,
|
||||
thinking_start_time=thinking_start_time,
|
||||
# reply_message_id=message.message_id
|
||||
|
||||
message_segment = Seg(type='emoji', data=emoji_cq)
|
||||
bot_message = MessageSending(
|
||||
message_id=think_id,
|
||||
chat_stream=chat,
|
||||
bot_user_info=bot_user_info,
|
||||
message_segment=message_segment,
|
||||
reply=message,
|
||||
is_head=False,
|
||||
is_emoji=True
|
||||
)
|
||||
await bot_message.initialize()
|
||||
message_manager.add_message(bot_message)
|
||||
|
||||
emotion = await self.gpt._get_emotion_tags(raw_content)
|
||||
logger.debug(f"为 '{response}' 获取到的情感标签为:{emotion}")
|
||||
valuedict = {
|
||||
@@ -225,12 +257,13 @@ class ChatBot:
|
||||
'fearful': -0.7,
|
||||
'neutral': 0.1
|
||||
}
|
||||
await relationship_manager.update_relationship_value(message.user_id,
|
||||
relationship_value=valuedict[emotion[0]])
|
||||
await relationship_manager.update_relationship_value(chat_stream=chat, relationship_value=valuedict[emotion[0]])
|
||||
# 使用情绪管理器更新情绪
|
||||
self.mood_manager.update_mood_from_emotion(emotion[0], global_config.mood_intensity_factor)
|
||||
|
||||
# willing_manager.change_reply_willing_after_sent(event.group_id)
|
||||
|
||||
# willing_manager.change_reply_willing_after_sent(
|
||||
# chat_stream=chat
|
||||
# )
|
||||
|
||||
|
||||
# 创建全局ChatBot实例
|
||||
|
||||
226
src/plugins/chat/chat_stream.py
Normal file
226
src/plugins/chat/chat_stream.py
Normal file
@@ -0,0 +1,226 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import time
|
||||
import copy
|
||||
from typing import Dict, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from ...common.database import Database
|
||||
from .message_base import GroupInfo, UserInfo
|
||||
|
||||
|
||||
class ChatStream:
|
||||
"""聊天流对象,存储一个完整的聊天上下文"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stream_id: str,
|
||||
platform: str,
|
||||
user_info: UserInfo,
|
||||
group_info: Optional[GroupInfo] = None,
|
||||
data: dict = None,
|
||||
):
|
||||
self.stream_id = stream_id
|
||||
self.platform = platform
|
||||
self.user_info = user_info
|
||||
self.group_info = group_info
|
||||
self.create_time = (
|
||||
data.get("create_time", int(time.time())) if data else int(time.time())
|
||||
)
|
||||
self.last_active_time = (
|
||||
data.get("last_active_time", self.create_time) if data else self.create_time
|
||||
)
|
||||
self.saved = False
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""转换为字典格式"""
|
||||
result = {
|
||||
"stream_id": self.stream_id,
|
||||
"platform": self.platform,
|
||||
"user_info": self.user_info.to_dict() if self.user_info else None,
|
||||
"group_info": self.group_info.to_dict() if self.group_info else None,
|
||||
"create_time": self.create_time,
|
||||
"last_active_time": self.last_active_time,
|
||||
}
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "ChatStream":
|
||||
"""从字典创建实例"""
|
||||
user_info = (
|
||||
UserInfo(**data.get("user_info", {})) if data.get("user_info") else None
|
||||
)
|
||||
group_info = (
|
||||
GroupInfo(**data.get("group_info", {})) if data.get("group_info") else None
|
||||
)
|
||||
|
||||
return cls(
|
||||
stream_id=data["stream_id"],
|
||||
platform=data["platform"],
|
||||
user_info=user_info,
|
||||
group_info=group_info,
|
||||
data=data,
|
||||
)
|
||||
|
||||
def update_active_time(self):
|
||||
"""更新最后活跃时间"""
|
||||
self.last_active_time = int(time.time())
|
||||
self.saved = False
|
||||
|
||||
|
||||
class ChatManager:
|
||||
"""聊天管理器,管理所有聊天流"""
|
||||
|
||||
_instance = None
|
||||
_initialized = False
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not self._initialized:
|
||||
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
|
||||
self.db = Database.get_instance()
|
||||
self._ensure_collection()
|
||||
self._initialized = True
|
||||
# 在事件循环中启动初始化
|
||||
# asyncio.create_task(self._initialize())
|
||||
# # 启动自动保存任务
|
||||
# asyncio.create_task(self._auto_save_task())
|
||||
|
||||
async def _initialize(self):
|
||||
"""异步初始化"""
|
||||
try:
|
||||
await self.load_all_streams()
|
||||
logger.success(f"聊天管理器已启动,已加载 {len(self.streams)} 个聊天流")
|
||||
except Exception as e:
|
||||
logger.error(f"聊天管理器启动失败: {str(e)}")
|
||||
|
||||
async def _auto_save_task(self):
|
||||
"""定期自动保存所有聊天流"""
|
||||
while True:
|
||||
await asyncio.sleep(300) # 每5分钟保存一次
|
||||
try:
|
||||
await self._save_all_streams()
|
||||
logger.info("聊天流自动保存完成")
|
||||
except Exception as e:
|
||||
logger.error(f"聊天流自动保存失败: {str(e)}")
|
||||
|
||||
def _ensure_collection(self):
|
||||
"""确保数据库集合存在并创建索引"""
|
||||
if "chat_streams" not in self.db.db.list_collection_names():
|
||||
self.db.db.create_collection("chat_streams")
|
||||
# 创建索引
|
||||
self.db.db.chat_streams.create_index([("stream_id", 1)], unique=True)
|
||||
self.db.db.chat_streams.create_index(
|
||||
[("platform", 1), ("user_info.user_id", 1), ("group_info.group_id", 1)]
|
||||
)
|
||||
|
||||
def _generate_stream_id(
|
||||
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
|
||||
) -> str:
|
||||
"""生成聊天流唯一ID"""
|
||||
if group_info:
|
||||
# 组合关键信息
|
||||
components = [
|
||||
platform,
|
||||
str(group_info.group_id)
|
||||
]
|
||||
else:
|
||||
components = [
|
||||
platform,
|
||||
str(user_info.user_id),
|
||||
"private"
|
||||
]
|
||||
|
||||
# 使用MD5生成唯一ID
|
||||
key = "_".join(components)
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
|
||||
async def get_or_create_stream(
|
||||
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
|
||||
) -> ChatStream:
|
||||
"""获取或创建聊天流
|
||||
|
||||
Args:
|
||||
platform: 平台标识
|
||||
user_info: 用户信息
|
||||
group_info: 群组信息(可选)
|
||||
|
||||
Returns:
|
||||
ChatStream: 聊天流对象
|
||||
"""
|
||||
# 生成stream_id
|
||||
stream_id = self._generate_stream_id(platform, user_info, group_info)
|
||||
|
||||
# 检查内存中是否存在
|
||||
if stream_id in self.streams:
|
||||
stream = self.streams[stream_id]
|
||||
# 更新用户信息和群组信息
|
||||
stream.update_active_time()
|
||||
stream=copy.deepcopy(stream)
|
||||
stream.user_info = user_info
|
||||
if group_info:
|
||||
stream.group_info = group_info
|
||||
return stream
|
||||
|
||||
# 检查数据库中是否存在
|
||||
data = self.db.db.chat_streams.find_one({"stream_id": stream_id})
|
||||
if data:
|
||||
stream = ChatStream.from_dict(data)
|
||||
# 更新用户信息和群组信息
|
||||
stream.user_info = user_info
|
||||
if group_info:
|
||||
stream.group_info = group_info
|
||||
stream.update_active_time()
|
||||
else:
|
||||
# 创建新的聊天流
|
||||
stream = ChatStream(
|
||||
stream_id=stream_id,
|
||||
platform=platform,
|
||||
user_info=user_info,
|
||||
group_info=group_info,
|
||||
)
|
||||
|
||||
# 保存到内存和数据库
|
||||
self.streams[stream_id] = stream
|
||||
await self._save_stream(stream)
|
||||
return copy.deepcopy(stream)
|
||||
|
||||
def get_stream(self, stream_id: str) -> Optional[ChatStream]:
|
||||
"""通过stream_id获取聊天流"""
|
||||
return self.streams.get(stream_id)
|
||||
|
||||
def get_stream_by_info(
|
||||
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
|
||||
) -> Optional[ChatStream]:
|
||||
"""通过信息获取聊天流"""
|
||||
stream_id = self._generate_stream_id(platform, user_info, group_info)
|
||||
return self.streams.get(stream_id)
|
||||
|
||||
async def _save_stream(self, stream: ChatStream):
|
||||
"""保存聊天流到数据库"""
|
||||
if not stream.saved:
|
||||
self.db.db.chat_streams.update_one(
|
||||
{"stream_id": stream.stream_id}, {"$set": stream.to_dict()}, upsert=True
|
||||
)
|
||||
stream.saved = True
|
||||
|
||||
async def _save_all_streams(self):
|
||||
"""保存所有聊天流"""
|
||||
for stream in self.streams.values():
|
||||
await self._save_stream(stream)
|
||||
|
||||
async def load_all_streams(self):
|
||||
"""从数据库加载所有聊天流"""
|
||||
all_streams = self.db.db.chat_streams.find({})
|
||||
for data in all_streams:
|
||||
stream = ChatStream.from_dict(data)
|
||||
self.streams[stream.stream_id] = stream
|
||||
|
||||
|
||||
# 创建全局单例
|
||||
chat_manager = ChatManager()
|
||||
@@ -2,22 +2,23 @@ import base64
|
||||
import html
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional
|
||||
from loguru import logger
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import requests
|
||||
|
||||
# 解析各种CQ码
|
||||
# 包含CQ码类
|
||||
import urllib3
|
||||
from loguru import logger
|
||||
from nonebot import get_driver
|
||||
from urllib3.util import create_urllib3_context
|
||||
|
||||
from ..models.utils_model import LLM_request
|
||||
from .config import global_config
|
||||
from .mapper import emojimapper
|
||||
from .utils_image import image_path_to_base64, storage_emoji, storage_image
|
||||
from .utils_user import get_user_nickname
|
||||
from .message_base import Seg
|
||||
from .utils_user import get_user_nickname,get_groupname
|
||||
from .message_base import GroupInfo, UserInfo
|
||||
|
||||
driver = get_driver()
|
||||
config = driver.config
|
||||
@@ -35,65 +36,80 @@ class TencentSSLAdapter(requests.adapters.HTTPAdapter):
|
||||
|
||||
def init_poolmanager(self, connections, maxsize, block=False):
|
||||
self.poolmanager = urllib3.poolmanager.PoolManager(
|
||||
num_pools=connections, maxsize=maxsize,
|
||||
block=block, ssl_context=self.ssl_context)
|
||||
num_pools=connections,
|
||||
maxsize=maxsize,
|
||||
block=block,
|
||||
ssl_context=self.ssl_context,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CQCode:
|
||||
"""
|
||||
CQ码数据类,用于存储和处理CQ码
|
||||
|
||||
|
||||
属性:
|
||||
type: CQ码类型(如'image', 'at', 'face'等)
|
||||
params: CQ码的参数字典
|
||||
raw_code: 原始CQ码字符串
|
||||
translated_plain_text: 经过处理(如AI翻译)后的文本表示
|
||||
translated_segments: 经过处理后的Seg对象列表
|
||||
"""
|
||||
|
||||
type: str
|
||||
params: Dict[str, str]
|
||||
# raw_code: str
|
||||
group_id: int
|
||||
user_id: int
|
||||
group_name: str = ""
|
||||
user_nickname: str = ""
|
||||
translated_plain_text: Optional[str] = None
|
||||
group_info: Optional[GroupInfo] = None
|
||||
user_info: Optional[UserInfo] = None
|
||||
translated_segments: Optional[Union[Seg, List[Seg]]] = None
|
||||
reply_message: Dict = None # 存储回复消息
|
||||
image_base64: Optional[str] = None
|
||||
_llm: Optional[LLM_request] = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""初始化LLM实例"""
|
||||
self._llm = LLM_request(model=global_config.vlm, temperature=0.4, max_tokens=300)
|
||||
pass
|
||||
|
||||
async def translate(self):
|
||||
"""根据CQ码类型进行相应的翻译处理"""
|
||||
if self.type == 'text':
|
||||
self.translated_plain_text = self.params.get('text', '')
|
||||
elif self.type == 'image':
|
||||
if self.params.get('sub_type') == '0':
|
||||
self.translated_plain_text = await self.translate_image()
|
||||
def translate(self):
|
||||
"""根据CQ码类型进行相应的翻译处理,转换为Seg对象"""
|
||||
if self.type == "text":
|
||||
self.translated_segments = Seg(
|
||||
type="text", data=self.params.get("text", "")
|
||||
)
|
||||
elif self.type == "image":
|
||||
base64_data = self.translate_image()
|
||||
if base64_data:
|
||||
if self.params.get("sub_type") == "0":
|
||||
self.translated_segments = Seg(type="image", data=base64_data)
|
||||
else:
|
||||
self.translated_segments = Seg(type="emoji", data=base64_data)
|
||||
else:
|
||||
self.translated_plain_text = await self.translate_emoji()
|
||||
elif self.type == 'at':
|
||||
user_nickname = get_user_nickname(self.params.get('qq', ''))
|
||||
if user_nickname:
|
||||
self.translated_plain_text = f"[@{user_nickname}]"
|
||||
self.translated_segments = Seg(type="text", data="[图片]")
|
||||
elif self.type == "at":
|
||||
user_nickname = get_user_nickname(self.params.get("qq", ""))
|
||||
self.translated_segments = Seg(
|
||||
type="text", data=f"[@{user_nickname or '某人'}]"
|
||||
)
|
||||
elif self.type == "reply":
|
||||
reply_segments = self.translate_reply()
|
||||
if reply_segments:
|
||||
self.translated_segments = Seg(type="seglist", data=reply_segments)
|
||||
else:
|
||||
self.translated_plain_text = "@某人"
|
||||
elif self.type == 'reply':
|
||||
self.translated_plain_text = await self.translate_reply()
|
||||
elif self.type == 'face':
|
||||
face_id = self.params.get('id', '')
|
||||
# self.translated_plain_text = f"[表情{face_id}]"
|
||||
self.translated_plain_text = f"[{emojimapper.get(int(face_id), '表情')}]"
|
||||
elif self.type == 'forward':
|
||||
self.translated_plain_text = await self.translate_forward()
|
||||
self.translated_segments = Seg(type="text", data="[回复某人消息]")
|
||||
elif self.type == "face":
|
||||
face_id = self.params.get("id", "")
|
||||
self.translated_segments = Seg(
|
||||
type="text", data=f"[{emojimapper.get(int(face_id), '表情')}]"
|
||||
)
|
||||
elif self.type == "forward":
|
||||
forward_segments = self.translate_forward()
|
||||
if forward_segments:
|
||||
self.translated_segments = Seg(type="seglist", data=forward_segments)
|
||||
else:
|
||||
self.translated_segments = Seg(type="text", data="[转发消息]")
|
||||
else:
|
||||
self.translated_plain_text = f"[{self.type}]"
|
||||
self.translated_segments = Seg(type="text", data=f"[{self.type}]")
|
||||
|
||||
def get_img(self):
|
||||
'''
|
||||
"""
|
||||
headers = {
|
||||
'User-Agent': 'QQ/8.9.68.11565 CFNetwork/1220.1 Darwin/20.3.0',
|
||||
'Accept': 'image/*;q=0.8',
|
||||
@@ -102,18 +118,18 @@ class CQCode:
|
||||
'Cache-Control': 'no-cache',
|
||||
'Pragma': 'no-cache'
|
||||
}
|
||||
'''
|
||||
"""
|
||||
# 腾讯专用请求头配置
|
||||
headers = {
|
||||
'User-Agent': 'Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.87 Safari/537.36',
|
||||
'Accept': 'text/html, application/xhtml xml, */*',
|
||||
'Accept-Encoding': 'gbk, GB2312',
|
||||
'Accept-Language': 'zh-cn',
|
||||
'Content-Type': 'application/x-www-form-urlencoded',
|
||||
'Cache-Control': 'no-cache'
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/50.0.2661.87 Safari/537.36",
|
||||
"Accept": "text/html, application/xhtml xml, */*",
|
||||
"Accept-Encoding": "gbk, GB2312",
|
||||
"Accept-Language": "zh-cn",
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
"Cache-Control": "no-cache",
|
||||
}
|
||||
url = html.unescape(self.params['url'])
|
||||
if not url.startswith(('http://', 'https://')):
|
||||
url = html.unescape(self.params["url"])
|
||||
if not url.startswith(("http://", "https://")):
|
||||
return None
|
||||
|
||||
# 创建专用会话
|
||||
@@ -129,30 +145,30 @@ class CQCode:
|
||||
headers=headers,
|
||||
timeout=15,
|
||||
allow_redirects=True,
|
||||
stream=True # 流式传输避免大内存问题
|
||||
stream=True, # 流式传输避免大内存问题
|
||||
)
|
||||
|
||||
# 腾讯服务器特殊状态码处理
|
||||
if response.status_code == 400 and 'multimedia.nt.qq.com.cn' in url:
|
||||
if response.status_code == 400 and "multimedia.nt.qq.com.cn" in url:
|
||||
return None
|
||||
|
||||
if response.status_code != 200:
|
||||
raise requests.exceptions.HTTPError(f"HTTP {response.status_code}")
|
||||
|
||||
# 验证内容类型
|
||||
content_type = response.headers.get('Content-Type', '')
|
||||
if not content_type.startswith('image/'):
|
||||
content_type = response.headers.get("Content-Type", "")
|
||||
if not content_type.startswith("image/"):
|
||||
raise ValueError(f"非图片内容类型: {content_type}")
|
||||
|
||||
# 转换为Base64
|
||||
image_base64 = base64.b64encode(response.content).decode('utf-8')
|
||||
image_base64 = base64.b64encode(response.content).decode("utf-8")
|
||||
self.image_base64 = image_base64
|
||||
return image_base64
|
||||
|
||||
except (requests.exceptions.SSLError, requests.exceptions.HTTPError) as e:
|
||||
if retry == max_retries - 1:
|
||||
logger.error(f"最终请求失败: {str(e)}")
|
||||
time.sleep(1.5 ** retry) # 指数退避
|
||||
time.sleep(1.5**retry) # 指数退避
|
||||
|
||||
except Exception:
|
||||
logger.exception("[未知错误]")
|
||||
@@ -160,211 +176,181 @@ class CQCode:
|
||||
|
||||
return None
|
||||
|
||||
async def translate_emoji(self) -> str:
|
||||
"""处理表情包类型的CQ码"""
|
||||
if 'url' not in self.params:
|
||||
return '[表情包]'
|
||||
base64_str = self.get_img()
|
||||
if base64_str:
|
||||
# 将 base64 字符串转换为字节类型
|
||||
image_bytes = base64.b64decode(base64_str)
|
||||
storage_emoji(image_bytes)
|
||||
return await self.get_emoji_description(base64_str)
|
||||
else:
|
||||
return '[表情包]'
|
||||
def translate_image(self) -> Optional[str]:
|
||||
"""处理图片类型的CQ码,返回base64字符串"""
|
||||
if "url" not in self.params:
|
||||
return None
|
||||
return self.get_img()
|
||||
|
||||
async def translate_image(self) -> str:
|
||||
"""处理图片类型的CQ码,区分普通图片和表情包"""
|
||||
# 没有url,直接返回默认文本
|
||||
if 'url' not in self.params:
|
||||
return '[图片]'
|
||||
base64_str = self.get_img()
|
||||
if base64_str:
|
||||
image_bytes = base64.b64decode(base64_str)
|
||||
storage_image(image_bytes)
|
||||
return await self.get_image_description(base64_str)
|
||||
else:
|
||||
return '[图片]'
|
||||
|
||||
async def get_emoji_description(self, image_base64: str) -> str:
|
||||
"""调用AI接口获取表情包描述"""
|
||||
def translate_forward(self) -> Optional[List[Seg]]:
|
||||
"""处理转发消息,返回Seg列表"""
|
||||
try:
|
||||
prompt = "这是一个表情包,请用简短的中文描述这个表情包传达的情感和含义。最多20个字。"
|
||||
# description, _ = self._llm.generate_response_for_image_sync(prompt, image_base64)
|
||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
|
||||
return f"[表情包:{description}]"
|
||||
except Exception as e:
|
||||
logger.exception(f"AI接口调用失败: {str(e)}")
|
||||
return "[表情包]"
|
||||
if "content" not in self.params:
|
||||
return None
|
||||
|
||||
async def get_image_description(self, image_base64: str) -> str:
|
||||
"""调用AI接口获取普通图片描述"""
|
||||
try:
|
||||
prompt = "请用中文描述这张图片的内容。如果有文字,请把文字都描述出来。并尝试猜测这个图片的含义。最多200个字。"
|
||||
# description, _ = self._llm.generate_response_for_image_sync(prompt, image_base64)
|
||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
|
||||
return f"[图片:{description}]"
|
||||
except Exception as e:
|
||||
logger.exception(f"AI接口调用失败: {str(e)}")
|
||||
return "[图片]"
|
||||
|
||||
async def translate_forward(self) -> str:
|
||||
"""处理转发消息"""
|
||||
try:
|
||||
if 'content' not in self.params:
|
||||
return '[转发消息]'
|
||||
|
||||
# 解析content内容(需要先反转义)
|
||||
content = self.unescape(self.params['content'])
|
||||
# print(f"\033[1;34m[调试信息]\033[0m 转发消息内容: {content}")
|
||||
# 将字符串形式的列表转换为Python对象
|
||||
content = self.unescape(self.params["content"])
|
||||
import ast
|
||||
|
||||
try:
|
||||
messages = ast.literal_eval(content)
|
||||
except ValueError as e:
|
||||
logger.error(f"解析转发消息内容失败: {str(e)}")
|
||||
return '[转发消息]'
|
||||
return None
|
||||
|
||||
# 处理每条消息
|
||||
formatted_messages = []
|
||||
formatted_segments = []
|
||||
for msg in messages:
|
||||
sender = msg.get('sender', {})
|
||||
nickname = sender.get('card') or sender.get('nickname', '未知用户')
|
||||
|
||||
# 获取消息内容并使用Message类处理
|
||||
raw_message = msg.get('raw_message', '')
|
||||
message_array = msg.get('message', [])
|
||||
sender = msg.get("sender", {})
|
||||
nickname = sender.get("card") or sender.get("nickname", "未知用户")
|
||||
raw_message = msg.get("raw_message", "")
|
||||
message_array = msg.get("message", [])
|
||||
|
||||
if message_array and isinstance(message_array, list):
|
||||
# 检查是否包含嵌套的转发消息
|
||||
for message_part in message_array:
|
||||
if message_part.get('type') == 'forward':
|
||||
content = '[转发消息]'
|
||||
if message_part.get("type") == "forward":
|
||||
content_seg = Seg(type="text", data="[转发消息]")
|
||||
break
|
||||
else:
|
||||
# 处理普通消息
|
||||
if raw_message:
|
||||
from .message import Message
|
||||
message_obj = Message(
|
||||
user_id=msg.get('user_id', 0),
|
||||
message_id=msg.get('message_id', 0),
|
||||
raw_message=raw_message,
|
||||
plain_text=raw_message,
|
||||
group_id=msg.get('group_id', 0)
|
||||
)
|
||||
await message_obj.initialize()
|
||||
content = message_obj.processed_plain_text
|
||||
else:
|
||||
content = '[空消息]'
|
||||
if raw_message:
|
||||
from .message_cq import MessageRecvCQ
|
||||
user_info=UserInfo(
|
||||
platform='qq',
|
||||
user_id=msg.get("user_id", 0),
|
||||
user_nickname=nickname,
|
||||
)
|
||||
group_info=GroupInfo(
|
||||
platform='qq',
|
||||
group_id=msg.get("group_id", 0),
|
||||
group_name=get_groupname(msg.get("group_id", 0))
|
||||
)
|
||||
|
||||
message_obj = MessageRecvCQ(
|
||||
message_id=msg.get("message_id", 0),
|
||||
user_info=user_info,
|
||||
raw_message=raw_message,
|
||||
plain_text=raw_message,
|
||||
group_info=group_info,
|
||||
)
|
||||
content_seg = Seg(
|
||||
type="seglist", data=message_obj.message_segment )
|
||||
else:
|
||||
content_seg = Seg(type="text", data="[空消息]")
|
||||
else:
|
||||
# 处理普通消息
|
||||
if raw_message:
|
||||
from .message import Message
|
||||
message_obj = Message(
|
||||
user_id=msg.get('user_id', 0),
|
||||
message_id=msg.get('message_id', 0),
|
||||
from .message_cq import MessageRecvCQ
|
||||
|
||||
user_info=UserInfo(
|
||||
platform='qq',
|
||||
user_id=msg.get("user_id", 0),
|
||||
user_nickname=nickname,
|
||||
)
|
||||
group_info=GroupInfo(
|
||||
platform='qq',
|
||||
group_id=msg.get("group_id", 0),
|
||||
group_name=get_groupname(msg.get("group_id", 0))
|
||||
)
|
||||
message_obj = MessageRecvCQ(
|
||||
message_id=msg.get("message_id", 0),
|
||||
user_info=user_info,
|
||||
raw_message=raw_message,
|
||||
plain_text=raw_message,
|
||||
group_id=msg.get('group_id', 0)
|
||||
group_info=group_info,
|
||||
)
|
||||
content_seg = Seg(
|
||||
type="seglist", data=message_obj.message_segment
|
||||
)
|
||||
await message_obj.initialize()
|
||||
content = message_obj.processed_plain_text
|
||||
else:
|
||||
content = '[空消息]'
|
||||
content_seg = Seg(type="text", data="[空消息]")
|
||||
|
||||
formatted_msg = f"{nickname}: {content}"
|
||||
formatted_messages.append(formatted_msg)
|
||||
formatted_segments.append(Seg(type="text", data=f"{nickname}: "))
|
||||
formatted_segments.append(content_seg)
|
||||
formatted_segments.append(Seg(type="text", data="\n"))
|
||||
|
||||
# 合并所有消息
|
||||
combined_messages = '\n'.join(formatted_messages)
|
||||
logger.debug(f"合并后的转发消息: {combined_messages}")
|
||||
return f"[转发消息:\n{combined_messages}]"
|
||||
return formatted_segments
|
||||
|
||||
except Exception:
|
||||
logger.exception("处理转发消息失败")
|
||||
return '[转发消息]'
|
||||
except Exception as e:
|
||||
logger.error(f"处理转发消息失败: {str(e)}")
|
||||
return None
|
||||
|
||||
async def translate_reply(self) -> str:
|
||||
"""处理回复类型的CQ码"""
|
||||
def translate_reply(self) -> Optional[List[Seg]]:
|
||||
"""处理回复类型的CQ码,返回Seg列表"""
|
||||
from .message_cq import MessageRecvCQ
|
||||
|
||||
# 创建Message对象
|
||||
from .message import Message
|
||||
if self.reply_message == None:
|
||||
# print(f"\033[1;31m[错误]\033[0m 回复消息为空")
|
||||
return '[回复某人消息]'
|
||||
if self.reply_message is None:
|
||||
return None
|
||||
|
||||
if self.reply_message.sender.user_id:
|
||||
message_obj = Message(
|
||||
user_id=self.reply_message.sender.user_id,
|
||||
|
||||
message_obj = MessageRecvCQ(
|
||||
user_info=UserInfo(user_id=self.reply_message.sender.user_id,user_nickname=self.reply_message.sender.get("nickname",None)),
|
||||
message_id=self.reply_message.message_id,
|
||||
raw_message=str(self.reply_message.message),
|
||||
group_id=self.group_id
|
||||
group_info=GroupInfo(group_id=self.reply_message.group_id),
|
||||
)
|
||||
await message_obj.initialize()
|
||||
if message_obj.user_id == global_config.BOT_QQ:
|
||||
return f"[回复 {global_config.BOT_NICKNAME} 的消息: {message_obj.processed_plain_text}]"
|
||||
else:
|
||||
return f"[回复 {self.reply_message.sender.nickname} 的消息: {message_obj.processed_plain_text}]"
|
||||
|
||||
segments = []
|
||||
if message_obj.message_info.user_info.user_id == global_config.BOT_QQ:
|
||||
segments.append(
|
||||
Seg(
|
||||
type="text", data=f"[回复 {global_config.BOT_NICKNAME} 的消息: "
|
||||
)
|
||||
)
|
||||
else:
|
||||
segments.append(
|
||||
Seg(
|
||||
type="text",
|
||||
data=f"[回复 {self.reply_message.sender.nickname} 的消息: ",
|
||||
)
|
||||
)
|
||||
|
||||
segments.append(Seg(type="seglist", data=message_obj.message_segment))
|
||||
segments.append(Seg(type="text", data="]"))
|
||||
return segments
|
||||
else:
|
||||
logger.error("回复消息的sender.user_id为空")
|
||||
return '[回复某人消息]'
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def unescape(text: str) -> str:
|
||||
"""反转义CQ码中的特殊字符"""
|
||||
return text.replace(',', ',') \
|
||||
.replace('[', '[') \
|
||||
.replace(']', ']') \
|
||||
.replace('&', '&')
|
||||
|
||||
@staticmethod
|
||||
def create_emoji_cq(file_path: str) -> str:
|
||||
"""
|
||||
创建表情包CQ码
|
||||
Args:
|
||||
file_path: 本地表情包文件路径
|
||||
Returns:
|
||||
表情包CQ码字符串
|
||||
"""
|
||||
base64_content = image_path_to_base64(file_path)
|
||||
|
||||
# 生成CQ码,设置sub_type=1表示这是表情包
|
||||
return f"[CQ:image,file=base64://{base64_content},sub_type=1]"
|
||||
|
||||
return (
|
||||
text.replace(",", ",")
|
||||
.replace("[", "[")
|
||||
.replace("]", "]")
|
||||
.replace("&", "&")
|
||||
)
|
||||
|
||||
class CQCode_tool:
|
||||
@staticmethod
|
||||
async def cq_from_dict_to_class(cq_code: Dict, reply: Optional[Dict] = None) -> CQCode:
|
||||
def cq_from_dict_to_class(cq_code: Dict,msg ,reply: Optional[Dict] = None) -> CQCode:
|
||||
"""
|
||||
将CQ码字典转换为CQCode对象
|
||||
|
||||
|
||||
Args:
|
||||
cq_code: CQ码字典
|
||||
msg: MessageCQ对象
|
||||
reply: 回复消息的字典(可选)
|
||||
|
||||
|
||||
Returns:
|
||||
CQCode对象
|
||||
"""
|
||||
# 处理字典形式的CQ码
|
||||
# 从cq_code字典中获取type字段的值,如果不存在则默认为'text'
|
||||
cq_type = cq_code.get('type', 'text')
|
||||
cq_type = cq_code.get("type", "text")
|
||||
params = {}
|
||||
if cq_type == 'text':
|
||||
params['text'] = cq_code.get('data', {}).get('text', '')
|
||||
if cq_type == "text":
|
||||
params["text"] = cq_code.get("data", {}).get("text", "")
|
||||
else:
|
||||
params = cq_code.get('data', {})
|
||||
params = cq_code.get("data", {})
|
||||
|
||||
instance = CQCode(
|
||||
type=cq_type,
|
||||
params=params,
|
||||
group_id=0,
|
||||
user_id=0,
|
||||
group_info=msg.message_info.group_info,
|
||||
user_info=msg.message_info.user_info,
|
||||
reply_message=reply
|
||||
)
|
||||
|
||||
# 进行翻译处理
|
||||
await instance.translate()
|
||||
instance.translate()
|
||||
return instance
|
||||
|
||||
@staticmethod
|
||||
@@ -378,5 +364,64 @@ class CQCode_tool:
|
||||
"""
|
||||
return f"[CQ:reply,id={message_id}]"
|
||||
|
||||
@staticmethod
|
||||
def create_emoji_cq(file_path: str) -> str:
|
||||
"""
|
||||
创建表情包CQ码
|
||||
Args:
|
||||
file_path: 本地表情包文件路径
|
||||
Returns:
|
||||
表情包CQ码字符串
|
||||
"""
|
||||
# 确保使用绝对路径
|
||||
abs_path = os.path.abspath(file_path)
|
||||
# 转义特殊字符
|
||||
escaped_path = (
|
||||
abs_path.replace("&", "&")
|
||||
.replace("[", "[")
|
||||
.replace("]", "]")
|
||||
.replace(",", ",")
|
||||
)
|
||||
# 生成CQ码,设置sub_type=1表示这是表情包
|
||||
return f"[CQ:image,file=file:///{escaped_path},sub_type=1]"
|
||||
|
||||
@staticmethod
|
||||
def create_emoji_cq_base64(base64_data: str) -> str:
|
||||
"""
|
||||
创建表情包CQ码
|
||||
Args:
|
||||
base64_data: base64编码的表情包数据
|
||||
Returns:
|
||||
表情包CQ码字符串
|
||||
"""
|
||||
# 转义base64数据
|
||||
escaped_base64 = (
|
||||
base64_data.replace("&", "&")
|
||||
.replace("[", "[")
|
||||
.replace("]", "]")
|
||||
.replace(",", ",")
|
||||
)
|
||||
# 生成CQ码,设置sub_type=1表示这是表情包
|
||||
return f"[CQ:image,file=base64://{escaped_base64},sub_type=1]"
|
||||
|
||||
@staticmethod
|
||||
def create_image_cq_base64(base64_data: str) -> str:
|
||||
"""
|
||||
创建表情包CQ码
|
||||
Args:
|
||||
base64_data: base64编码的表情包数据
|
||||
Returns:
|
||||
表情包CQ码字符串
|
||||
"""
|
||||
# 转义base64数据
|
||||
escaped_base64 = (
|
||||
base64_data.replace("&", "&")
|
||||
.replace("[", "[")
|
||||
.replace("]", "]")
|
||||
.replace(",", ",")
|
||||
)
|
||||
# 生成CQ码,设置sub_type=1表示这是表情包
|
||||
return f"[CQ:image,file=base64://{escaped_base64},sub_type=0]"
|
||||
|
||||
|
||||
cq_code_tool = CQCode_tool()
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import traceback
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from loguru import logger
|
||||
from nonebot import get_driver
|
||||
@@ -11,11 +13,12 @@ from nonebot import get_driver
|
||||
from ...common.database import Database
|
||||
from ..chat.config import global_config
|
||||
from ..chat.utils import get_embedding
|
||||
from ..chat.utils_image import image_path_to_base64
|
||||
from ..chat.utils_image import ImageManager, image_path_to_base64
|
||||
from ..models.utils_model import LLM_request
|
||||
|
||||
driver = get_driver()
|
||||
config = driver.config
|
||||
image_manager = ImageManager()
|
||||
|
||||
|
||||
class EmojiManager:
|
||||
@@ -76,7 +79,6 @@ class EmojiManager:
|
||||
if 'emoji' not in self.db.db.list_collection_names():
|
||||
self.db.db.create_collection('emoji')
|
||||
self.db.db.emoji.create_index([('embedding', '2dsphere')])
|
||||
self.db.db.emoji.create_index([('tags', 1)])
|
||||
self.db.db.emoji.create_index([('filename', 1)], unique=True)
|
||||
|
||||
def record_usage(self, emoji_id: str):
|
||||
@@ -87,10 +89,10 @@ class EmojiManager:
|
||||
{'_id': emoji_id},
|
||||
{'$inc': {'usage_count': 1}}
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("记录表情使用失败")
|
||||
|
||||
async def get_emoji_for_text(self, text: str) -> Optional[str]:
|
||||
except Exception as e:
|
||||
logger.error(f"记录表情使用失败: {str(e)}")
|
||||
|
||||
async def get_emoji_for_text(self, text: str) -> Optional[Tuple[str,str]]:
|
||||
"""根据文本内容获取相关表情包
|
||||
Args:
|
||||
text: 输入文本
|
||||
@@ -144,15 +146,15 @@ class EmojiManager:
|
||||
emoji_similarities.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
# 获取前3个最相似的表情包
|
||||
top_3_emojis = emoji_similarities[:3]
|
||||
|
||||
if not top_3_emojis:
|
||||
top_10_emojis = emoji_similarities[:10 if len(emoji_similarities) > 10 else len(emoji_similarities)]
|
||||
|
||||
if not top_10_emojis:
|
||||
logger.warning("未找到匹配的表情包")
|
||||
return None
|
||||
|
||||
# 从前3个中随机选择一个
|
||||
selected_emoji, similarity = random.choice(top_3_emojis)
|
||||
|
||||
selected_emoji, similarity = random.choice(top_10_emojis)
|
||||
|
||||
if selected_emoji and 'path' in selected_emoji:
|
||||
# 更新使用次数
|
||||
self.db.db.emoji.update_one(
|
||||
@@ -174,15 +176,15 @@ class EmojiManager:
|
||||
logger.error(f"获取表情包失败: {str(e)}")
|
||||
return None
|
||||
|
||||
async def _get_emoji_description(self, image_base64: str) -> str:
|
||||
"""获取表情包的标签"""
|
||||
async def _get_emoji_discription(self, image_base64: str) -> str:
|
||||
"""获取表情包的标签,使用image_manager的描述生成功能"""
|
||||
try:
|
||||
prompt = '这是一个表情包,使用中文简洁的描述一下表情包的内容和表情包所表达的情感'
|
||||
|
||||
content, _ = await self.vlm.generate_response_for_image(prompt, image_base64)
|
||||
logger.debug(f"输出描述: {content}")
|
||||
return content
|
||||
|
||||
# 使用image_manager获取描述,去掉前后的方括号和"表情包:"前缀
|
||||
description = await image_manager.get_emoji_description(image_base64)
|
||||
# 去掉[表情包:xxx]的格式,只保留描述内容
|
||||
description = description.strip('[]').replace('表情包:', '')
|
||||
return description
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取标签失败: {str(e)}")
|
||||
return None
|
||||
@@ -223,29 +225,66 @@ class EmojiManager:
|
||||
|
||||
for filename in files_to_process:
|
||||
image_path = os.path.join(emoji_dir, filename)
|
||||
|
||||
# 检查是否已经注册过
|
||||
existing_emoji = self.db.db['emoji'].find_one({'filename': filename})
|
||||
if existing_emoji:
|
||||
continue
|
||||
|
||||
# 压缩图片并获取base64编码
|
||||
|
||||
# 获取图片的base64编码和哈希值
|
||||
image_base64 = image_path_to_base64(image_path)
|
||||
if image_base64 is None:
|
||||
os.remove(image_path)
|
||||
continue
|
||||
|
||||
# 获取表情包的描述
|
||||
description = await self._get_emoji_description(image_base64)
|
||||
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
|
||||
# 检查是否已经注册过
|
||||
existing_emoji = self.db.db['emoji'].find_one({'filename': filename})
|
||||
description = None
|
||||
|
||||
if existing_emoji:
|
||||
# 即使表情包已存在,也检查是否需要同步到images集合
|
||||
description = existing_emoji.get('discription')
|
||||
# 检查是否在images集合中存在
|
||||
existing_image = image_manager.db.db.images.find_one({'hash': image_hash})
|
||||
if not existing_image:
|
||||
# 同步到images集合
|
||||
image_doc = {
|
||||
'hash': image_hash,
|
||||
'path': image_path,
|
||||
'type': 'emoji',
|
||||
'description': description,
|
||||
'timestamp': int(time.time())
|
||||
}
|
||||
image_manager.db.db.images.update_one(
|
||||
{'hash': image_hash},
|
||||
{'$set': image_doc},
|
||||
upsert=True
|
||||
)
|
||||
# 保存描述到image_descriptions集合
|
||||
image_manager._save_description_to_db(image_hash, description, 'emoji')
|
||||
logger.success(f"同步已存在的表情包到images集合: {filename}")
|
||||
continue
|
||||
|
||||
# 检查是否在images集合中已有描述
|
||||
existing_description = image_manager._get_description_from_db(image_hash, 'emoji')
|
||||
|
||||
if existing_description:
|
||||
description = existing_description
|
||||
else:
|
||||
# 获取表情包的描述
|
||||
description = await self._get_emoji_discription(image_base64)
|
||||
|
||||
if global_config.EMOJI_CHECK:
|
||||
check = await self._check_emoji(image_base64)
|
||||
if '是' not in check:
|
||||
os.remove(image_path)
|
||||
logger.info(f"描述: {description}")
|
||||
logger.info(f"描述: {description}")
|
||||
logger.info(f"其不满足过滤规则,被剔除 {check}")
|
||||
continue
|
||||
logger.info(f"check通过 {check}")
|
||||
|
||||
if description is not None:
|
||||
embedding = await get_embedding(description)
|
||||
|
||||
if description is not None:
|
||||
embedding = await get_embedding(description)
|
||||
# 准备数据库记录
|
||||
@@ -253,14 +292,32 @@ class EmojiManager:
|
||||
'filename': filename,
|
||||
'path': image_path,
|
||||
'embedding': embedding,
|
||||
'description': description,
|
||||
'discription': description,
|
||||
'hash': image_hash,
|
||||
'timestamp': int(time.time())
|
||||
}
|
||||
|
||||
# 保存到数据库
|
||||
|
||||
# 保存到emoji数据库
|
||||
self.db.db['emoji'].insert_one(emoji_record)
|
||||
logger.success(f"注册新表情包: {filename}")
|
||||
logger.info(f"描述: {description}")
|
||||
|
||||
# 保存到images数据库
|
||||
image_doc = {
|
||||
'hash': image_hash,
|
||||
'path': image_path,
|
||||
'type': 'emoji',
|
||||
'description': description,
|
||||
'timestamp': int(time.time())
|
||||
}
|
||||
image_manager.db.db.images.update_one(
|
||||
{'hash': image_hash},
|
||||
{'$set': image_doc},
|
||||
upsert=True
|
||||
)
|
||||
# 保存描述到image_descriptions集合
|
||||
image_manager._save_description_to_db(image_hash, description, 'emoji')
|
||||
logger.success(f"同步保存到images集合: {filename}")
|
||||
else:
|
||||
logger.warning(f"跳过表情包: {filename}")
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from loguru import logger
|
||||
from ...common.database import Database
|
||||
from ..models.utils_model import LLM_request
|
||||
from .config import global_config
|
||||
from .message import Message
|
||||
from .message import MessageRecv, MessageThinking, MessageSending,Message
|
||||
from .prompt_builder import prompt_builder
|
||||
from .relationship_manager import relationship_manager
|
||||
from .utils import process_llm_response
|
||||
@@ -19,58 +19,89 @@ config = driver.config
|
||||
|
||||
class ResponseGenerator:
|
||||
def __init__(self):
|
||||
self.model_r1 = LLM_request(model=global_config.llm_reasoning, temperature=0.7,max_tokens=1000,stream=True)
|
||||
self.model_v3 = LLM_request(model=global_config.llm_normal, temperature=0.7,max_tokens=1000)
|
||||
self.model_r1_distill = LLM_request(model=global_config.llm_reasoning_minor, temperature=0.7,max_tokens=1000)
|
||||
self.model_v25 = LLM_request(model=global_config.llm_normal_minor, temperature=0.7,max_tokens=1000)
|
||||
self.model_r1 = LLM_request(
|
||||
model=global_config.llm_reasoning,
|
||||
temperature=0.7,
|
||||
max_tokens=1000,
|
||||
stream=True,
|
||||
)
|
||||
self.model_v3 = LLM_request(
|
||||
model=global_config.llm_normal, temperature=0.7, max_tokens=1000
|
||||
)
|
||||
self.model_r1_distill = LLM_request(
|
||||
model=global_config.llm_reasoning_minor, temperature=0.7, max_tokens=1000
|
||||
)
|
||||
self.model_v25 = LLM_request(
|
||||
model=global_config.llm_normal_minor, temperature=0.7, max_tokens=1000
|
||||
)
|
||||
self.db = Database.get_instance()
|
||||
self.current_model_type = 'r1' # 默认使用 R1
|
||||
self.current_model_type = "r1" # 默认使用 R1
|
||||
|
||||
async def generate_response(self, message: Message) -> Optional[Union[str, List[str]]]:
|
||||
async def generate_response(
|
||||
self, message: MessageThinking
|
||||
) -> Optional[Union[str, List[str]]]:
|
||||
"""根据当前模型类型选择对应的生成函数"""
|
||||
# 从global_config中获取模型概率值并选择模型
|
||||
rand = random.random()
|
||||
if rand < global_config.MODEL_R1_PROBABILITY:
|
||||
self.current_model_type = 'r1'
|
||||
self.current_model_type = "r1"
|
||||
current_model = self.model_r1
|
||||
elif rand < global_config.MODEL_R1_PROBABILITY + global_config.MODEL_V3_PROBABILITY:
|
||||
self.current_model_type = 'v3'
|
||||
elif (
|
||||
rand
|
||||
< global_config.MODEL_R1_PROBABILITY + global_config.MODEL_V3_PROBABILITY
|
||||
):
|
||||
self.current_model_type = "v3"
|
||||
current_model = self.model_v3
|
||||
else:
|
||||
self.current_model_type = 'r1_distill'
|
||||
self.current_model_type = "r1_distill"
|
||||
current_model = self.model_r1_distill
|
||||
|
||||
logger.info(f"{global_config.BOT_NICKNAME}{self.current_model_type}思考中")
|
||||
|
||||
model_response = await self._generate_response_with_model(message, current_model)
|
||||
raw_content=model_response
|
||||
|
||||
model_response = await self._generate_response_with_model(
|
||||
message, current_model
|
||||
)
|
||||
raw_content = model_response
|
||||
|
||||
# print(f"raw_content: {raw_content}")
|
||||
# print(f"model_response: {model_response}")
|
||||
|
||||
if model_response:
|
||||
logger.info(f'{global_config.BOT_NICKNAME}的回复是:{model_response}')
|
||||
model_response = await self._process_response(model_response)
|
||||
if model_response:
|
||||
return model_response, raw_content
|
||||
return None, raw_content
|
||||
|
||||
return model_response ,raw_content
|
||||
return None,raw_content
|
||||
|
||||
async def _generate_response_with_model(self, message: Message, model: LLM_request) -> Optional[str]:
|
||||
async def _generate_response_with_model(
|
||||
self, message: MessageThinking, model: LLM_request
|
||||
) -> Optional[str]:
|
||||
"""使用指定的模型生成回复"""
|
||||
sender_name = message.user_nickname or f"用户{message.user_id}"
|
||||
if message.user_cardname:
|
||||
sender_name=f"[({message.user_id}){message.user_nickname}]{message.user_cardname}"
|
||||
|
||||
sender_name = (
|
||||
message.chat_stream.user_info.user_nickname
|
||||
or f"用户{message.chat_stream.user_info.user_id}"
|
||||
)
|
||||
if message.chat_stream.user_info.user_cardname:
|
||||
sender_name = f"[({message.chat_stream.user_info.user_id}){message.chat_stream.user_info.user_nickname}]{message.chat_stream.user_info.user_cardname}"
|
||||
|
||||
# 获取关系值
|
||||
relationship_value = relationship_manager.get_relationship(message.user_id).relationship_value if relationship_manager.get_relationship(message.user_id) else 0.0
|
||||
relationship_value = (
|
||||
relationship_manager.get_relationship(
|
||||
message.chat_stream
|
||||
).relationship_value
|
||||
if relationship_manager.get_relationship(message.chat_stream)
|
||||
else 0.0
|
||||
)
|
||||
if relationship_value != 0.0:
|
||||
# print(f"\033[1;32m[关系管理]\033[0m 回复中_当前关系值: {relationship_value}")
|
||||
pass
|
||||
|
||||
|
||||
# 构建prompt
|
||||
prompt, prompt_check = await prompt_builder._build_prompt(
|
||||
message_txt=message.processed_plain_text,
|
||||
sender_name=sender_name,
|
||||
relationship_value=relationship_value,
|
||||
group_id=message.group_id
|
||||
stream_id=message.chat_stream.stream_id,
|
||||
)
|
||||
|
||||
# 读空气模块 简化逻辑,先停用
|
||||
@@ -96,7 +127,7 @@ class ResponseGenerator:
|
||||
except Exception:
|
||||
logger.exception("生成回复时出错")
|
||||
return None
|
||||
|
||||
|
||||
# 保存到数据库
|
||||
self._save_to_db(
|
||||
message=message,
|
||||
@@ -108,54 +139,73 @@ class ResponseGenerator:
|
||||
reasoning_content=reasoning_content,
|
||||
# reasoning_content_check=reasoning_content_check if global_config.enable_kuuki_read else ""
|
||||
)
|
||||
|
||||
|
||||
return content
|
||||
|
||||
# def _save_to_db(self, message: Message, sender_name: str, prompt: str, prompt_check: str,
|
||||
# content: str, content_check: str, reasoning_content: str, reasoning_content_check: str):
|
||||
def _save_to_db(self, message: Message, sender_name: str, prompt: str, prompt_check: str,
|
||||
content: str, reasoning_content: str,):
|
||||
def _save_to_db(
|
||||
self,
|
||||
message: MessageRecv,
|
||||
sender_name: str,
|
||||
prompt: str,
|
||||
prompt_check: str,
|
||||
content: str,
|
||||
reasoning_content: str,
|
||||
):
|
||||
"""保存对话记录到数据库"""
|
||||
self.db.db.reasoning_logs.insert_one({
|
||||
'time': time.time(),
|
||||
'group_id': message.group_id,
|
||||
'user': sender_name,
|
||||
'message': message.processed_plain_text,
|
||||
'model': self.current_model_type,
|
||||
# 'reasoning_check': reasoning_content_check,
|
||||
# 'response_check': content_check,
|
||||
'reasoning': reasoning_content,
|
||||
'response': content,
|
||||
'prompt': prompt,
|
||||
'prompt_check': prompt_check
|
||||
})
|
||||
self.db.db.reasoning_logs.insert_one(
|
||||
{
|
||||
"time": time.time(),
|
||||
"chat_id": message.chat_stream.stream_id,
|
||||
"user": sender_name,
|
||||
"message": message.processed_plain_text,
|
||||
"model": self.current_model_type,
|
||||
# 'reasoning_check': reasoning_content_check,
|
||||
# 'response_check': content_check,
|
||||
"reasoning": reasoning_content,
|
||||
"response": content,
|
||||
"prompt": prompt,
|
||||
"prompt_check": prompt_check,
|
||||
}
|
||||
)
|
||||
|
||||
async def _get_emotion_tags(self, content: str) -> List[str]:
|
||||
"""提取情感标签"""
|
||||
try:
|
||||
prompt = f'''请从以下内容中,从"happy,angry,sad,surprised,disgusted,fearful,neutral"中选出最匹配的1个情感标签并输出
|
||||
prompt = f"""请从以下内容中,从"happy,angry,sad,surprised,disgusted,fearful,neutral"中选出最匹配的1个情感标签并输出
|
||||
只输出标签就好,不要输出其他内容:
|
||||
内容:{content}
|
||||
输出:
|
||||
'''
|
||||
"""
|
||||
content, _ = await self.model_v25.generate_response(prompt)
|
||||
content=content.strip()
|
||||
if content in ['happy','angry','sad','surprised','disgusted','fearful','neutral']:
|
||||
content = content.strip()
|
||||
if content in [
|
||||
"happy",
|
||||
"angry",
|
||||
"sad",
|
||||
"surprised",
|
||||
"disgusted",
|
||||
"fearful",
|
||||
"neutral",
|
||||
]:
|
||||
return [content]
|
||||
else:
|
||||
return ["neutral"]
|
||||
|
||||
except Exception:
|
||||
logger.exception("获取情感标签时出错")
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取情感标签时出错: {e}")
|
||||
return ["neutral"]
|
||||
|
||||
|
||||
async def _process_response(self, content: str) -> Tuple[List[str], List[str]]:
|
||||
"""处理响应内容,返回处理后的内容和情感标签"""
|
||||
if not content:
|
||||
return None, []
|
||||
|
||||
|
||||
processed_response = process_llm_response(content)
|
||||
|
||||
# print(f"得到了处理后的llm返回{processed_response}")
|
||||
|
||||
return processed_response
|
||||
|
||||
|
||||
|
||||
@@ -1,231 +1,386 @@
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, ForwardRef, List, Optional
|
||||
from typing import Dict, ForwardRef, List, Optional, Union
|
||||
|
||||
import urllib3
|
||||
from loguru import logger
|
||||
|
||||
from .cq_code import CQCode, cq_code_tool
|
||||
from .utils_cq import parse_cq_code
|
||||
from .utils_user import get_groupname, get_user_cardname, get_user_nickname
|
||||
|
||||
Message = ForwardRef('Message') # 添加这行
|
||||
from .utils_image import image_manager
|
||||
from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase
|
||||
from .chat_stream import ChatStream, chat_manager
|
||||
# 禁用SSL警告
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
#这个类是消息数据类,用于存储和管理消息数据。
|
||||
#它定义了消息的属性,包括群组ID、用户ID、消息ID、原始消息内容、纯文本内容和时间戳。
|
||||
#它还定义了两个辅助属性:keywords用于提取消息的关键词,is_plain_text用于判断消息是否为纯文本。
|
||||
|
||||
@dataclass
|
||||
class MessageRecv(MessageBase):
|
||||
"""接收消息类,用于处理从MessageCQ序列化的消息"""
|
||||
|
||||
def __init__(self, message_dict: Dict):
|
||||
"""从MessageCQ的字典初始化
|
||||
|
||||
Args:
|
||||
message_dict: MessageCQ序列化后的字典
|
||||
"""
|
||||
message_info = BaseMessageInfo.from_dict(message_dict.get('message_info', {}))
|
||||
message_segment = Seg.from_dict(message_dict.get('message_segment', {}))
|
||||
raw_message = message_dict.get('raw_message')
|
||||
|
||||
super().__init__(
|
||||
message_info=message_info,
|
||||
message_segment=message_segment,
|
||||
raw_message=raw_message
|
||||
)
|
||||
|
||||
# 处理消息内容
|
||||
self.processed_plain_text = "" # 初始化为空字符串
|
||||
self.detailed_plain_text = "" # 初始化为空字符串
|
||||
self.is_emoji=False
|
||||
def update_chat_stream(self,chat_stream:ChatStream):
|
||||
self.chat_stream=chat_stream
|
||||
|
||||
async def process(self) -> None:
|
||||
"""处理消息内容,生成纯文本和详细文本
|
||||
|
||||
这个方法必须在创建实例后显式调用,因为它包含异步操作。
|
||||
"""
|
||||
self.processed_plain_text = await self._process_message_segments(self.message_segment)
|
||||
self.detailed_plain_text = self._generate_detailed_text()
|
||||
|
||||
async def _process_message_segments(self, segment: Seg) -> str:
|
||||
"""递归处理消息段,转换为文字描述
|
||||
|
||||
Args:
|
||||
segment: 要处理的消息段
|
||||
|
||||
Returns:
|
||||
str: 处理后的文本
|
||||
"""
|
||||
if segment.type == 'seglist':
|
||||
# 处理消息段列表
|
||||
segments_text = []
|
||||
for seg in segment.data:
|
||||
processed = await self._process_message_segments(seg)
|
||||
if processed:
|
||||
segments_text.append(processed)
|
||||
return ' '.join(segments_text)
|
||||
else:
|
||||
# 处理单个消息段
|
||||
return await self._process_single_segment(segment)
|
||||
|
||||
async def _process_single_segment(self, seg: Seg) -> str:
|
||||
"""处理单个消息段
|
||||
|
||||
Args:
|
||||
seg: 要处理的消息段
|
||||
|
||||
Returns:
|
||||
str: 处理后的文本
|
||||
"""
|
||||
try:
|
||||
if seg.type == 'text':
|
||||
return seg.data
|
||||
elif seg.type == 'image':
|
||||
# 如果是base64图片数据
|
||||
if isinstance(seg.data, str):
|
||||
return await image_manager.get_image_description(seg.data)
|
||||
return '[图片]'
|
||||
elif seg.type == 'emoji':
|
||||
self.is_emoji=True
|
||||
if isinstance(seg.data, str):
|
||||
return await image_manager.get_emoji_description(seg.data)
|
||||
return '[表情]'
|
||||
else:
|
||||
return f"[{seg.type}:{str(seg.data)}]"
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息段失败: {str(e)}, 类型: {seg.type}, 数据: {seg.data}")
|
||||
return f"[处理失败的{seg.type}消息]"
|
||||
|
||||
def _generate_detailed_text(self) -> str:
|
||||
"""生成详细文本,包含时间和用户信息"""
|
||||
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.message_info.time))
|
||||
user_info = self.message_info.user_info
|
||||
name = (
|
||||
f"{user_info.user_nickname}(ta的昵称:{user_info.user_cardname},ta的id:{user_info.user_id})"
|
||||
if user_info.user_cardname!=''
|
||||
else f"{user_info.user_nickname}(ta的id:{user_info.user_id})"
|
||||
)
|
||||
return f"[{time_str}] {name}: {self.processed_plain_text}\n"
|
||||
|
||||
@dataclass
|
||||
class Message(MessageBase):
|
||||
chat_stream: ChatStream=None
|
||||
reply: Optional['Message'] = None
|
||||
detailed_plain_text: str = ""
|
||||
processed_plain_text: str = ""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message_id: str,
|
||||
time: int,
|
||||
chat_stream: ChatStream,
|
||||
user_info: UserInfo,
|
||||
message_segment: Optional[Seg] = None,
|
||||
reply: Optional['MessageRecv'] = None,
|
||||
detailed_plain_text: str = "",
|
||||
processed_plain_text: str = "",
|
||||
):
|
||||
# 构造基础消息信息
|
||||
message_info = BaseMessageInfo(
|
||||
platform=chat_stream.platform,
|
||||
message_id=message_id,
|
||||
time=time,
|
||||
group_info=chat_stream.group_info,
|
||||
user_info=user_info
|
||||
)
|
||||
|
||||
# 调用父类初始化
|
||||
super().__init__(
|
||||
message_info=message_info,
|
||||
message_segment=message_segment,
|
||||
raw_message=None
|
||||
)
|
||||
|
||||
self.chat_stream = chat_stream
|
||||
# 文本处理相关属性
|
||||
self.processed_plain_text = detailed_plain_text
|
||||
self.detailed_plain_text = processed_plain_text
|
||||
|
||||
# 回复消息
|
||||
self.reply = reply
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
"""消息数据类"""
|
||||
message_id: int = None
|
||||
time: float = None
|
||||
|
||||
group_id: int = None
|
||||
group_name: str = None # 群名称
|
||||
|
||||
user_id: int = None
|
||||
user_nickname: str = None # 用户昵称
|
||||
user_cardname: str = None # 用户群昵称
|
||||
|
||||
raw_message: str = None # 原始消息,包含未解析的cq码
|
||||
plain_text: str = None # 纯文本
|
||||
|
||||
reply_message: Dict = None # 存储 回复的 源消息
|
||||
|
||||
# 延迟初始化字段
|
||||
_initialized: bool = False
|
||||
message_segments: List[Dict] = None # 存储解析后的消息片段
|
||||
processed_plain_text: str = None # 用于存储处理后的plain_text
|
||||
detailed_plain_text: str = None # 用于存储详细可读文本
|
||||
|
||||
# 状态标志
|
||||
is_emoji: bool = False
|
||||
has_emoji: bool = False
|
||||
translate_cq: bool = True
|
||||
|
||||
async def initialize(self):
|
||||
"""显式异步初始化方法(必须调用)"""
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
# 异步获取补充信息
|
||||
self.group_name = self.group_name or get_groupname(self.group_id)
|
||||
self.user_nickname = self.user_nickname or get_user_nickname(self.user_id)
|
||||
self.user_cardname = self.user_cardname or get_user_cardname(self.user_id)
|
||||
|
||||
# 消息解析
|
||||
if self.raw_message:
|
||||
if not isinstance(self,Message_Sending):
|
||||
self.message_segments = await self.parse_message_segments(self.raw_message)
|
||||
self.processed_plain_text = ' '.join(
|
||||
seg.translated_plain_text
|
||||
for seg in self.message_segments
|
||||
)
|
||||
|
||||
# 构建详细文本
|
||||
if self.time is None:
|
||||
self.time = int(time.time())
|
||||
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.time))
|
||||
name = (
|
||||
f"{self.user_nickname}(ta的昵称:{self.user_cardname},ta的id:{self.user_id})"
|
||||
if self.user_cardname
|
||||
else f"{self.user_nickname or f'用户{self.user_id}'}"
|
||||
)
|
||||
if isinstance(self,Message_Sending) and self.is_emoji:
|
||||
self.detailed_plain_text = f"[{time_str}] {name}: {self.detailed_plain_text}\n"
|
||||
else:
|
||||
self.detailed_plain_text = f"[{time_str}] {name}: {self.processed_plain_text}\n"
|
||||
|
||||
self._initialized = True
|
||||
class MessageProcessBase(Message):
|
||||
"""消息处理基类,用于处理中和发送中的消息"""
|
||||
|
||||
async def parse_message_segments(self, message: str) -> List[CQCode]:
|
||||
"""
|
||||
将消息解析为片段列表,包括纯文本和CQ码
|
||||
返回的列表中每个元素都是字典,包含:
|
||||
- cq_code_list:分割出的聊天对象,包括文本和CQ码
|
||||
- trans_list:翻译后的对象列表
|
||||
"""
|
||||
# print(f"\033[1;34m[调试信息]\033[0m 正在处理消息: {message}")
|
||||
cq_code_dict_list = []
|
||||
trans_list = []
|
||||
|
||||
start = 0
|
||||
while True:
|
||||
# 查找下一个CQ码的开始位置
|
||||
cq_start = message.find('[CQ:', start)
|
||||
#如果没有cq码,直接返回文本内容
|
||||
if cq_start == -1:
|
||||
# 如果没有找到更多CQ码,添加剩余文本
|
||||
if start < len(message):
|
||||
text = message[start:].strip()
|
||||
if text: # 只添加非空文本
|
||||
cq_code_dict_list.append(parse_cq_code(text))
|
||||
break
|
||||
# 添加CQ码前的文本
|
||||
if cq_start > start:
|
||||
text = message[start:cq_start].strip()
|
||||
if text: # 只添加非空文本
|
||||
cq_code_dict_list.append(parse_cq_code(text))
|
||||
# 查找CQ码的结束位置
|
||||
cq_end = message.find(']', cq_start)
|
||||
if cq_end == -1:
|
||||
# CQ码未闭合,作为普通文本处理
|
||||
text = message[cq_start:].strip()
|
||||
if text:
|
||||
cq_code_dict_list.append(parse_cq_code(text))
|
||||
break
|
||||
cq_code = message[cq_start:cq_end + 1]
|
||||
|
||||
#将cq_code解析成字典
|
||||
cq_code_dict_list.append(parse_cq_code(cq_code))
|
||||
# 更新start位置到当前CQ码之后
|
||||
start = cq_end + 1
|
||||
|
||||
# print(f"\033[1;34m[调试信息]\033[0m 提取的消息对象:列表: {cq_code_dict_list}")
|
||||
|
||||
#判定是否是表情包消息,以及是否含有表情包
|
||||
if len(cq_code_dict_list) == 1 and cq_code_dict_list[0]['type'] == 'image':
|
||||
self.is_emoji = True
|
||||
self.has_emoji_emoji = True
|
||||
else:
|
||||
for segment in cq_code_dict_list:
|
||||
if segment['type'] == 'image' and segment['data'].get('sub_type') == '1':
|
||||
self.has_emoji_emoji = True
|
||||
break
|
||||
|
||||
|
||||
#翻译作为字典的CQ码
|
||||
for _code_item in cq_code_dict_list:
|
||||
message_obj = await cq_code_tool.cq_from_dict_to_class(_code_item,reply = self.reply_message)
|
||||
trans_list.append(message_obj)
|
||||
return trans_list
|
||||
def __init__(
|
||||
self,
|
||||
message_id: str,
|
||||
chat_stream: ChatStream,
|
||||
bot_user_info: UserInfo,
|
||||
message_segment: Optional[Seg] = None,
|
||||
reply: Optional['MessageRecv'] = None
|
||||
):
|
||||
# 调用父类初始化
|
||||
super().__init__(
|
||||
message_id=message_id,
|
||||
time=int(time.time()),
|
||||
chat_stream=chat_stream,
|
||||
user_info=bot_user_info,
|
||||
message_segment=message_segment,
|
||||
reply=reply
|
||||
)
|
||||
|
||||
class Message_Thinking:
|
||||
"""消息思考类"""
|
||||
def __init__(self, message: Message,message_id: str):
|
||||
# 复制原始消息的基本属性
|
||||
self.group_id = message.group_id
|
||||
self.user_id = message.user_id
|
||||
self.user_nickname = message.user_nickname
|
||||
self.user_cardname = message.user_cardname
|
||||
self.group_name = message.group_name
|
||||
|
||||
self.message_id = message_id
|
||||
|
||||
# 思考状态相关属性
|
||||
# 处理状态相关属性
|
||||
self.thinking_start_time = int(time.time())
|
||||
self.thinking_time = 0
|
||||
self.interupt=False
|
||||
|
||||
def update_thinking_time(self):
|
||||
self.thinking_time = round(time.time(), 2) - self.thinking_start_time
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message_Sending(Message):
|
||||
"""发送中的消息类"""
|
||||
thinking_start_time: float = None # 思考开始时间
|
||||
thinking_time: float = None # 思考时间
|
||||
|
||||
reply_message_id: int = None # 存储 回复的 源消息ID
|
||||
|
||||
is_head: bool = False # 是否是头部消息
|
||||
|
||||
def update_thinking_time(self):
|
||||
self.thinking_time = round(time.time(), 2) - self.thinking_start_time
|
||||
def update_thinking_time(self) -> float:
|
||||
"""更新思考时间"""
|
||||
self.thinking_time = round(time.time() - self.thinking_start_time, 2)
|
||||
return self.thinking_time
|
||||
|
||||
async def _process_message_segments(self, segment: Seg) -> str:
|
||||
"""递归处理消息段,转换为文字描述
|
||||
|
||||
Args:
|
||||
segment: 要处理的消息段
|
||||
|
||||
Returns:
|
||||
str: 处理后的文本
|
||||
"""
|
||||
if segment.type == 'seglist':
|
||||
# 处理消息段列表
|
||||
segments_text = []
|
||||
for seg in segment.data:
|
||||
processed = await self._process_message_segments(seg)
|
||||
if processed:
|
||||
segments_text.append(processed)
|
||||
return ' '.join(segments_text)
|
||||
else:
|
||||
# 处理单个消息段
|
||||
return await self._process_single_segment(segment)
|
||||
|
||||
|
||||
async def _process_single_segment(self, seg: Seg) -> str:
|
||||
"""处理单个消息段
|
||||
|
||||
Args:
|
||||
seg: 要处理的消息段
|
||||
|
||||
Returns:
|
||||
str: 处理后的文本
|
||||
"""
|
||||
try:
|
||||
if seg.type == 'text':
|
||||
return seg.data
|
||||
elif seg.type == 'image':
|
||||
# 如果是base64图片数据
|
||||
if isinstance(seg.data, str):
|
||||
return await image_manager.get_image_description(seg.data)
|
||||
return '[图片]'
|
||||
elif seg.type == 'emoji':
|
||||
if isinstance(seg.data, str):
|
||||
return await image_manager.get_emoji_description(seg.data)
|
||||
return '[表情]'
|
||||
elif seg.type == 'at':
|
||||
return f"[@{seg.data}]"
|
||||
elif seg.type == 'reply':
|
||||
if self.reply and hasattr(self.reply, 'processed_plain_text'):
|
||||
return f"[回复:{self.reply.processed_plain_text}]"
|
||||
else:
|
||||
return f"[{seg.type}:{str(seg.data)}]"
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息段失败: {str(e)}, 类型: {seg.type}, 数据: {seg.data}")
|
||||
return f"[处理失败的{seg.type}消息]"
|
||||
|
||||
def _generate_detailed_text(self) -> str:
|
||||
"""生成详细文本,包含时间和用户信息"""
|
||||
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.message_info.time))
|
||||
user_info = self.message_info.user_info
|
||||
name = (
|
||||
f"{user_info.user_nickname}(ta的昵称:{user_info.user_cardname},ta的id:{user_info.user_id})"
|
||||
if user_info.user_cardname != ''
|
||||
else f"{user_info.user_nickname}(ta的id:{user_info.user_id})"
|
||||
)
|
||||
return f"[{time_str}] {name}: {self.processed_plain_text}\n"
|
||||
|
||||
@dataclass
|
||||
class MessageThinking(MessageProcessBase):
|
||||
"""思考状态的消息类"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message_id: str,
|
||||
chat_stream: ChatStream,
|
||||
bot_user_info: UserInfo,
|
||||
reply: Optional['MessageRecv'] = None
|
||||
):
|
||||
# 调用父类初始化
|
||||
super().__init__(
|
||||
message_id=message_id,
|
||||
chat_stream=chat_stream,
|
||||
bot_user_info=bot_user_info,
|
||||
message_segment=None, # 思考状态不需要消息段
|
||||
reply=reply
|
||||
)
|
||||
|
||||
# 思考状态特有属性
|
||||
self.interrupt = False
|
||||
|
||||
@dataclass
|
||||
class MessageSending(MessageProcessBase):
|
||||
"""发送状态的消息类"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message_id: str,
|
||||
chat_stream: ChatStream,
|
||||
bot_user_info: UserInfo,
|
||||
message_segment: Seg,
|
||||
reply: Optional['MessageRecv'] = None,
|
||||
is_head: bool = False,
|
||||
is_emoji: bool = False
|
||||
):
|
||||
# 调用父类初始化
|
||||
super().__init__(
|
||||
message_id=message_id,
|
||||
chat_stream=chat_stream,
|
||||
bot_user_info=bot_user_info,
|
||||
message_segment=message_segment,
|
||||
reply=reply
|
||||
)
|
||||
|
||||
# 发送状态特有属性
|
||||
self.reply_to_message_id = reply.message_info.message_id if reply else None
|
||||
self.is_head = is_head
|
||||
self.is_emoji = is_emoji
|
||||
|
||||
def set_reply(self, reply: Optional['MessageRecv']) -> None:
|
||||
"""设置回复消息"""
|
||||
if reply:
|
||||
self.reply = reply
|
||||
self.reply_to_message_id = self.reply.message_info.message_id
|
||||
self.message_segment = Seg(type='seglist', data=[
|
||||
Seg(type='reply', data=reply.message_info.message_id),
|
||||
self.message_segment
|
||||
])
|
||||
|
||||
async def process(self) -> None:
|
||||
"""处理消息内容,生成纯文本和详细文本"""
|
||||
if self.message_segment:
|
||||
self.processed_plain_text = await self._process_message_segments(self.message_segment)
|
||||
self.detailed_plain_text = self._generate_detailed_text()
|
||||
|
||||
@classmethod
|
||||
def from_thinking(
|
||||
cls,
|
||||
thinking: MessageThinking,
|
||||
message_segment: Seg,
|
||||
is_head: bool = False,
|
||||
is_emoji: bool = False
|
||||
) -> 'MessageSending':
|
||||
"""从思考状态消息创建发送状态消息"""
|
||||
return cls(
|
||||
message_id=thinking.message_info.message_id,
|
||||
chat_stream=thinking.chat_stream,
|
||||
message_segment=message_segment,
|
||||
bot_user_info=thinking.message_info.user_info,
|
||||
reply=thinking.reply,
|
||||
is_head=is_head,
|
||||
is_emoji=is_emoji
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
ret= super().to_dict()
|
||||
ret['message_info']['user_info']=self.chat_stream.user_info.to_dict()
|
||||
return ret
|
||||
|
||||
@dataclass
|
||||
class MessageSet:
|
||||
"""消息集合类,可以存储多个发送消息"""
|
||||
def __init__(self, group_id: int, user_id: int, message_id: str):
|
||||
self.group_id = group_id
|
||||
self.user_id = user_id
|
||||
def __init__(self, chat_stream: ChatStream, message_id: str):
|
||||
self.chat_stream = chat_stream
|
||||
self.message_id = message_id
|
||||
self.messages: List[Message_Sending] = [] # 修改类型标注
|
||||
self.messages: List[MessageSending] = []
|
||||
self.time = round(time.time(), 2)
|
||||
|
||||
def add_message(self, message: Message_Sending) -> None:
|
||||
"""添加消息到集合,只接受Message_Sending类型"""
|
||||
if not isinstance(message, Message_Sending):
|
||||
raise TypeError("MessageSet只能添加Message_Sending类型的消息")
|
||||
def add_message(self, message: MessageSending) -> None:
|
||||
"""添加消息到集合"""
|
||||
if not isinstance(message, MessageSending):
|
||||
raise TypeError("MessageSet只能添加MessageSending类型的消息")
|
||||
self.messages.append(message)
|
||||
# 按时间排序
|
||||
self.messages.sort(key=lambda x: x.time)
|
||||
self.messages.sort(key=lambda x: x.message_info.time)
|
||||
|
||||
def get_message_by_index(self, index: int) -> Optional[Message_Sending]:
|
||||
def get_message_by_index(self, index: int) -> Optional[MessageSending]:
|
||||
"""通过索引获取消息"""
|
||||
if 0 <= index < len(self.messages):
|
||||
return self.messages[index]
|
||||
return None
|
||||
|
||||
def get_message_by_time(self, target_time: float) -> Optional[Message_Sending]:
|
||||
def get_message_by_time(self, target_time: float) -> Optional[MessageSending]:
|
||||
"""获取最接近指定时间的消息"""
|
||||
if not self.messages:
|
||||
return None
|
||||
|
||||
# 使用二分查找找到最接近的消息
|
||||
left, right = 0, len(self.messages) - 1
|
||||
while left < right:
|
||||
mid = (left + right) // 2
|
||||
if self.messages[mid].time < target_time:
|
||||
if self.messages[mid].message_info.time < target_time:
|
||||
left = mid + 1
|
||||
else:
|
||||
right = mid
|
||||
|
||||
return self.messages[left]
|
||||
|
||||
|
||||
def clear_messages(self) -> None:
|
||||
"""清空所有消息"""
|
||||
self.messages.clear()
|
||||
|
||||
def remove_message(self, message: Message_Sending) -> bool:
|
||||
def remove_message(self, message: MessageSending) -> bool:
|
||||
"""移除指定消息"""
|
||||
if message in self.messages:
|
||||
self.messages.remove(message)
|
||||
|
||||
186
src/plugins/chat/message_base.py
Normal file
186
src/plugins/chat/message_base.py
Normal file
@@ -0,0 +1,186 @@
|
||||
from dataclasses import dataclass, asdict
|
||||
from typing import List, Optional, Union, Any, Dict
|
||||
|
||||
@dataclass
|
||||
class Seg:
|
||||
"""消息片段类,用于表示消息的不同部分
|
||||
|
||||
Attributes:
|
||||
type: 片段类型,可以是 'text'、'image'、'seglist' 等
|
||||
data: 片段的具体内容
|
||||
- 对于 text 类型,data 是字符串
|
||||
- 对于 image 类型,data 是 base64 字符串
|
||||
- 对于 seglist 类型,data 是 Seg 列表
|
||||
translated_data: 经过翻译处理的数据(可选)
|
||||
"""
|
||||
type: str
|
||||
data: Union[str, List['Seg']]
|
||||
|
||||
|
||||
# def __init__(self, type: str, data: Union[str, List['Seg']],):
|
||||
# """初始化实例,确保字典和属性同步"""
|
||||
# # 先初始化字典
|
||||
# self.type = type
|
||||
# self.data = data
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> 'Seg':
|
||||
"""从字典创建Seg实例"""
|
||||
type=data.get('type')
|
||||
data=data.get('data')
|
||||
if type == 'seglist':
|
||||
data = [Seg.from_dict(seg) for seg in data]
|
||||
return cls(
|
||||
type=type,
|
||||
data=data
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""转换为字典格式"""
|
||||
result = {'type': self.type}
|
||||
if self.type == 'seglist':
|
||||
result['data'] = [seg.to_dict() for seg in self.data]
|
||||
else:
|
||||
result['data'] = self.data
|
||||
return result
|
||||
|
||||
@dataclass
|
||||
class GroupInfo:
|
||||
"""群组信息类"""
|
||||
platform: Optional[str] = None
|
||||
group_id: Optional[int] = None
|
||||
group_name: Optional[str] = None # 群名称
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""转换为字典格式"""
|
||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> 'GroupInfo':
|
||||
"""从字典创建GroupInfo实例
|
||||
|
||||
Args:
|
||||
data: 包含必要字段的字典
|
||||
|
||||
Returns:
|
||||
GroupInfo: 新的实例
|
||||
"""
|
||||
return cls(
|
||||
platform=data.get('platform'),
|
||||
group_id=data.get('group_id'),
|
||||
group_name=data.get('group_name',None)
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class UserInfo:
|
||||
"""用户信息类"""
|
||||
platform: Optional[str] = None
|
||||
user_id: Optional[int] = None
|
||||
user_nickname: Optional[str] = None # 用户昵称
|
||||
user_cardname: Optional[str] = None # 用户群昵称
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""转换为字典格式"""
|
||||
return {k: v for k, v in asdict(self).items() if v is not None}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> 'UserInfo':
|
||||
"""从字典创建UserInfo实例
|
||||
|
||||
Args:
|
||||
data: 包含必要字段的字典
|
||||
|
||||
Returns:
|
||||
UserInfo: 新的实例
|
||||
"""
|
||||
return cls(
|
||||
platform=data.get('platform'),
|
||||
user_id=data.get('user_id'),
|
||||
user_nickname=data.get('user_nickname',None),
|
||||
user_cardname=data.get('user_cardname',None)
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class BaseMessageInfo:
|
||||
"""消息信息类"""
|
||||
platform: Optional[str] = None
|
||||
message_id: Union[str,int,None] = None
|
||||
time: Optional[int] = None
|
||||
group_info: Optional[GroupInfo] = None
|
||||
user_info: Optional[UserInfo] = None
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""转换为字典格式"""
|
||||
result = {}
|
||||
for field, value in asdict(self).items():
|
||||
if value is not None:
|
||||
if isinstance(value, (GroupInfo, UserInfo)):
|
||||
result[field] = value.to_dict()
|
||||
else:
|
||||
result[field] = value
|
||||
return result
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> 'BaseMessageInfo':
|
||||
"""从字典创建BaseMessageInfo实例
|
||||
|
||||
Args:
|
||||
data: 包含必要字段的字典
|
||||
|
||||
Returns:
|
||||
BaseMessageInfo: 新的实例
|
||||
"""
|
||||
group_info = GroupInfo(**data.get('group_info', {}))
|
||||
user_info = UserInfo(**data.get('user_info', {}))
|
||||
return cls(
|
||||
platform=data.get('platform'),
|
||||
message_id=data.get('message_id'),
|
||||
time=data.get('time'),
|
||||
group_info=group_info,
|
||||
user_info=user_info
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class MessageBase:
|
||||
"""消息类"""
|
||||
message_info: BaseMessageInfo
|
||||
message_segment: Seg
|
||||
raw_message: Optional[str] = None # 原始消息,包含未解析的cq码
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""转换为字典格式
|
||||
|
||||
Returns:
|
||||
Dict: 包含所有非None字段的字典,其中:
|
||||
- message_info: 转换为字典格式
|
||||
- message_segment: 转换为字典格式
|
||||
- raw_message: 如果存在则包含
|
||||
"""
|
||||
result = {
|
||||
'message_info': self.message_info.to_dict(),
|
||||
'message_segment': self.message_segment.to_dict()
|
||||
}
|
||||
if self.raw_message is not None:
|
||||
result['raw_message'] = self.raw_message
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> 'MessageBase':
|
||||
"""从字典创建MessageBase实例
|
||||
|
||||
Args:
|
||||
data: 包含必要字段的字典
|
||||
|
||||
Returns:
|
||||
MessageBase: 新的实例
|
||||
"""
|
||||
message_info = BaseMessageInfo(**data.get('message_info', {}))
|
||||
message_segment = Seg(**data.get('message_segment', {}))
|
||||
raw_message = data.get('raw_message',None)
|
||||
return cls(
|
||||
message_info=message_info,
|
||||
message_segment=message_segment,
|
||||
raw_message=raw_message
|
||||
)
|
||||
|
||||
|
||||
|
||||
169
src/plugins/chat/message_cq.py
Normal file
169
src/plugins/chat/message_cq.py
Normal file
@@ -0,0 +1,169 @@
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, ForwardRef, List, Optional, Union
|
||||
|
||||
import urllib3
|
||||
|
||||
from .cq_code import CQCode, cq_code_tool
|
||||
from .utils_cq import parse_cq_code
|
||||
from .utils_user import get_groupname, get_user_cardname, get_user_nickname
|
||||
from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase
|
||||
# 禁用SSL警告
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
#这个类是消息数据类,用于存储和管理消息数据。
|
||||
#它定义了消息的属性,包括群组ID、用户ID、消息ID、原始消息内容、纯文本内容和时间戳。
|
||||
#它还定义了两个辅助属性:keywords用于提取消息的关键词,is_plain_text用于判断消息是否为纯文本。
|
||||
|
||||
@dataclass
|
||||
class MessageCQ(MessageBase):
|
||||
"""QQ消息基类,继承自MessageBase
|
||||
|
||||
最小必要参数:
|
||||
- message_id: 消息ID
|
||||
- user_id: 发送者/接收者ID
|
||||
- platform: 平台标识(默认为"qq")
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
message_id: int,
|
||||
user_info: UserInfo,
|
||||
group_info: Optional[GroupInfo] = None,
|
||||
platform: str = "qq"
|
||||
):
|
||||
# 构造基础消息信息
|
||||
message_info = BaseMessageInfo(
|
||||
platform=platform,
|
||||
message_id=message_id,
|
||||
time=int(time.time()),
|
||||
group_info=group_info,
|
||||
user_info=user_info
|
||||
)
|
||||
# 调用父类初始化,message_segment 由子类设置
|
||||
super().__init__(
|
||||
message_info=message_info,
|
||||
message_segment=None,
|
||||
raw_message=None
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class MessageRecvCQ(MessageCQ):
|
||||
"""QQ接收消息类,用于解析raw_message到Seg对象"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message_id: int,
|
||||
user_info: UserInfo,
|
||||
raw_message: str,
|
||||
group_info: Optional[GroupInfo] = None,
|
||||
platform: str = "qq",
|
||||
reply_message: Optional[Dict] = None,
|
||||
):
|
||||
# 调用父类初始化
|
||||
super().__init__(message_id, user_info, group_info, platform)
|
||||
|
||||
if group_info.group_name is None:
|
||||
group_info.group_name = get_groupname(group_info.group_id)
|
||||
|
||||
# 解析消息段
|
||||
self.message_segment = self._parse_message(raw_message, reply_message)
|
||||
self.raw_message = raw_message
|
||||
|
||||
def _parse_message(self, message: str, reply_message: Optional[Dict] = None) -> Seg:
|
||||
"""解析消息内容为Seg对象"""
|
||||
cq_code_dict_list = []
|
||||
segments = []
|
||||
|
||||
start = 0
|
||||
while True:
|
||||
cq_start = message.find('[CQ:', start)
|
||||
if cq_start == -1:
|
||||
if start < len(message):
|
||||
text = message[start:].strip()
|
||||
if text:
|
||||
cq_code_dict_list.append(parse_cq_code(text))
|
||||
break
|
||||
|
||||
if cq_start > start:
|
||||
text = message[start:cq_start].strip()
|
||||
if text:
|
||||
cq_code_dict_list.append(parse_cq_code(text))
|
||||
|
||||
cq_end = message.find(']', cq_start)
|
||||
if cq_end == -1:
|
||||
text = message[cq_start:].strip()
|
||||
if text:
|
||||
cq_code_dict_list.append(parse_cq_code(text))
|
||||
break
|
||||
|
||||
cq_code = message[cq_start:cq_end + 1]
|
||||
cq_code_dict_list.append(parse_cq_code(cq_code))
|
||||
start = cq_end + 1
|
||||
|
||||
# 转换CQ码为Seg对象
|
||||
for code_item in cq_code_dict_list:
|
||||
message_obj = cq_code_tool.cq_from_dict_to_class(code_item,msg=self,reply=reply_message)
|
||||
if message_obj.translated_segments:
|
||||
segments.append(message_obj.translated_segments)
|
||||
|
||||
# 如果只有一个segment,直接返回
|
||||
if len(segments) == 1:
|
||||
return segments[0]
|
||||
|
||||
# 否则返回seglist类型的Seg
|
||||
return Seg(type='seglist', data=segments)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""转换为字典格式,包含所有必要信息"""
|
||||
base_dict = super().to_dict()
|
||||
return base_dict
|
||||
|
||||
@dataclass
|
||||
class MessageSendCQ(MessageCQ):
|
||||
"""QQ发送消息类,用于将Seg对象转换为raw_message"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data: Dict
|
||||
):
|
||||
# 调用父类初始化
|
||||
message_info = BaseMessageInfo.from_dict(data.get('message_info', {}))
|
||||
message_segment = Seg.from_dict(data.get('message_segment', {}))
|
||||
super().__init__(
|
||||
message_info.message_id,
|
||||
message_info.user_info,
|
||||
message_info.group_info if message_info.group_info else None,
|
||||
message_info.platform
|
||||
)
|
||||
|
||||
self.message_segment = message_segment
|
||||
self.raw_message = self._generate_raw_message()
|
||||
|
||||
def _generate_raw_message(self, ) -> str:
|
||||
"""将Seg对象转换为raw_message"""
|
||||
segments = []
|
||||
|
||||
# 处理消息段
|
||||
if self.message_segment.type == 'seglist':
|
||||
for seg in self.message_segment.data:
|
||||
segments.append(self._seg_to_cq_code(seg))
|
||||
else:
|
||||
segments.append(self._seg_to_cq_code(self.message_segment))
|
||||
|
||||
return ''.join(segments)
|
||||
|
||||
def _seg_to_cq_code(self, seg: Seg) -> str:
|
||||
"""将单个Seg对象转换为CQ码字符串"""
|
||||
if seg.type == 'text':
|
||||
return str(seg.data)
|
||||
elif seg.type == 'image':
|
||||
return cq_code_tool.create_image_cq_base64(seg.data)
|
||||
elif seg.type == 'emoji':
|
||||
return cq_code_tool.create_emoji_cq_base64(seg.data)
|
||||
elif seg.type == 'at':
|
||||
return f"[CQ:at,qq={seg.data}]"
|
||||
elif seg.type == 'reply':
|
||||
return cq_code_tool.create_reply_cq(int(seg.data))
|
||||
else:
|
||||
return f"[{seg.data}]"
|
||||
|
||||
@@ -6,10 +6,11 @@ from loguru import logger
|
||||
from nonebot.adapters.onebot.v11 import Bot
|
||||
|
||||
from .cq_code import cq_code_tool
|
||||
from .message import Message, Message_Sending, Message_Thinking, MessageSet
|
||||
from .message_cq import MessageSendCQ
|
||||
from .message import MessageSending, MessageThinking, MessageRecv,MessageSet
|
||||
from .storage import MessageStorage
|
||||
from .utils import calculate_typing_time
|
||||
from .config import global_config
|
||||
from .chat_stream import chat_manager
|
||||
|
||||
|
||||
class Message_Sender:
|
||||
@@ -24,64 +25,57 @@ class Message_Sender:
|
||||
"""设置当前bot实例"""
|
||||
self._current_bot = bot
|
||||
|
||||
async def send_group_message(
|
||||
async def send_message(
|
||||
self,
|
||||
group_id: int,
|
||||
send_text: str,
|
||||
auto_escape: bool = False,
|
||||
reply_message_id: int = None,
|
||||
at_user_id: int = None
|
||||
message: MessageSending,
|
||||
) -> None:
|
||||
|
||||
if not self._current_bot:
|
||||
raise RuntimeError("Bot未设置,请先调用set_bot方法设置bot实例")
|
||||
|
||||
message = send_text
|
||||
|
||||
# 如果需要回复
|
||||
if reply_message_id:
|
||||
reply_cq = cq_code_tool.create_reply_cq(reply_message_id)
|
||||
message = reply_cq + message
|
||||
|
||||
# 如果需要at
|
||||
# if at_user_id:
|
||||
# at_cq = cq_code_tool.create_at_cq(at_user_id)
|
||||
# message = at_cq + " " + message
|
||||
|
||||
typing_time = calculate_typing_time(message)
|
||||
if typing_time > 10:
|
||||
typing_time = 10
|
||||
await asyncio.sleep(typing_time)
|
||||
|
||||
# 发送消息
|
||||
try:
|
||||
await self._current_bot.send_group_msg(
|
||||
group_id=group_id,
|
||||
message=message,
|
||||
auto_escape=auto_escape
|
||||
"""发送消息"""
|
||||
if isinstance(message, MessageSending):
|
||||
message_json = message.to_dict()
|
||||
message_send=MessageSendCQ(
|
||||
data=message_json
|
||||
)
|
||||
logger.debug(f"发送消息{message}成功")
|
||||
except Exception:
|
||||
logger.exception(f"发送消息{message}失败")
|
||||
|
||||
if message_send.message_info.group_info:
|
||||
try:
|
||||
await self._current_bot.send_group_msg(
|
||||
group_id=message.message_info.group_info.group_id,
|
||||
message=message_send.raw_message,
|
||||
auto_escape=False
|
||||
)
|
||||
logger.success(f"[调试] 发送消息{message.processed_plain_text}成功")
|
||||
except Exception as e:
|
||||
logger.error(f"[调试] 发生错误 {e}")
|
||||
logger.error(f"[调试] 发送消息{message.processed_plain_text}失败")
|
||||
else:
|
||||
try:
|
||||
await self._current_bot.send_private_msg(
|
||||
user_id=message.message_info.user_info.user_id,
|
||||
message=message_send.raw_message,
|
||||
auto_escape=False
|
||||
)
|
||||
logger.success(f"[调试] 发送消息{message.processed_plain_text}成功")
|
||||
except Exception as e:
|
||||
logger.error(f"发生错误 {e}")
|
||||
logger.error(f"[调试] 发送消息{message.processed_plain_text}失败")
|
||||
|
||||
|
||||
class MessageContainer:
|
||||
"""单个群的发送/思考消息容器"""
|
||||
|
||||
def __init__(self, group_id: int, max_size: int = 100):
|
||||
self.group_id = group_id
|
||||
"""单个聊天流的发送/思考消息容器"""
|
||||
def __init__(self, chat_id: str, max_size: int = 100):
|
||||
self.chat_id = chat_id
|
||||
self.max_size = max_size
|
||||
self.messages = []
|
||||
self.last_send_time = 0
|
||||
self.thinking_timeout = 20 # 思考超时时间(秒)
|
||||
|
||||
def get_timeout_messages(self) -> List[Message_Sending]:
|
||||
|
||||
def get_timeout_messages(self) -> List[MessageSending]:
|
||||
"""获取所有超时的Message_Sending对象(思考时间超过30秒),按thinking_start_time排序"""
|
||||
current_time = time.time()
|
||||
timeout_messages = []
|
||||
|
||||
for msg in self.messages:
|
||||
if isinstance(msg, Message_Sending):
|
||||
if isinstance(msg, MessageSending):
|
||||
if current_time - msg.thinking_start_time > self.thinking_timeout:
|
||||
timeout_messages.append(msg)
|
||||
|
||||
@@ -89,8 +83,8 @@ class MessageContainer:
|
||||
timeout_messages.sort(key=lambda x: x.thinking_start_time)
|
||||
|
||||
return timeout_messages
|
||||
|
||||
def get_earliest_message(self) -> Optional[Union[Message_Thinking, Message_Sending]]:
|
||||
|
||||
def get_earliest_message(self) -> Optional[Union[MessageThinking, MessageSending]]:
|
||||
"""获取thinking_start_time最早的消息对象"""
|
||||
if not self.messages:
|
||||
return None
|
||||
@@ -102,17 +96,16 @@ class MessageContainer:
|
||||
earliest_time = msg_time
|
||||
earliest_message = msg
|
||||
return earliest_message
|
||||
|
||||
def add_message(self, message: Union[Message_Thinking, Message_Sending]) -> None:
|
||||
|
||||
def add_message(self, message: Union[MessageThinking, MessageSending]) -> None:
|
||||
"""添加消息到队列"""
|
||||
# print(f"\033[1;32m[添加消息]\033[0m 添加消息到对应群")
|
||||
if isinstance(message, MessageSet):
|
||||
for single_message in message.messages:
|
||||
self.messages.append(single_message)
|
||||
else:
|
||||
self.messages.append(message)
|
||||
|
||||
def remove_message(self, message: Union[Message_Thinking, Message_Sending]) -> bool:
|
||||
|
||||
def remove_message(self, message: Union[MessageThinking, MessageSending]) -> bool:
|
||||
"""移除消息,如果消息存在则返回True,否则返回False"""
|
||||
try:
|
||||
if message in self.messages:
|
||||
@@ -126,42 +119,40 @@ class MessageContainer:
|
||||
def has_messages(self) -> bool:
|
||||
"""检查是否有待发送的消息"""
|
||||
return bool(self.messages)
|
||||
|
||||
def get_all_messages(self) -> List[Union[Message, Message_Thinking]]:
|
||||
|
||||
def get_all_messages(self) -> List[Union[MessageSending, MessageThinking]]:
|
||||
"""获取所有消息"""
|
||||
return list(self.messages)
|
||||
|
||||
|
||||
class MessageManager:
|
||||
"""管理所有群的消息容器"""
|
||||
|
||||
"""管理所有聊天流的消息容器"""
|
||||
def __init__(self):
|
||||
self.containers: Dict[int, MessageContainer] = {}
|
||||
self.containers: Dict[str, MessageContainer] = {} # chat_id -> MessageContainer
|
||||
self.storage = MessageStorage()
|
||||
self._running = True
|
||||
|
||||
def get_container(self, group_id: int) -> MessageContainer:
|
||||
"""获取或创建群的消息容器"""
|
||||
if group_id not in self.containers:
|
||||
self.containers[group_id] = MessageContainer(group_id)
|
||||
return self.containers[group_id]
|
||||
|
||||
def add_message(self, message: Union[Message_Thinking, Message_Sending, MessageSet]) -> None:
|
||||
container = self.get_container(message.group_id)
|
||||
|
||||
def get_container(self, chat_id: str) -> MessageContainer:
|
||||
"""获取或创建聊天流的消息容器"""
|
||||
if chat_id not in self.containers:
|
||||
self.containers[chat_id] = MessageContainer(chat_id)
|
||||
return self.containers[chat_id]
|
||||
|
||||
def add_message(self, message: Union[MessageThinking, MessageSending, MessageSet]) -> None:
|
||||
chat_stream = message.chat_stream
|
||||
if not chat_stream:
|
||||
raise ValueError("无法找到对应的聊天流")
|
||||
container = self.get_container(chat_stream.stream_id)
|
||||
container.add_message(message)
|
||||
|
||||
async def process_group_messages(self, group_id: int):
|
||||
"""处理群消息"""
|
||||
# if int(time.time() / 3) == time.time() / 3:
|
||||
# print(f"\033[1;34m[调试]\033[0m 开始处理群{group_id}的消息")
|
||||
container = self.get_container(group_id)
|
||||
|
||||
async def process_chat_messages(self, chat_id: str):
|
||||
"""处理聊天流消息"""
|
||||
container = self.get_container(chat_id)
|
||||
if container.has_messages():
|
||||
# 最早的对象,可能是思考消息,也可能是发送消息
|
||||
message_earliest = container.get_earliest_message() # 一个message_thinking or message_sending
|
||||
|
||||
# 如果是思考消息
|
||||
if isinstance(message_earliest, Message_Thinking):
|
||||
# 优先等待这条消息
|
||||
# print(f"处理有message的容器chat_id: {chat_id}")
|
||||
message_earliest = container.get_earliest_message()
|
||||
|
||||
if isinstance(message_earliest, MessageThinking):
|
||||
message_earliest.update_thinking_time()
|
||||
thinking_time = message_earliest.thinking_time
|
||||
print(f"消息正在思考中,已思考{int(thinking_time)}秒\r", end='', flush=True)
|
||||
@@ -170,47 +161,38 @@ class MessageManager:
|
||||
if thinking_time > global_config.thinking_timeout:
|
||||
logger.warning(f"消息思考超时({thinking_time}秒),移除该消息")
|
||||
container.remove_message(message_earliest)
|
||||
else: # 如果不是message_thinking就只能是message_sending
|
||||
logger.debug(f"消息'{message_earliest.processed_plain_text}'正在发送中")
|
||||
# 直接发,等什么呢
|
||||
else:
|
||||
|
||||
if message_earliest.is_head and message_earliest.update_thinking_time() > 30:
|
||||
await message_sender.send_group_message(group_id, message_earliest.processed_plain_text,
|
||||
auto_escape=False,
|
||||
reply_message_id=message_earliest.reply_message_id)
|
||||
await message_sender.send_message(message_earliest.set_reply())
|
||||
else:
|
||||
await message_sender.send_group_message(group_id, message_earliest.processed_plain_text,
|
||||
auto_escape=False)
|
||||
# 移除消息
|
||||
if message_earliest.is_emoji:
|
||||
message_earliest.processed_plain_text = "[表情包]"
|
||||
await self.storage.store_message(message_earliest, None)
|
||||
|
||||
await message_sender.send_message(message_earliest)
|
||||
await message_earliest.process()
|
||||
|
||||
print(f"\033[1;34m[调试]\033[0m 消息'{message_earliest.processed_plain_text}'正在发送中")
|
||||
|
||||
await self.storage.store_message(message_earliest, message_earliest.chat_stream,None)
|
||||
|
||||
container.remove_message(message_earliest)
|
||||
|
||||
# 获取并处理超时消息
|
||||
message_timeout = container.get_timeout_messages() # 也许是一堆message_sending
|
||||
|
||||
message_timeout = container.get_timeout_messages()
|
||||
if message_timeout:
|
||||
logger.warning(f"发现{len(message_timeout)}条超时消息")
|
||||
for msg in message_timeout:
|
||||
if msg == message_earliest:
|
||||
continue # 跳过已经处理过的消息
|
||||
|
||||
continue
|
||||
|
||||
try:
|
||||
# 发送
|
||||
if msg.is_head and msg.update_thinking_time() > 30:
|
||||
await message_sender.send_group_message(group_id, msg.processed_plain_text,
|
||||
auto_escape=False,
|
||||
reply_message_id=msg.reply_message_id)
|
||||
await message_sender.send_message(msg.set_reply())
|
||||
else:
|
||||
await message_sender.send_group_message(group_id, msg.processed_plain_text,
|
||||
auto_escape=False)
|
||||
|
||||
# 如果是表情包,则替换为"[表情包]"
|
||||
if msg.is_emoji:
|
||||
msg.processed_plain_text = "[表情包]"
|
||||
await self.storage.store_message(msg, None)
|
||||
|
||||
# 安全地移除消息
|
||||
await message_sender.send_message(msg)
|
||||
|
||||
# if msg.is_emoji:
|
||||
# msg.processed_plain_text = "[表情包]"
|
||||
await msg.process()
|
||||
await self.storage.store_message(msg,msg.chat_stream, None)
|
||||
|
||||
if not container.remove_message(msg):
|
||||
logger.warning("尝试删除不存在的消息")
|
||||
except Exception:
|
||||
@@ -222,9 +204,9 @@ class MessageManager:
|
||||
while self._running:
|
||||
await asyncio.sleep(1)
|
||||
tasks = []
|
||||
for group_id in self.containers.keys():
|
||||
tasks.append(self.process_group_messages(group_id))
|
||||
|
||||
for chat_id in self.containers.keys():
|
||||
tasks.append(self.process_chat_messages(chat_id))
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from ..moods.moods import MoodManager
|
||||
from ..schedule.schedule_generator import bot_schedule
|
||||
from .config import global_config
|
||||
from .utils import get_embedding, get_recent_group_detailed_plain_text
|
||||
from .chat_stream import ChatStream, chat_manager
|
||||
|
||||
|
||||
class PromptBuilder:
|
||||
@@ -17,11 +18,13 @@ class PromptBuilder:
|
||||
self.activate_messages = ''
|
||||
self.db = Database.get_instance()
|
||||
|
||||
async def _build_prompt(self,
|
||||
message_txt: str,
|
||||
sender_name: str = "某人",
|
||||
relationship_value: float = 0.0,
|
||||
group_id: Optional[int] = None) -> tuple[str, str]:
|
||||
|
||||
|
||||
async def _build_prompt(self,
|
||||
message_txt: str,
|
||||
sender_name: str = "某人",
|
||||
relationship_value: float = 0.0,
|
||||
stream_id: Optional[int] = None) -> tuple[str, str]:
|
||||
"""构建prompt
|
||||
|
||||
Args:
|
||||
@@ -70,14 +73,20 @@ class PromptBuilder:
|
||||
logger.debug(f"知识检索耗时: {(end_time - start_time):.3f}秒")
|
||||
|
||||
# 获取聊天上下文
|
||||
chat_in_group=True
|
||||
chat_talking_prompt = ''
|
||||
if group_id:
|
||||
chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, group_id,
|
||||
limit=global_config.MAX_CONTEXT_SIZE,
|
||||
combine=True)
|
||||
|
||||
chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}"
|
||||
|
||||
if stream_id:
|
||||
chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, stream_id, limit=global_config.MAX_CONTEXT_SIZE,combine = True)
|
||||
chat_stream=chat_manager.get_stream(stream_id)
|
||||
if chat_stream.group_info:
|
||||
chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}"
|
||||
else:
|
||||
chat_in_group=False
|
||||
chat_talking_prompt = f"以下是你正在和{sender_name}私聊的内容:\n{chat_talking_prompt}"
|
||||
# print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}")
|
||||
|
||||
|
||||
|
||||
# 使用新的记忆获取方法
|
||||
memory_prompt = ''
|
||||
start_time = time.time()
|
||||
@@ -108,15 +117,10 @@ class PromptBuilder:
|
||||
|
||||
# 激活prompt构建
|
||||
activate_prompt = ''
|
||||
activate_prompt = f"以上是群里正在进行的聊天,{memory_prompt} 现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和他{relation_prompt},{mood_prompt},你想要{relation_prompt_2}。"
|
||||
|
||||
# 检测机器人相关词汇,改为关键词检测与反应功能了,提取到全局配置中
|
||||
# bot_keywords = ['人机', 'bot', '机器', '入机', 'robot', '机器人']
|
||||
# is_bot = any(keyword in message_txt.lower() for keyword in bot_keywords)
|
||||
# if is_bot:
|
||||
# is_bot_prompt = '有人说你是人机或者机器人,否定这一事实,攻击或反问他。注意,只有认为你是机器人才回复,否则不要否认'
|
||||
# else:
|
||||
# is_bot_prompt = ''
|
||||
if chat_in_group:
|
||||
activate_prompt = f"以上是群里正在进行的聊天,{memory_prompt} 现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和ta{relation_prompt},{mood_prompt},你想要{relation_prompt_2}。"
|
||||
else:
|
||||
activate_prompt = f"以上是你正在和{sender_name}私聊的内容,{memory_prompt} 现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和ta{relation_prompt},{mood_prompt},你想要{relation_prompt_2}。"
|
||||
|
||||
# 关键词检测与反应
|
||||
keywords_reaction_prompt = ''
|
||||
@@ -125,15 +129,19 @@ class PromptBuilder:
|
||||
if any(keyword in message_txt.lower() for keyword in rule.get("keywords", [])):
|
||||
logger.info(f"检测到以下关键词之一:{rule.get('keywords', [])},触发反应:{rule.get('reaction', '')}")
|
||||
keywords_reaction_prompt += rule.get("reaction", "") + ','
|
||||
|
||||
# 人格选择
|
||||
personality = global_config.PROMPT_PERSONALITY
|
||||
|
||||
#人格选择
|
||||
personality=global_config.PROMPT_PERSONALITY
|
||||
probability_1 = global_config.PERSONALITY_1
|
||||
probability_2 = global_config.PERSONALITY_2
|
||||
probability_3 = global_config.PERSONALITY_3
|
||||
|
||||
prompt_personality = f'{activate_prompt}你的网名叫{global_config.BOT_NICKNAME},你还有很多别名:{"/".join(global_config.BOT_ALIAS_NAMES)},'
|
||||
personality_choice = random.random()
|
||||
if chat_in_group:
|
||||
prompt_in_group=f"你正在浏览{chat_stream.platform}群"
|
||||
else:
|
||||
prompt_in_group=f"你正在{chat_stream.platform}上和{sender_name}私聊"
|
||||
if personality_choice < probability_1: # 第一种人格
|
||||
prompt_personality += f'''{personality[0]}, 你正在浏览qq群,{promt_info_prompt},
|
||||
现在请你给出日常且口语化的回复,平淡一些,尽量简短一些。{keywords_reaction_prompt}
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import asyncio
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, Union
|
||||
from loguru import logger
|
||||
from typing import Optional
|
||||
|
||||
from ...common.database import Database
|
||||
|
||||
from .message_base import UserInfo
|
||||
from .chat_stream import ChatStream
|
||||
|
||||
class Impression:
|
||||
traits: str = None
|
||||
@@ -15,92 +17,153 @@ class Impression:
|
||||
|
||||
class Relationship:
|
||||
user_id: int = None
|
||||
# impression: Impression = None
|
||||
# group_id: int = None
|
||||
# group_name: str = None
|
||||
platform: str = None
|
||||
gender: str = None
|
||||
age: int = None
|
||||
nickname: str = None
|
||||
relationship_value: float = None
|
||||
saved = False
|
||||
|
||||
def __init__(self, user_id: int, data=None, **kwargs):
|
||||
if isinstance(data, dict):
|
||||
# 如果输入是字典,使用字典解析
|
||||
self.user_id = data.get('user_id')
|
||||
self.gender = data.get('gender')
|
||||
self.age = data.get('age')
|
||||
self.nickname = data.get('nickname')
|
||||
self.relationship_value = data.get('relationship_value', 0.0)
|
||||
self.saved = data.get('saved', False)
|
||||
else:
|
||||
# 如果是直接传入属性值
|
||||
self.user_id = kwargs.get('user_id')
|
||||
self.gender = kwargs.get('gender')
|
||||
self.age = kwargs.get('age')
|
||||
self.nickname = kwargs.get('nickname')
|
||||
self.relationship_value = kwargs.get('relationship_value', 0.0)
|
||||
self.saved = kwargs.get('saved', False)
|
||||
|
||||
|
||||
def __init__(self, chat:ChatStream=None,data:dict=None):
|
||||
self.user_id=chat.user_info.user_id if chat else data.get('user_id',0)
|
||||
self.platform=chat.platform if chat else data.get('platform','')
|
||||
self.nickname=chat.user_info.user_nickname if chat else data.get('nickname','')
|
||||
self.relationship_value=data.get('relationship_value',0) if data else 0
|
||||
self.age=data.get('age',0) if data else 0
|
||||
self.gender=data.get('gender','') if data else ''
|
||||
|
||||
|
||||
class RelationshipManager:
|
||||
def __init__(self):
|
||||
self.relationships: dict[int, Relationship] = {}
|
||||
|
||||
async def update_relationship(self, user_id: int, data=None, **kwargs):
|
||||
self.relationships: dict[tuple[int, str], Relationship] = {} # 修改为使用(user_id, platform)作为键
|
||||
|
||||
async def update_relationship(self,
|
||||
chat_stream:ChatStream,
|
||||
data: dict = None,
|
||||
**kwargs) -> Optional[Relationship]:
|
||||
"""更新或创建关系
|
||||
Args:
|
||||
chat_stream: 聊天流对象
|
||||
data: 字典格式的数据(可选)
|
||||
**kwargs: 其他参数
|
||||
Returns:
|
||||
Relationship: 关系对象
|
||||
"""
|
||||
# 确定user_id和platform
|
||||
if chat_stream.user_info is not None:
|
||||
user_id = chat_stream.user_info.user_id
|
||||
platform = chat_stream.user_info.platform or 'qq'
|
||||
else:
|
||||
platform = platform or 'qq'
|
||||
|
||||
if user_id is None:
|
||||
raise ValueError("必须提供user_id或user_info")
|
||||
|
||||
# 使用(user_id, platform)作为键
|
||||
key = (user_id, platform)
|
||||
|
||||
# 检查是否在内存中已存在
|
||||
relationship = self.relationships.get(user_id)
|
||||
relationship = self.relationships.get(key)
|
||||
if relationship:
|
||||
# 如果存在,更新现有对象
|
||||
if isinstance(data, dict):
|
||||
for key, value in data.items():
|
||||
if hasattr(relationship, key) and value is not None:
|
||||
setattr(relationship, key, value)
|
||||
else:
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(relationship, key) and value is not None:
|
||||
setattr(relationship, key, value)
|
||||
for k, value in data.items():
|
||||
if hasattr(relationship, k) and value is not None:
|
||||
setattr(relationship, k, value)
|
||||
else:
|
||||
# 如果不存在,创建新对象
|
||||
relationship = Relationship(user_id, data=data) if isinstance(data, dict) else Relationship(user_id,
|
||||
**kwargs)
|
||||
self.relationships[user_id] = relationship
|
||||
|
||||
# 更新 id_name_nickname_table
|
||||
# self.id_name_nickname_table[user_id] = [relationship.nickname] # 别称设置为空列表
|
||||
if chat_stream.user_info is not None:
|
||||
relationship = Relationship(chat=chat_stream, **kwargs)
|
||||
else:
|
||||
raise ValueError("必须提供user_id或user_info")
|
||||
self.relationships[key] = relationship
|
||||
|
||||
# 保存到数据库
|
||||
await self.storage_relationship(relationship)
|
||||
relationship.saved = True
|
||||
|
||||
return relationship
|
||||
|
||||
async def update_relationship_value(self, user_id: int, **kwargs):
|
||||
|
||||
async def update_relationship_value(self,
|
||||
chat_stream:ChatStream,
|
||||
**kwargs) -> Optional[Relationship]:
|
||||
"""更新关系值
|
||||
Args:
|
||||
user_id: 用户ID(可选,如果提供user_info则不需要)
|
||||
platform: 平台(可选,如果提供user_info则不需要)
|
||||
user_info: 用户信息对象(可选)
|
||||
**kwargs: 其他参数
|
||||
Returns:
|
||||
Relationship: 关系对象
|
||||
"""
|
||||
# 确定user_id和platform
|
||||
user_info = chat_stream.user_info
|
||||
if user_info is not None:
|
||||
user_id = user_info.user_id
|
||||
platform = user_info.platform or 'qq'
|
||||
else:
|
||||
platform = platform or 'qq'
|
||||
|
||||
if user_id is None:
|
||||
raise ValueError("必须提供user_id或user_info")
|
||||
|
||||
# 使用(user_id, platform)作为键
|
||||
key = (user_id, platform)
|
||||
|
||||
# 检查是否在内存中已存在
|
||||
relationship = self.relationships.get(user_id)
|
||||
relationship = self.relationships.get(key)
|
||||
if relationship:
|
||||
for key, value in kwargs.items():
|
||||
if key == 'relationship_value':
|
||||
for k, value in kwargs.items():
|
||||
if k == 'relationship_value':
|
||||
relationship.relationship_value += value
|
||||
await self.storage_relationship(relationship)
|
||||
relationship.saved = True
|
||||
return relationship
|
||||
else:
|
||||
logger.warning(f"用户 {user_id} 不存在,无法更新")
|
||||
# 如果不存在且提供了user_info,则创建新的关系
|
||||
if user_info is not None:
|
||||
return await self.update_relationship(chat_stream=chat_stream, **kwargs)
|
||||
logger.warning(f"[关系管理] 用户 {user_id}({platform}) 不存在,无法更新")
|
||||
return None
|
||||
|
||||
def get_relationship(self, user_id: int) -> Optional[Relationship]:
|
||||
"""获取用户关系对象"""
|
||||
if user_id in self.relationships:
|
||||
return self.relationships[user_id]
|
||||
|
||||
def get_relationship(self,
|
||||
chat_stream:ChatStream) -> Optional[Relationship]:
|
||||
"""获取用户关系对象
|
||||
Args:
|
||||
user_id: 用户ID(可选,如果提供user_info则不需要)
|
||||
platform: 平台(可选,如果提供user_info则不需要)
|
||||
user_info: 用户信息对象(可选)
|
||||
Returns:
|
||||
Relationship: 关系对象
|
||||
"""
|
||||
# 确定user_id和platform
|
||||
user_info = chat_stream.user_info
|
||||
platform = chat_stream.user_info.platform or 'qq'
|
||||
if user_info is not None:
|
||||
user_id = user_info.user_id
|
||||
platform = user_info.platform or 'qq'
|
||||
else:
|
||||
platform = platform or 'qq'
|
||||
|
||||
if user_id is None:
|
||||
raise ValueError("必须提供user_id或user_info")
|
||||
|
||||
key = (user_id, platform)
|
||||
if key in self.relationships:
|
||||
return self.relationships[key]
|
||||
else:
|
||||
return 0
|
||||
|
||||
async def load_relationship(self, data: dict) -> Relationship:
|
||||
"""从数据库加载或创建新的关系对象"""
|
||||
rela = Relationship(user_id=data['user_id'], data=data)
|
||||
# 确保data中有platform字段,如果没有则默认为'qq'
|
||||
if 'platform' not in data:
|
||||
data['platform'] = 'qq'
|
||||
|
||||
rela = Relationship(data=data)
|
||||
rela.saved = True
|
||||
self.relationships[rela.user_id] = rela
|
||||
key = (rela.user_id, rela.platform)
|
||||
self.relationships[key] = rela
|
||||
return rela
|
||||
|
||||
async def load_all_relationships(self):
|
||||
@@ -117,11 +180,9 @@ class RelationshipManager:
|
||||
all_relationships = db.db.relationships.find({})
|
||||
# 依次加载每条记录
|
||||
for data in all_relationships:
|
||||
user_id = data['user_id']
|
||||
relationship = await self.load_relationship(data)
|
||||
self.relationships[user_id] = relationship
|
||||
logger.debug(f"已加载 {len(self.relationships)} 条关系记录")
|
||||
|
||||
await self.load_relationship(data)
|
||||
logger.debug(f"[关系管理] 已加载 {len(self.relationships)} 条关系记录")
|
||||
|
||||
while True:
|
||||
logger.debug("正在自动保存关系")
|
||||
await asyncio.sleep(300) # 等待300秒(5分钟)
|
||||
@@ -130,16 +191,15 @@ class RelationshipManager:
|
||||
async def _save_all_relationships(self):
|
||||
"""将所有关系数据保存到数据库"""
|
||||
# 保存所有关系数据
|
||||
for userid, relationship in self.relationships.items():
|
||||
for (userid, platform), relationship in self.relationships.items():
|
||||
if not relationship.saved:
|
||||
relationship.saved = True
|
||||
await self.storage_relationship(relationship)
|
||||
|
||||
|
||||
async def storage_relationship(self, relationship: Relationship):
|
||||
"""
|
||||
将关系记录存储到数据库中
|
||||
"""
|
||||
"""将关系记录存储到数据库中"""
|
||||
user_id = relationship.user_id
|
||||
platform = relationship.platform
|
||||
nickname = relationship.nickname
|
||||
relationship_value = relationship.relationship_value
|
||||
gender = relationship.gender
|
||||
@@ -148,8 +208,9 @@ class RelationshipManager:
|
||||
|
||||
db = Database.get_instance()
|
||||
db.db.relationships.update_one(
|
||||
{'user_id': user_id},
|
||||
{'user_id': user_id, 'platform': platform},
|
||||
{'$set': {
|
||||
'platform': platform,
|
||||
'nickname': nickname,
|
||||
'relationship_value': relationship_value,
|
||||
'gender': gender,
|
||||
@@ -158,13 +219,37 @@ class RelationshipManager:
|
||||
}},
|
||||
upsert=True
|
||||
)
|
||||
|
||||
def get_name(self, user_id: int) -> str:
|
||||
|
||||
|
||||
def get_name(self,
|
||||
user_id: int = None,
|
||||
platform: str = None,
|
||||
user_info: UserInfo = None) -> str:
|
||||
"""获取用户昵称
|
||||
Args:
|
||||
user_id: 用户ID(可选,如果提供user_info则不需要)
|
||||
platform: 平台(可选,如果提供user_info则不需要)
|
||||
user_info: 用户信息对象(可选)
|
||||
Returns:
|
||||
str: 用户昵称
|
||||
"""
|
||||
# 确定user_id和platform
|
||||
if user_info is not None:
|
||||
user_id = user_info.user_id
|
||||
platform = user_info.platform or 'qq'
|
||||
else:
|
||||
platform = platform or 'qq'
|
||||
|
||||
if user_id is None:
|
||||
raise ValueError("必须提供user_id或user_info")
|
||||
|
||||
# 确保user_id是整数类型
|
||||
user_id = int(user_id)
|
||||
if user_id in self.relationships:
|
||||
|
||||
return self.relationships[user_id].nickname
|
||||
key = (user_id, platform)
|
||||
if key in self.relationships:
|
||||
return self.relationships[key].nickname
|
||||
elif user_info is not None:
|
||||
return user_info.user_nickname or user_info.user_cardname or "某人"
|
||||
else:
|
||||
return "某人"
|
||||
|
||||
|
||||
@@ -1,48 +1,30 @@
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from ...common.database import Database
|
||||
from .message import Message
|
||||
from .message_base import MessageBase
|
||||
from .message import MessageSending, MessageRecv
|
||||
from .chat_stream import ChatStream
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class MessageStorage:
|
||||
def __init__(self):
|
||||
self.db = Database.get_instance()
|
||||
|
||||
async def store_message(self, message: Message, topic: Optional[str] = None) -> None:
|
||||
|
||||
async def store_message(self, message: Union[MessageSending, MessageRecv],chat_stream:ChatStream, topic: Optional[str] = None) -> None:
|
||||
"""存储消息到数据库"""
|
||||
try:
|
||||
if not message.is_emoji:
|
||||
message_data = {
|
||||
"group_id": message.group_id,
|
||||
"user_id": message.user_id,
|
||||
"message_id": message.message_id,
|
||||
"raw_message": message.raw_message,
|
||||
"plain_text": message.plain_text,
|
||||
message_data = {
|
||||
"message_id": message.message_info.message_id,
|
||||
"time": message.message_info.time,
|
||||
"chat_id":chat_stream.stream_id,
|
||||
"chat_info": chat_stream.to_dict(),
|
||||
"user_info": message.message_info.user_info.to_dict(),
|
||||
"processed_plain_text": message.processed_plain_text,
|
||||
"time": message.time,
|
||||
"user_nickname": message.user_nickname,
|
||||
"user_cardname": message.user_cardname,
|
||||
"group_name": message.group_name,
|
||||
"topic": topic,
|
||||
"detailed_plain_text": message.detailed_plain_text,
|
||||
}
|
||||
else:
|
||||
message_data = {
|
||||
"group_id": message.group_id,
|
||||
"user_id": message.user_id,
|
||||
"message_id": message.message_id,
|
||||
"raw_message": message.raw_message,
|
||||
"plain_text": message.plain_text,
|
||||
"processed_plain_text": '[表情包]',
|
||||
"time": message.time,
|
||||
"user_nickname": message.user_nickname,
|
||||
"user_cardname": message.user_cardname,
|
||||
"group_name": message.group_name,
|
||||
"topic": topic,
|
||||
"detailed_plain_text": message.detailed_plain_text,
|
||||
}
|
||||
|
||||
self.db.db.messages.insert_one(message_data)
|
||||
except Exception:
|
||||
logger.exception("存储消息失败")
|
||||
|
||||
@@ -12,32 +12,15 @@ from loguru import logger
|
||||
from ..models.utils_model import LLM_request
|
||||
from ..utils.typo_generator import ChineseTypoGenerator
|
||||
from .config import global_config
|
||||
from .message import Message
|
||||
from .message import MessageThinking, MessageRecv,MessageSending,MessageProcessBase,Message
|
||||
from .message_base import MessageBase,BaseMessageInfo,UserInfo,GroupInfo
|
||||
from .chat_stream import ChatStream
|
||||
from ..moods.moods import MoodManager
|
||||
|
||||
driver = get_driver()
|
||||
config = driver.config
|
||||
|
||||
|
||||
def combine_messages(messages: List[Message]) -> str:
|
||||
"""将消息列表组合成格式化的字符串
|
||||
|
||||
Args:
|
||||
messages: Message对象列表
|
||||
|
||||
Returns:
|
||||
str: 格式化后的消息字符串
|
||||
"""
|
||||
result = ""
|
||||
for message in messages:
|
||||
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message.time))
|
||||
name = message.user_nickname or f"用户{message.user_id}"
|
||||
content = message.processed_plain_text or message.plain_text
|
||||
|
||||
result += f"[{time_str}] {name}: {content}\n"
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def db_message_to_str(message_dict: Dict) -> str:
|
||||
logger.debug(f"message_dict: {message_dict}")
|
||||
@@ -53,14 +36,11 @@ def db_message_to_str(message_dict: Dict) -> str:
|
||||
return result
|
||||
|
||||
|
||||
def is_mentioned_bot_in_txt(message: str) -> bool:
|
||||
def is_mentioned_bot_in_message(message: MessageRecv) -> bool:
|
||||
"""检查消息是否提到了机器人"""
|
||||
if global_config.BOT_NICKNAME is None:
|
||||
return True
|
||||
if global_config.BOT_NICKNAME in message:
|
||||
return True
|
||||
for keyword in global_config.BOT_ALIAS_NAMES:
|
||||
if keyword in message:
|
||||
keywords = [global_config.BOT_NICKNAME]
|
||||
for keyword in keywords:
|
||||
if keyword in message.processed_plain_text:
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -93,46 +73,45 @@ def calculate_information_content(text):
|
||||
|
||||
|
||||
def get_cloest_chat_from_db(db, length: int, timestamp: str):
|
||||
"""从数据库中获取最接近指定时间戳的聊天记录,并记录读取次数
|
||||
"""从数据库中获取最接近指定时间戳的聊天记录
|
||||
|
||||
Args:
|
||||
db: 数据库实例
|
||||
length: 要获取的消息数量
|
||||
timestamp: 时间戳
|
||||
|
||||
Returns:
|
||||
list: 消息记录字典列表,每个字典包含消息内容和时间信息
|
||||
list: 消息记录列表,每个记录包含时间和文本信息
|
||||
"""
|
||||
chat_records = []
|
||||
closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)])
|
||||
|
||||
if closest_record and closest_record.get('memorized', 0) < 4:
|
||||
if closest_record:
|
||||
closest_time = closest_record['time']
|
||||
group_id = closest_record['group_id']
|
||||
# 获取该时间戳之后的length条消息,且groupid相同
|
||||
records = list(db.db.messages.find(
|
||||
{"time": {"$gt": closest_time}, "group_id": group_id}
|
||||
chat_id = closest_record['chat_id'] # 获取chat_id
|
||||
# 获取该时间戳之后的length条消息,保持相同的chat_id
|
||||
chat_records = list(db.db.messages.find(
|
||||
{
|
||||
"time": {"$gt": closest_time},
|
||||
"chat_id": chat_id # 添加chat_id过滤
|
||||
}
|
||||
).sort('time', 1).limit(length))
|
||||
|
||||
# 更新每条消息的memorized属性
|
||||
for record in records:
|
||||
current_memorized = record.get('memorized', 0)
|
||||
if current_memorized > 3:
|
||||
print("消息已读取3次,跳过")
|
||||
return ''
|
||||
|
||||
# 更新memorized值
|
||||
db.db.messages.update_one(
|
||||
{"_id": record["_id"]},
|
||||
{"$set": {"memorized": current_memorized + 1}}
|
||||
)
|
||||
|
||||
# 添加到记录列表中
|
||||
chat_records.append({
|
||||
'text': record["detailed_plain_text"],
|
||||
# 转换记录格式
|
||||
formatted_records = []
|
||||
for record in chat_records:
|
||||
formatted_records.append({
|
||||
'time': record["time"],
|
||||
'group_id': record["group_id"]
|
||||
'chat_id': record["chat_id"],
|
||||
'detailed_plain_text': record.get("detailed_plain_text", "") # 添加文本内容
|
||||
})
|
||||
|
||||
return chat_records
|
||||
return formatted_records
|
||||
|
||||
return []
|
||||
|
||||
|
||||
async def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
|
||||
async def get_recent_group_messages(db, chat_id:str, limit: int = 12) -> list:
|
||||
"""从数据库获取群组最近的消息记录
|
||||
|
||||
Args:
|
||||
@@ -146,35 +125,28 @@ async def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
|
||||
|
||||
# 从数据库获取最近消息
|
||||
recent_messages = list(db.db.messages.find(
|
||||
{"group_id": group_id},
|
||||
# {
|
||||
# "time": 1,
|
||||
# "user_id": 1,
|
||||
# "user_nickname": 1,
|
||||
# "message_id": 1,
|
||||
# "raw_message": 1,
|
||||
# "processed_text": 1
|
||||
# }
|
||||
{"chat_id": chat_id},
|
||||
).sort("time", -1).limit(limit))
|
||||
|
||||
if not recent_messages:
|
||||
return []
|
||||
|
||||
# 转换为 Message对象列表
|
||||
from .message import Message
|
||||
message_objects = []
|
||||
for msg_data in recent_messages:
|
||||
try:
|
||||
chat_info=msg_data.get("chat_info",{})
|
||||
chat_stream=ChatStream.from_dict(chat_info)
|
||||
user_info=msg_data.get("user_info",{})
|
||||
user_info=UserInfo.from_dict(user_info)
|
||||
msg = Message(
|
||||
time=msg_data["time"],
|
||||
user_id=msg_data["user_id"],
|
||||
user_nickname=msg_data.get("user_nickname", ""),
|
||||
message_id=msg_data["message_id"],
|
||||
raw_message=msg_data["raw_message"],
|
||||
chat_stream=chat_stream,
|
||||
time=msg_data["time"],
|
||||
user_info=user_info,
|
||||
processed_plain_text=msg_data.get("processed_text", ""),
|
||||
group_id=group_id
|
||||
detailed_plain_text=msg_data.get("detailed_plain_text", "")
|
||||
)
|
||||
await msg.initialize()
|
||||
message_objects.append(msg)
|
||||
except KeyError:
|
||||
logger.warning("数据库中存在无效的消息")
|
||||
@@ -185,13 +157,14 @@ async def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
|
||||
return message_objects
|
||||
|
||||
|
||||
def get_recent_group_detailed_plain_text(db, group_id: int, limit: int = 12, combine=False):
|
||||
def get_recent_group_detailed_plain_text(db, chat_stream_id: int, limit: int = 12, combine=False):
|
||||
recent_messages = list(db.db.messages.find(
|
||||
{"group_id": group_id},
|
||||
{"chat_id": chat_stream_id},
|
||||
{
|
||||
"time": 1, # 返回时间字段
|
||||
"user_id": 1, # 返回用户ID字段
|
||||
"user_nickname": 1, # 返回用户昵称字段
|
||||
"chat_id":1,
|
||||
"chat_info":1,
|
||||
"user_info": 1,
|
||||
"message_id": 1, # 返回消息ID字段
|
||||
"detailed_plain_text": 1 # 返回处理后的文本字段
|
||||
}
|
||||
|
||||
@@ -2,7 +2,11 @@ import base64
|
||||
import io
|
||||
import os
|
||||
import time
|
||||
import zlib # 用于 CRC32
|
||||
import zlib
|
||||
import aiohttp
|
||||
import hashlib
|
||||
from typing import Optional, Tuple, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from loguru import logger
|
||||
from nonebot import get_driver
|
||||
@@ -10,280 +14,353 @@ from PIL import Image
|
||||
|
||||
from ...common.database import Database
|
||||
from ..chat.config import global_config
|
||||
|
||||
from ..models.utils_model import LLM_request
|
||||
driver = get_driver()
|
||||
config = driver.config
|
||||
|
||||
|
||||
|
||||
def storage_compress_image(base64_data: str, max_size: int = 200) -> str:
|
||||
"""
|
||||
压缩base64格式的图片到指定大小(单位:KB)并在数据库中记录图片信息
|
||||
Args:
|
||||
base64_data: base64编码的图片数据
|
||||
max_size: 最大文件大小(KB)
|
||||
Returns:
|
||||
str: 压缩后的base64图片数据
|
||||
"""
|
||||
try:
|
||||
# 将base64转换为字节数据
|
||||
image_data = base64.b64decode(base64_data)
|
||||
|
||||
# 使用 CRC32 计算哈希值
|
||||
hash_value = format(zlib.crc32(image_data) & 0xFFFFFFFF, 'x')
|
||||
|
||||
# 确保图片目录存在
|
||||
images_dir = "data/images"
|
||||
os.makedirs(images_dir, exist_ok=True)
|
||||
|
||||
# 连接数据库
|
||||
db = Database.get_instance()
|
||||
|
||||
# 检查是否已存在相同哈希值的图片
|
||||
collection = db.db['images']
|
||||
existing_image = collection.find_one({'hash': hash_value})
|
||||
|
||||
if existing_image:
|
||||
print(f"\033[1;33m[提示]\033[0m 发现重复图片,使用已存在的文件: {existing_image['path']}")
|
||||
return base64_data
|
||||
|
||||
# 将字节数据转换为图片对象
|
||||
img = Image.open(io.BytesIO(image_data))
|
||||
|
||||
# 如果是动图,直接返回原图
|
||||
if getattr(img, 'is_animated', False):
|
||||
return base64_data
|
||||
|
||||
# 计算当前大小(KB)
|
||||
current_size = len(image_data) / 1024
|
||||
|
||||
# 如果已经小于目标大小,直接使用原图
|
||||
if current_size <= max_size:
|
||||
compressed_data = image_data
|
||||
else:
|
||||
# 压缩逻辑
|
||||
# 先缩放到50%
|
||||
new_width = int(img.width * 0.5)
|
||||
new_height = int(img.height * 0.5)
|
||||
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||||
|
||||
# 如果缩放后的最大边长仍然大于400,继续缩放
|
||||
max_dimension = 400
|
||||
max_current = max(new_width, new_height)
|
||||
if max_current > max_dimension:
|
||||
ratio = max_dimension / max_current
|
||||
new_width = int(new_width * ratio)
|
||||
new_height = int(new_height * ratio)
|
||||
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||||
|
||||
# 转换为RGB模式(去除透明通道)
|
||||
if img.mode in ('RGBA', 'P'):
|
||||
img = img.convert('RGB')
|
||||
|
||||
# 使用固定质量参数压缩
|
||||
output = io.BytesIO()
|
||||
img.save(output, format='JPEG', quality=85, optimize=True)
|
||||
compressed_data = output.getvalue()
|
||||
|
||||
# 生成文件名(使用时间戳和哈希值确保唯一性)
|
||||
timestamp = int(time.time())
|
||||
filename = f"{timestamp}_{hash_value}.jpg"
|
||||
image_path = os.path.join(images_dir, filename)
|
||||
|
||||
# 保存文件
|
||||
with open(image_path, "wb") as f:
|
||||
f.write(compressed_data)
|
||||
|
||||
print(f"\033[1;32m[成功]\033[0m 保存图片到: {image_path}")
|
||||
|
||||
try:
|
||||
# 准备数据库记录
|
||||
image_record = {
|
||||
'filename': filename,
|
||||
'path': image_path,
|
||||
'size': len(compressed_data) / 1024,
|
||||
'timestamp': timestamp,
|
||||
'width': img.width,
|
||||
'height': img.height,
|
||||
'description': '',
|
||||
'tags': [],
|
||||
'type': 'image',
|
||||
'hash': hash_value
|
||||
}
|
||||
|
||||
# 保存记录
|
||||
collection.insert_one(image_record)
|
||||
print("\033[1;32m[成功]\033[0m 保存图片记录到数据库")
|
||||
|
||||
except Exception as db_error:
|
||||
print(f"\033[1;31m[错误]\033[0m 数据库操作失败: {str(db_error)}")
|
||||
|
||||
# 将压缩后的数据转换为base64
|
||||
compressed_base64 = base64.b64encode(compressed_data).decode('utf-8')
|
||||
return compressed_base64
|
||||
|
||||
except Exception as e:
|
||||
print(f"\033[1;31m[错误]\033[0m 压缩图片失败: {str(e)}")
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
return base64_data
|
||||
|
||||
def storage_emoji(image_data: bytes) -> bytes:
|
||||
"""
|
||||
存储表情包到本地文件夹
|
||||
Args:
|
||||
image_data: 图片字节数据
|
||||
group_id: 群组ID(仅用于日志)
|
||||
user_id: 用户ID(仅用于日志)
|
||||
Returns:
|
||||
bytes: 原始图片数据
|
||||
"""
|
||||
if not global_config.EMOJI_SAVE:
|
||||
return image_data
|
||||
try:
|
||||
# 使用 CRC32 计算哈希值
|
||||
hash_value = format(zlib.crc32(image_data) & 0xFFFFFFFF, 'x')
|
||||
|
||||
# 确保表情包目录存在
|
||||
emoji_dir = "data/emoji"
|
||||
os.makedirs(emoji_dir, exist_ok=True)
|
||||
|
||||
# 检查是否已存在相同哈希值的文件
|
||||
for filename in os.listdir(emoji_dir):
|
||||
if hash_value in filename:
|
||||
# print(f"\033[1;33m[提示]\033[0m 发现重复表情包: {filename}")
|
||||
return image_data
|
||||
|
||||
# 生成文件名
|
||||
timestamp = int(time.time())
|
||||
filename = f"{timestamp}_{hash_value}.jpg"
|
||||
emoji_path = os.path.join(emoji_dir, filename)
|
||||
|
||||
# 直接保存原始文件
|
||||
with open(emoji_path, "wb") as f:
|
||||
f.write(image_data)
|
||||
|
||||
print(f"\033[1;32m[成功]\033[0m 保存表情包到: {emoji_path}")
|
||||
return image_data
|
||||
|
||||
except Exception as e:
|
||||
print(f"\033[1;31m[错误]\033[0m 保存表情包失败: {str(e)}")
|
||||
return image_data
|
||||
class ImageManager:
|
||||
_instance = None
|
||||
IMAGE_DIR = "data" # 图像存储根目录
|
||||
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance.db = None
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self):
|
||||
if not self._initialized:
|
||||
self.db = Database.get_instance()
|
||||
self._ensure_image_collection()
|
||||
self._ensure_description_collection()
|
||||
self._ensure_image_dir()
|
||||
self._initialized = True
|
||||
self._llm = LLM_request(model=global_config.vlm, temperature=0.4, max_tokens=300)
|
||||
|
||||
def _ensure_image_dir(self):
|
||||
"""确保图像存储目录存在"""
|
||||
os.makedirs(self.IMAGE_DIR, exist_ok=True)
|
||||
|
||||
def _ensure_image_collection(self):
|
||||
"""确保images集合存在并创建索引"""
|
||||
if 'images' not in self.db.db.list_collection_names():
|
||||
self.db.db.create_collection('images')
|
||||
# 创建索引
|
||||
self.db.db.images.create_index([('hash', 1)], unique=True)
|
||||
self.db.db.images.create_index([('url', 1)])
|
||||
self.db.db.images.create_index([('path', 1)])
|
||||
|
||||
def storage_image(image_data: bytes) -> bytes:
|
||||
"""
|
||||
存储图片到本地文件夹
|
||||
Args:
|
||||
image_data: 图片字节数据
|
||||
group_id: 群组ID(仅用于日志)
|
||||
user_id: 用户ID(仅用于日志)
|
||||
Returns:
|
||||
bytes: 原始图片数据
|
||||
"""
|
||||
try:
|
||||
# 使用 CRC32 计算哈希值
|
||||
hash_value = format(zlib.crc32(image_data) & 0xFFFFFFFF, 'x')
|
||||
|
||||
# 确保表情包目录存在
|
||||
image_dir = "data/image"
|
||||
os.makedirs(image_dir, exist_ok=True)
|
||||
|
||||
# 检查是否已存在相同哈希值的文件
|
||||
for filename in os.listdir(image_dir):
|
||||
if hash_value in filename:
|
||||
# print(f"\033[1;33m[提示]\033[0m 发现重复表情包: {filename}")
|
||||
return image_data
|
||||
|
||||
# 生成文件名
|
||||
timestamp = int(time.time())
|
||||
filename = f"{timestamp}_{hash_value}.jpg"
|
||||
image_path = os.path.join(image_dir, filename)
|
||||
|
||||
# 直接保存原始文件
|
||||
with open(image_path, "wb") as f:
|
||||
f.write(image_data)
|
||||
|
||||
print(f"\033[1;32m[成功]\033[0m 保存图片到: {image_path}")
|
||||
return image_data
|
||||
|
||||
except Exception as e:
|
||||
print(f"\033[1;31m[错误]\033[0m 保存图片失败: {str(e)}")
|
||||
return image_data
|
||||
def _ensure_description_collection(self):
|
||||
"""确保image_descriptions集合存在并创建索引"""
|
||||
if 'image_descriptions' not in self.db.db.list_collection_names():
|
||||
self.db.db.create_collection('image_descriptions')
|
||||
# 创建索引
|
||||
self.db.db.image_descriptions.create_index([('hash', 1)], unique=True)
|
||||
self.db.db.image_descriptions.create_index([('type', 1)])
|
||||
|
||||
def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 1024 * 1024) -> str:
|
||||
"""压缩base64格式的图片到指定大小
|
||||
Args:
|
||||
base64_data: base64编码的图片数据
|
||||
target_size: 目标文件大小(字节),默认0.8MB
|
||||
Returns:
|
||||
str: 压缩后的base64图片数据
|
||||
"""
|
||||
try:
|
||||
# 将base64转换为字节数据
|
||||
image_data = base64.b64decode(base64_data)
|
||||
def _get_description_from_db(self, image_hash: str, description_type: str) -> Optional[str]:
|
||||
"""从数据库获取图片描述
|
||||
|
||||
# 如果已经小于目标大小,直接返回原图
|
||||
if len(image_data) <= 2*1024*1024:
|
||||
return base64_data
|
||||
Args:
|
||||
image_hash: 图片哈希值
|
||||
description_type: 描述类型 ('emoji' 或 'image')
|
||||
|
||||
# 将字节数据转换为图片对象
|
||||
img = Image.open(io.BytesIO(image_data))
|
||||
Returns:
|
||||
Optional[str]: 描述文本,如果不存在则返回None
|
||||
"""
|
||||
result= self.db.db.image_descriptions.find_one({
|
||||
'hash': image_hash,
|
||||
'type': description_type
|
||||
})
|
||||
return result['description'] if result else None
|
||||
|
||||
def _save_description_to_db(self, image_hash: str, description: str, description_type: str) -> None:
|
||||
"""保存图片描述到数据库
|
||||
|
||||
# 获取原始尺寸
|
||||
original_width, original_height = img.size
|
||||
|
||||
# 计算缩放比例
|
||||
scale = min(1.0, (target_size / len(image_data)) ** 0.5)
|
||||
|
||||
# 计算新的尺寸
|
||||
new_width = int(original_width * scale)
|
||||
new_height = int(original_height * scale)
|
||||
|
||||
# 创建内存缓冲区
|
||||
output_buffer = io.BytesIO()
|
||||
|
||||
# 如果是GIF,处理所有帧
|
||||
if getattr(img, "is_animated", False):
|
||||
frames = []
|
||||
for frame_idx in range(img.n_frames):
|
||||
img.seek(frame_idx)
|
||||
new_frame = img.copy()
|
||||
new_frame = new_frame.resize((new_width//2, new_height//2), Image.Resampling.LANCZOS) # 动图折上折
|
||||
frames.append(new_frame)
|
||||
|
||||
# 保存到缓冲区
|
||||
frames[0].save(
|
||||
output_buffer,
|
||||
format='GIF',
|
||||
save_all=True,
|
||||
append_images=frames[1:],
|
||||
optimize=True,
|
||||
duration=img.info.get('duration', 100),
|
||||
loop=img.info.get('loop', 0)
|
||||
)
|
||||
else:
|
||||
# 处理静态图片
|
||||
resized_img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||||
|
||||
# 保存到缓冲区,保持原始格式
|
||||
if img.format == 'PNG' and img.mode in ('RGBA', 'LA'):
|
||||
resized_img.save(output_buffer, format='PNG', optimize=True)
|
||||
Args:
|
||||
image_hash: 图片哈希值
|
||||
description: 描述文本
|
||||
description_type: 描述类型 ('emoji' 或 'image')
|
||||
"""
|
||||
self.db.db.image_descriptions.update_one(
|
||||
{'hash': image_hash, 'type': description_type},
|
||||
{
|
||||
'$set': {
|
||||
'description': description,
|
||||
'timestamp': int(time.time())
|
||||
}
|
||||
},
|
||||
upsert=True
|
||||
)
|
||||
|
||||
async def save_image(self,
|
||||
image_data: Union[str, bytes],
|
||||
url: str = None,
|
||||
description: str = None,
|
||||
is_base64: bool = False) -> Optional[str]:
|
||||
"""保存图像
|
||||
Args:
|
||||
image_data: 图像数据(base64字符串或字节)
|
||||
url: 图像URL
|
||||
description: 图像描述
|
||||
is_base64: image_data是否为base64格式
|
||||
Returns:
|
||||
str: 保存后的文件路径,失败返回None
|
||||
"""
|
||||
try:
|
||||
# 转换为字节格式
|
||||
if is_base64:
|
||||
if isinstance(image_data, str):
|
||||
image_bytes = base64.b64decode(image_data)
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
resized_img.save(output_buffer, format='JPEG', quality=95, optimize=True)
|
||||
if isinstance(image_data, bytes):
|
||||
image_bytes = image_data
|
||||
else:
|
||||
return None
|
||||
|
||||
# 计算哈希值
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
|
||||
# 查重
|
||||
existing = self.db.db.images.find_one({'hash': image_hash})
|
||||
if existing:
|
||||
return existing['path']
|
||||
|
||||
# 生成文件名和路径
|
||||
timestamp = int(time.time())
|
||||
filename = f"{timestamp}_{image_hash[:8]}.jpg"
|
||||
file_path = os.path.join(self.IMAGE_DIR, filename)
|
||||
|
||||
# 保存文件
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
|
||||
# 保存到数据库
|
||||
image_doc = {
|
||||
'hash': image_hash,
|
||||
'path': file_path,
|
||||
'url': url,
|
||||
'description': description,
|
||||
'timestamp': timestamp
|
||||
}
|
||||
self.db.db.images.insert_one(image_doc)
|
||||
|
||||
return file_path
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存图像失败: {str(e)}")
|
||||
return None
|
||||
|
||||
async def get_image_by_url(self, url: str) -> Optional[str]:
|
||||
"""根据URL获取图像路径(带查重)
|
||||
Args:
|
||||
url: 图像URL
|
||||
Returns:
|
||||
str: 本地文件路径,不存在返回None
|
||||
"""
|
||||
try:
|
||||
# 先查找是否已存在
|
||||
existing = self.db.db.images.find_one({'url': url})
|
||||
if existing:
|
||||
return existing['path']
|
||||
|
||||
# 下载图像
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as resp:
|
||||
if resp.status == 200:
|
||||
image_bytes = await resp.read()
|
||||
return await self.save_image(image_bytes, url=url)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取图像失败: {str(e)}")
|
||||
return None
|
||||
|
||||
async def get_base64_by_url(self, url: str) -> Optional[str]:
|
||||
"""根据URL获取base64(带查重)
|
||||
Args:
|
||||
url: 图像URL
|
||||
Returns:
|
||||
str: base64字符串,失败返回None
|
||||
"""
|
||||
try:
|
||||
image_path = await self.get_image_by_url(url)
|
||||
if not image_path:
|
||||
return None
|
||||
|
||||
with open(image_path, 'rb') as f:
|
||||
image_bytes = f.read()
|
||||
return base64.b64encode(image_bytes).decode('utf-8')
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取base64失败: {str(e)}")
|
||||
return None
|
||||
|
||||
async def save_base64_image(self, base64_str: str, description: str = None) -> Optional[str]:
|
||||
"""保存base64图像(带查重)
|
||||
Args:
|
||||
base64_str: base64字符串
|
||||
description: 图像描述
|
||||
Returns:
|
||||
str: 保存路径,失败返回None
|
||||
"""
|
||||
return await self.save_image(base64_str, description=description, is_base64=True)
|
||||
|
||||
# 获取压缩后的数据并转换为base64
|
||||
compressed_data = output_buffer.getvalue()
|
||||
logger.success(f"压缩图片: {original_width}x{original_height} -> {new_width}x{new_height}")
|
||||
logger.info(f"压缩前大小: {len(image_data)/1024:.1f}KB, 压缩后大小: {len(compressed_data)/1024:.1f}KB")
|
||||
def check_url_exists(self, url: str) -> bool:
|
||||
"""检查URL是否已存在
|
||||
Args:
|
||||
url: 图像URL
|
||||
Returns:
|
||||
bool: 是否存在
|
||||
"""
|
||||
return self.db.db.images.find_one({'url': url}) is not None
|
||||
|
||||
return base64.b64encode(compressed_data).decode('utf-8')
|
||||
def check_hash_exists(self, image_data: Union[str, bytes], is_base64: bool = False) -> bool:
|
||||
"""检查图像是否已存在
|
||||
Args:
|
||||
image_data: 图像数据(base64或字节)
|
||||
is_base64: 是否为base64格式
|
||||
Returns:
|
||||
bool: 是否存在
|
||||
"""
|
||||
try:
|
||||
if is_base64:
|
||||
if isinstance(image_data, str):
|
||||
image_bytes = base64.b64decode(image_data)
|
||||
else:
|
||||
return False
|
||||
else:
|
||||
if isinstance(image_data, bytes):
|
||||
image_bytes = image_data
|
||||
else:
|
||||
return False
|
||||
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
return self.db.db.images.find_one({'hash': image_hash}) is not None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"检查哈希失败: {str(e)}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"压缩图片失败: {str(e)}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return base64_data
|
||||
async def get_emoji_description(self, image_base64: str) -> str:
|
||||
"""获取表情包描述,带查重和保存功能"""
|
||||
try:
|
||||
# 计算图片哈希
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
|
||||
# 查询缓存的描述
|
||||
cached_description = self._get_description_from_db(image_hash, 'emoji')
|
||||
if cached_description:
|
||||
logger.info(f"缓存表情包描述: {cached_description}")
|
||||
return f"[表情包:{cached_description}]"
|
||||
|
||||
# 调用AI获取描述
|
||||
prompt = "这是一个表情包,使用中文简洁的描述一下表情包的内容和表情包所表达的情感"
|
||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
|
||||
|
||||
# 根据配置决定是否保存图片
|
||||
if global_config.EMOJI_SAVE:
|
||||
# 生成文件名和路径
|
||||
timestamp = int(time.time())
|
||||
filename = f"emoji_{timestamp}_{image_hash[:8]}.jpg"
|
||||
file_path = os.path.join(self.IMAGE_DIR, filename)
|
||||
|
||||
try:
|
||||
# 保存文件
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
|
||||
# 保存到数据库
|
||||
image_doc = {
|
||||
'hash': image_hash,
|
||||
'path': file_path,
|
||||
'type': 'emoji',
|
||||
'description': description,
|
||||
'timestamp': timestamp
|
||||
}
|
||||
self.db.db.images.update_one(
|
||||
{'hash': image_hash},
|
||||
{'$set': image_doc},
|
||||
upsert=True
|
||||
)
|
||||
logger.success(f"保存表情包: {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"保存表情包文件失败: {str(e)}")
|
||||
|
||||
# 保存描述到数据库
|
||||
self._save_description_to_db(image_hash, description, 'emoji')
|
||||
|
||||
return f"[表情包:{description}]"
|
||||
except Exception as e:
|
||||
logger.error(f"获取表情包描述失败: {str(e)}")
|
||||
return "[表情包]"
|
||||
|
||||
async def get_image_description(self, image_base64: str) -> str:
|
||||
"""获取普通图片描述,带查重和保存功能"""
|
||||
try:
|
||||
# 计算图片哈希
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||
|
||||
# 查询缓存的描述
|
||||
cached_description = self._get_description_from_db(image_hash, 'image')
|
||||
if cached_description:
|
||||
return f"[图片:{cached_description}]"
|
||||
|
||||
# 调用AI获取描述
|
||||
prompt = "请用中文描述这张图片的内容。如果有文字,请把文字都描述出来。并尝试猜测这个图片的含义。最多200个字。"
|
||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
|
||||
|
||||
if description is None:
|
||||
logger.warning("AI未能生成图片描述")
|
||||
return "[图片]"
|
||||
|
||||
# 根据配置决定是否保存图片
|
||||
if global_config.EMOJI_SAVE:
|
||||
# 生成文件名和路径
|
||||
timestamp = int(time.time())
|
||||
filename = f"image_{timestamp}_{image_hash[:8]}.jpg"
|
||||
file_path = os.path.join(self.IMAGE_DIR, filename)
|
||||
|
||||
try:
|
||||
# 保存文件
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
|
||||
# 保存到数据库
|
||||
image_doc = {
|
||||
'hash': image_hash,
|
||||
'path': file_path,
|
||||
'type': 'image',
|
||||
'description': description,
|
||||
'timestamp': timestamp
|
||||
}
|
||||
self.db.db.images.update_one(
|
||||
{'hash': image_hash},
|
||||
{'$set': image_doc},
|
||||
upsert=True
|
||||
)
|
||||
logger.success(f"保存图片: {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"保存图片文件失败: {str(e)}")
|
||||
|
||||
# 保存描述到数据库
|
||||
self._save_description_to_db(image_hash, description, 'image')
|
||||
|
||||
return f"[图片:{description}]"
|
||||
except Exception as e:
|
||||
logger.error(f"获取图片描述失败: {str(e)}")
|
||||
return "[图片]"
|
||||
|
||||
|
||||
|
||||
# 创建全局单例
|
||||
image_manager = ImageManager()
|
||||
|
||||
|
||||
def image_path_to_base64(image_path: str) -> str:
|
||||
"""将图片路径转换为base64编码
|
||||
|
||||
@@ -1,107 +1,109 @@
|
||||
import asyncio
|
||||
from typing import Dict
|
||||
from loguru import logger
|
||||
|
||||
from typing import Dict
|
||||
from loguru import logger
|
||||
|
||||
from .config import global_config
|
||||
from .message_base import UserInfo, GroupInfo
|
||||
from .chat_stream import chat_manager,ChatStream
|
||||
|
||||
|
||||
class WillingManager:
|
||||
def __init__(self):
|
||||
self.group_reply_willing = {} # 存储每个群的回复意愿
|
||||
self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿
|
||||
self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿
|
||||
self._decay_task = None
|
||||
self._started = False
|
||||
self.min_reply_willing = 0.01
|
||||
self.attenuation_coefficient = 0.75
|
||||
|
||||
|
||||
async def _decay_reply_willing(self):
|
||||
"""定期衰减回复意愿"""
|
||||
while True:
|
||||
await asyncio.sleep(5)
|
||||
for group_id in self.group_reply_willing:
|
||||
self.group_reply_willing[group_id] = max(
|
||||
self.min_reply_willing,
|
||||
self.group_reply_willing[group_id] * self.attenuation_coefficient
|
||||
)
|
||||
|
||||
def get_willing(self, group_id: int) -> float:
|
||||
"""获取指定群组的回复意愿"""
|
||||
return self.group_reply_willing.get(group_id, 0)
|
||||
|
||||
def set_willing(self, group_id: int, willing: float):
|
||||
"""设置指定群组的回复意愿"""
|
||||
self.group_reply_willing[group_id] = willing
|
||||
|
||||
def change_reply_willing_received(self, group_id: int, topic: str, is_mentioned_bot: bool, config,
|
||||
user_id: int = None, is_emoji: bool = False, interested_rate: float = 0) -> float:
|
||||
|
||||
# 若非目标回复群组,则直接return
|
||||
if group_id not in config.talk_allowed_groups:
|
||||
reply_probability = 0
|
||||
return reply_probability
|
||||
|
||||
current_willing = self.group_reply_willing.get(group_id, 0)
|
||||
|
||||
logger.debug(f"[{group_id}]的初始回复意愿: {current_willing}")
|
||||
|
||||
# 根据消息类型(被cue/表情包)调控
|
||||
if is_mentioned_bot:
|
||||
current_willing = min(
|
||||
3.0,
|
||||
current_willing + 0.9
|
||||
)
|
||||
logger.debug(f"被提及, 当前意愿: {current_willing}")
|
||||
|
||||
for chat_id in self.chat_reply_willing:
|
||||
self.chat_reply_willing[chat_id] = max(0, self.chat_reply_willing[chat_id] * 0.6)
|
||||
for chat_id in self.chat_reply_willing:
|
||||
self.chat_reply_willing[chat_id] = max(0, self.chat_reply_willing[chat_id] * 0.6)
|
||||
|
||||
def get_willing(self,chat_stream:ChatStream) -> float:
|
||||
"""获取指定聊天流的回复意愿"""
|
||||
stream = chat_stream
|
||||
if stream:
|
||||
return self.chat_reply_willing.get(stream.stream_id, 0)
|
||||
return 0
|
||||
|
||||
def set_willing(self, chat_id: str, willing: float):
|
||||
"""设置指定聊天流的回复意愿"""
|
||||
self.chat_reply_willing[chat_id] = willing
|
||||
def set_willing(self, chat_id: str, willing: float):
|
||||
"""设置指定聊天流的回复意愿"""
|
||||
self.chat_reply_willing[chat_id] = willing
|
||||
|
||||
async def change_reply_willing_received(self,
|
||||
chat_stream:ChatStream,
|
||||
topic: str = None,
|
||||
is_mentioned_bot: bool = False,
|
||||
config = None,
|
||||
is_emoji: bool = False,
|
||||
interested_rate: float = 0) -> float:
|
||||
"""改变指定聊天流的回复意愿并返回回复概率"""
|
||||
# 获取或创建聊天流
|
||||
stream = chat_stream
|
||||
chat_id = stream.stream_id
|
||||
|
||||
current_willing = self.chat_reply_willing.get(chat_id, 0)
|
||||
|
||||
# print(f"初始意愿: {current_willing}")
|
||||
if is_mentioned_bot and current_willing < 1.0:
|
||||
current_willing += 0.9
|
||||
print(f"被提及, 当前意愿: {current_willing}")
|
||||
elif is_mentioned_bot:
|
||||
current_willing += 0.05
|
||||
print(f"被重复提及, 当前意愿: {current_willing}")
|
||||
|
||||
if is_emoji:
|
||||
current_willing *= 0.1
|
||||
logger.debug(f"表情包, 当前意愿: {current_willing}")
|
||||
|
||||
# 兴趣放大系数,若兴趣 > 0.4则增加回复概率
|
||||
interested_rate_amplifier = global_config.response_interested_rate_amplifier
|
||||
logger.debug(f"放大系数_interested_rate: {interested_rate_amplifier}")
|
||||
interested_rate *= interested_rate_amplifier
|
||||
|
||||
current_willing += max(
|
||||
0.0,
|
||||
interested_rate - 0.4
|
||||
)
|
||||
|
||||
# 回复意愿系数调控,独立乘区
|
||||
willing_amplifier = max(
|
||||
global_config.response_willing_amplifier,
|
||||
self.min_reply_willing
|
||||
)
|
||||
current_willing *= willing_amplifier
|
||||
logger.debug(f"放大系数_willing: {global_config.response_willing_amplifier}, 当前意愿: {current_willing}")
|
||||
|
||||
# 回复概率迭代,保底0.01回复概率
|
||||
reply_probability = max(
|
||||
(current_willing - 0.45) * 2,
|
||||
self.min_reply_willing
|
||||
)
|
||||
|
||||
# 降低目标低频群组回复概率
|
||||
down_frequency_rate = max(
|
||||
1.0,
|
||||
global_config.down_frequency_rate
|
||||
)
|
||||
if group_id in config.talk_frequency_down_groups:
|
||||
reply_probability = reply_probability / down_frequency_rate
|
||||
print(f"表情包, 当前意愿: {current_willing}")
|
||||
|
||||
print(f"放大系数_interested_rate: {global_config.response_interested_rate_amplifier}")
|
||||
interested_rate *= global_config.response_interested_rate_amplifier #放大回复兴趣度
|
||||
if interested_rate > 0.4:
|
||||
# print(f"兴趣度: {interested_rate}, 当前意愿: {current_willing}")
|
||||
current_willing += interested_rate-0.4
|
||||
|
||||
current_willing *= global_config.response_willing_amplifier #放大回复意愿
|
||||
# print(f"放大系数_willing: {global_config.response_willing_amplifier}, 当前意愿: {current_willing}")
|
||||
|
||||
reply_probability = max((current_willing - 0.45) * 2, 0)
|
||||
|
||||
# 检查群组权限(如果是群聊)
|
||||
if chat_stream.group_info:
|
||||
if chat_stream.group_info.group_id in config.talk_frequency_down_groups:
|
||||
reply_probability = reply_probability / global_config.down_frequency_rate
|
||||
|
||||
reply_probability = min(reply_probability, 1)
|
||||
|
||||
self.group_reply_willing[group_id] = min(current_willing, 3.0)
|
||||
logger.debug(f"当前群组{group_id}回复概率:{reply_probability}")
|
||||
if reply_probability < 0:
|
||||
reply_probability = 0
|
||||
|
||||
self.chat_reply_willing[chat_id] = min(current_willing, 3.0)
|
||||
return reply_probability
|
||||
|
||||
def change_reply_willing_sent(self, group_id: int):
|
||||
"""开始思考后降低群组的回复意愿"""
|
||||
current_willing = self.group_reply_willing.get(group_id, 0)
|
||||
self.group_reply_willing[group_id] = max(0, current_willing - 2)
|
||||
|
||||
def change_reply_willing_after_sent(self, group_id: int):
|
||||
"""发送消息后提高群组的回复意愿"""
|
||||
current_willing = self.group_reply_willing.get(group_id, 0)
|
||||
if current_willing < 1:
|
||||
self.group_reply_willing[group_id] = min(1, current_willing + 0.2)
|
||||
|
||||
|
||||
def change_reply_willing_sent(self, chat_stream:ChatStream):
|
||||
"""开始思考后降低聊天流的回复意愿"""
|
||||
stream = chat_stream
|
||||
if stream:
|
||||
current_willing = self.chat_reply_willing.get(stream.stream_id, 0)
|
||||
self.chat_reply_willing[stream.stream_id] = max(0, current_willing - 2)
|
||||
|
||||
def change_reply_willing_after_sent(self,chat_stream:ChatStream):
|
||||
"""发送消息后提高聊天流的回复意愿"""
|
||||
stream = chat_stream
|
||||
if stream:
|
||||
current_willing = self.chat_reply_willing.get(stream.stream_id, 0)
|
||||
if current_willing < 1:
|
||||
self.chat_reply_willing[stream.stream_id] = min(1, current_willing + 0.2)
|
||||
|
||||
async def ensure_started(self):
|
||||
"""确保衰减任务已启动"""
|
||||
if not self._started:
|
||||
@@ -109,6 +111,5 @@ class WillingManager:
|
||||
self._decay_task = asyncio.create_task(self._decay_reply_willing())
|
||||
self._started = True
|
||||
|
||||
|
||||
# 创建全局实例
|
||||
willing_manager = WillingManager()
|
||||
willing_manager = WillingManager()
|
||||
@@ -239,7 +239,7 @@ class Hippocampus:
|
||||
time_info += f"是从 {earliest_str} 到 {latest_str} 的对话:\n"
|
||||
|
||||
for msg in messages:
|
||||
input_text += f"{msg['text']}\n"
|
||||
input_text += f"{msg['detailed_plain_text']}\n"
|
||||
|
||||
logger.debug(input_text)
|
||||
|
||||
|
||||
@@ -7,10 +7,11 @@ from typing import Tuple, Union
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
from nonebot import get_driver
|
||||
|
||||
import base64
|
||||
from PIL import Image
|
||||
import io
|
||||
from ...common.database import Database
|
||||
from ..chat.config import global_config
|
||||
from ..chat.utils_image import compress_base64_image_by_scale
|
||||
|
||||
driver = get_driver()
|
||||
config = driver.config
|
||||
@@ -432,3 +433,78 @@ class LLM_request:
|
||||
response_handler=embedding_handler
|
||||
)
|
||||
return embedding
|
||||
|
||||
def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 1024 * 1024) -> str:
|
||||
"""压缩base64格式的图片到指定大小
|
||||
Args:
|
||||
base64_data: base64编码的图片数据
|
||||
target_size: 目标文件大小(字节),默认0.8MB
|
||||
Returns:
|
||||
str: 压缩后的base64图片数据
|
||||
"""
|
||||
try:
|
||||
# 将base64转换为字节数据
|
||||
image_data = base64.b64decode(base64_data)
|
||||
|
||||
# 如果已经小于目标大小,直接返回原图
|
||||
if len(image_data) <= 2*1024*1024:
|
||||
return base64_data
|
||||
|
||||
# 将字节数据转换为图片对象
|
||||
img = Image.open(io.BytesIO(image_data))
|
||||
|
||||
# 获取原始尺寸
|
||||
original_width, original_height = img.size
|
||||
|
||||
# 计算缩放比例
|
||||
scale = min(1.0, (target_size / len(image_data)) ** 0.5)
|
||||
|
||||
# 计算新的尺寸
|
||||
new_width = int(original_width * scale)
|
||||
new_height = int(original_height * scale)
|
||||
|
||||
# 创建内存缓冲区
|
||||
output_buffer = io.BytesIO()
|
||||
|
||||
# 如果是GIF,处理所有帧
|
||||
if getattr(img, "is_animated", False):
|
||||
frames = []
|
||||
for frame_idx in range(img.n_frames):
|
||||
img.seek(frame_idx)
|
||||
new_frame = img.copy()
|
||||
new_frame = new_frame.resize((new_width//2, new_height//2), Image.Resampling.LANCZOS) # 动图折上折
|
||||
frames.append(new_frame)
|
||||
|
||||
# 保存到缓冲区
|
||||
frames[0].save(
|
||||
output_buffer,
|
||||
format='GIF',
|
||||
save_all=True,
|
||||
append_images=frames[1:],
|
||||
optimize=True,
|
||||
duration=img.info.get('duration', 100),
|
||||
loop=img.info.get('loop', 0)
|
||||
)
|
||||
else:
|
||||
# 处理静态图片
|
||||
resized_img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||||
|
||||
# 保存到缓冲区,保持原始格式
|
||||
if img.format == 'PNG' and img.mode in ('RGBA', 'LA'):
|
||||
resized_img.save(output_buffer, format='PNG', optimize=True)
|
||||
else:
|
||||
resized_img.save(output_buffer, format='JPEG', quality=95, optimize=True)
|
||||
|
||||
# 获取压缩后的数据并转换为base64
|
||||
compressed_data = output_buffer.getvalue()
|
||||
logger.success(f"压缩图片: {original_width}x{original_height} -> {new_width}x{new_height}")
|
||||
logger.info(f"压缩前大小: {len(image_data)/1024:.1f}KB, 压缩后大小: {len(compressed_data)/1024:.1f}KB")
|
||||
|
||||
return base64.b64encode(compressed_data).decode('utf-8')
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"压缩图片失败: {str(e)}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return base64_data
|
||||
|
||||
|
||||
Reference in New Issue
Block a user