Merge branch 'SengokuCola:debug' into debug

This commit is contained in:
Cookie987
2025-03-13 13:17:46 +08:00
committed by GitHub
18 changed files with 399 additions and 458 deletions

8
.github/workflows/ruff.yml vendored Normal file
View File

@@ -0,0 +1,8 @@
name: Ruff
on: [ push, pull_request ]
jobs:
ruff:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: astral-sh/ruff-action@v3

10
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,10 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.9.10
hooks:
# Run the linter.
- id: ruff
args: [ --fix ]
# Run the formatter.
- id: ruff-format

88
bot.py
View File

@@ -17,19 +17,6 @@ env_mask = {key: os.getenv(key) for key in os.environ}
uvicorn_server = None uvicorn_server = None
# 配置日志
log_path = os.path.join(os.getcwd(), "logs")
if not os.path.exists(log_path):
os.makedirs(log_path)
# 添加文件日志启用rotation和retention
logger.add(
os.path.join(log_path, "maimbot_{time:YYYY-MM-DD}.log"),
rotation="00:00", # 每天0点创建新文件
retention="30 days", # 保留30天的日志
level="INFO",
encoding="utf-8"
)
def easter_egg(): def easter_egg():
# 彩蛋 # 彩蛋
@@ -76,7 +63,7 @@ def init_env():
# 首先加载基础环境变量.env # 首先加载基础环境变量.env
if os.path.exists(".env"): if os.path.exists(".env"):
load_dotenv(".env",override=True) load_dotenv(".env", override=True)
logger.success("成功加载基础环境变量配置") logger.success("成功加载基础环境变量配置")
@@ -90,10 +77,7 @@ def load_env():
logger.success("加载开发环境变量配置") logger.success("加载开发环境变量配置")
load_dotenv(".env.dev", override=True) # override=True 允许覆盖已存在的环境变量 load_dotenv(".env.dev", override=True) # override=True 允许覆盖已存在的环境变量
fn_map = { fn_map = {"prod": prod, "dev": dev}
"prod": prod,
"dev": dev
}
env = os.getenv("ENVIRONMENT") env = os.getenv("ENVIRONMENT")
logger.info(f"[load_env] 当前的 ENVIRONMENT 变量值:{env}") logger.info(f"[load_env] 当前的 ENVIRONMENT 变量值:{env}")
@@ -109,28 +93,45 @@ def load_env():
logger.error(f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在") logger.error(f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在")
RuntimeError(f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在") RuntimeError(f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在")
def load_logger():
logger.remove() # 移除默认配置
if os.getenv("ENVIRONMENT") == "dev":
logger.add(
sys.stderr,
format="<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> <fg #777777>|</> <level>{level: <7}</level> <fg "
"#777777>|</> <cyan>{name:.<8}</cyan>:<cyan>{function:.<8}</cyan>:<cyan>{line: >4}</cyan> <fg "
"#777777>-</> <level>{message}</level>",
colorize=True,
level=os.getenv("LOG_LEVEL", "DEBUG"), # 根据环境设置日志级别默认为DEBUG
)
else:
logger.add(
sys.stderr,
format="<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> <fg #777777>|</> <level>{level: <7}</level> <fg "
"#777777>|</> <cyan>{name:.<8}</cyan>:<cyan>{function:.<8}</cyan>:<cyan>{line: >4}</cyan> <fg "
"#777777>-</> <level>{message}</level>",
colorize=True,
level=os.getenv("LOG_LEVEL", "INFO"), # 根据环境设置日志级别默认为INFO
filter=lambda record: "nonebot" not in record["name"]
)
def load_logger():
logger.remove()
# 配置日志基础路径
log_path = os.path.join(os.getcwd(), "logs")
if not os.path.exists(log_path):
os.makedirs(log_path)
current_env = os.getenv("ENVIRONMENT", "dev")
# 公共配置参数
log_level = os.getenv("LOG_LEVEL", "INFO" if current_env == "prod" else "DEBUG")
log_filter = lambda record: (
("nonebot" not in record["name"] or record["level"].no >= logger.level("ERROR").no)
if current_env == "prod"
else True
)
log_format = (
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> "
"<fg #777777>|</> <level>{level: <7}</level> "
"<fg #777777>|</> <cyan>{name:.<8}</cyan>:<cyan>{function:.<8}</cyan>:<cyan>{line: >4}</cyan> "
"<fg #777777>-</> <level>{message}</level>"
)
# 日志文件储存至/logs
logger.add(
os.path.join(log_path, "maimbot_{time:YYYY-MM-DD}.log"),
rotation="00:00",
retention="30 days",
format=log_format,
colorize=False,
level=log_level,
filter=log_filter,
encoding="utf-8",
)
# 终端输出
logger.add(sys.stderr, format=log_format, colorize=True, level=log_level, filter=log_filter)
def scan_provider(env_config: dict): def scan_provider(env_config: dict):
@@ -160,10 +161,7 @@ def scan_provider(env_config: dict):
# 检查每个 provider 是否同时存在 url 和 key # 检查每个 provider 是否同时存在 url 和 key
for provider_name, config in provider.items(): for provider_name, config in provider.items():
if config["url"] is None or config["key"] is None: if config["url"] is None or config["key"] is None:
logger.error( logger.error(f"provider 内容:{config}\nenv_config 内容:{env_config}")
f"provider 内容:{config}\n"
f"env_config 内容:{env_config}"
)
raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量") raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量")
@@ -192,7 +190,7 @@ async def uvicorn_main():
reload=os.getenv("ENVIRONMENT") == "dev", reload=os.getenv("ENVIRONMENT") == "dev",
timeout_graceful_shutdown=5, timeout_graceful_shutdown=5,
log_config=None, log_config=None,
access_log=False access_log=False,
) )
server = uvicorn.Server(config) server = uvicorn.Server(config)
uvicorn_server = server uvicorn_server = server
@@ -202,7 +200,7 @@ async def uvicorn_main():
def raw_main(): def raw_main():
# 利用 TZ 环境变量设定程序工作的时区 # 利用 TZ 环境变量设定程序工作的时区
# 仅保证行为一致,不依赖 localtime(),实际对生产环境几乎没有作用 # 仅保证行为一致,不依赖 localtime(),实际对生产环境几乎没有作用
if platform.system().lower() != 'windows': if platform.system().lower() != "windows":
time.tzset() time.tzset()
easter_egg() easter_egg()

Binary file not shown.

After

Width:  |  Height:  |  Size: 20 KiB

BIN
docs/avatars/default.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 36 KiB

1
docs/avatars/run.bat Normal file
View File

@@ -0,0 +1 @@
gource gource.log --user-image-dir docs/avatars/ --default-user-image docs/avatars/default.png

View File

@@ -22,9 +22,7 @@ def __create_database_instance():
if username and password: if username and password:
# 如果有用户名和密码,使用认证连接 # 如果有用户名和密码,使用认证连接
return MongoClient( return MongoClient(host, port, username=username, password=password, authSource=auth_source)
host, port, username=username, password=password, authSource=auth_source
)
# 否则使用无认证连接 # 否则使用无认证连接
return MongoClient(host, port) return MongoClient(host, port)

View File

@@ -7,7 +7,7 @@ from datetime import datetime
from typing import Dict, List from typing import Dict, List
from loguru import logger from loguru import logger
from typing import Optional from typing import Optional
from ..common.database import db
import customtkinter as ctk import customtkinter as ctk
from dotenv import load_dotenv from dotenv import load_dotenv
@@ -16,6 +16,8 @@ from dotenv import load_dotenv
current_dir = os.path.dirname(os.path.abspath(__file__)) current_dir = os.path.dirname(os.path.abspath(__file__))
# 获取项目根目录 # 获取项目根目录
root_dir = os.path.abspath(os.path.join(current_dir, '..', '..')) root_dir = os.path.abspath(os.path.join(current_dir, '..', '..'))
sys.path.insert(0, root_dir)
from src.common.database import db
# 加载环境变量 # 加载环境变量
if os.path.exists(os.path.join(root_dir, '.env.dev')): if os.path.exists(os.path.join(root_dir, '.env.dev')):

View File

@@ -3,8 +3,9 @@ import time
import os import os
from loguru import logger from loguru import logger
from nonebot import get_driver, on_message, require from nonebot import get_driver, on_message, on_notice, require
from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent, Message, MessageSegment,MessageEvent from nonebot.rule import to_me
from nonebot.adapters.onebot.v11 import Bot, GroupMessageEvent, Message, MessageSegment, MessageEvent, NoticeEvent
from nonebot.typing import T_State from nonebot.typing import T_State
from ..moods.moods import MoodManager # 导入情绪管理器 from ..moods.moods import MoodManager # 导入情绪管理器
@@ -39,6 +40,8 @@ logger.debug(f"正在唤醒{global_config.BOT_NICKNAME}......")
chat_bot = ChatBot() chat_bot = ChatBot()
# 注册消息处理器 # 注册消息处理器
msg_in = on_message(priority=5) msg_in = on_message(priority=5)
# 注册和bot相关的通知处理器
notice_matcher = on_notice(priority=1)
# 创建定时任务 # 创建定时任务
scheduler = require("nonebot_plugin_apscheduler").scheduler scheduler = require("nonebot_plugin_apscheduler").scheduler
@@ -95,19 +98,24 @@ async def _(bot: Bot, event: MessageEvent, state: T_State):
await chat_bot.handle_message(event, bot) await chat_bot.handle_message(event, bot)
@notice_matcher.handle()
async def _(bot: Bot, event: NoticeEvent, state: T_State):
logger.debug(f"收到通知:{event}")
await chat_bot.handle_notice(event, bot)
# 添加build_memory定时任务 # 添加build_memory定时任务
@scheduler.scheduled_job("interval", seconds=global_config.build_memory_interval, id="build_memory") @scheduler.scheduled_job("interval", seconds=global_config.build_memory_interval, id="build_memory")
async def build_memory_task(): async def build_memory_task():
"""每build_memory_interval秒执行一次记忆构建""" """每build_memory_interval秒执行一次记忆构建"""
logger.debug( logger.debug("[记忆构建]------------------------------------开始构建记忆--------------------------------------")
"[记忆构建]"
"------------------------------------开始构建记忆--------------------------------------")
start_time = time.time() start_time = time.time()
await hippocampus.operation_build_memory(chat_size=20) await hippocampus.operation_build_memory(chat_size=20)
end_time = time.time() end_time = time.time()
logger.success( logger.success(
f"[记忆构建]--------------------------记忆构建完成:耗时: {end_time - start_time:.2f} " f"[记忆构建]--------------------------记忆构建完成:耗时: {end_time - start_time:.2f} "
"秒-------------------------------------------") "秒-------------------------------------------"
)
@scheduler.scheduled_job("interval", seconds=global_config.forget_memory_interval, id="forget_memory") @scheduler.scheduled_job("interval", seconds=global_config.forget_memory_interval, id="forget_memory")

View File

@@ -7,6 +7,8 @@ from nonebot.adapters.onebot.v11 import (
GroupMessageEvent, GroupMessageEvent,
MessageEvent, MessageEvent,
PrivateMessageEvent, PrivateMessageEvent,
NoticeEvent,
PokeNotifyEvent,
) )
from ..memory_system.memory import hippocampus from ..memory_system.memory import hippocampus
@@ -25,6 +27,7 @@ from .relationship_manager import relationship_manager
from .storage import MessageStorage from .storage import MessageStorage
from .utils import calculate_typing_time, is_mentioned_bot_in_message from .utils import calculate_typing_time, is_mentioned_bot_in_message
from .utils_image import image_path_to_base64 from .utils_image import image_path_to_base64
from .utils_user import get_user_nickname, get_user_cardname, get_groupname
from .willing_manager import willing_manager # 导入意愿管理器 from .willing_manager import willing_manager # 导入意愿管理器
from .message_base import UserInfo, GroupInfo, Seg from .message_base import UserInfo, GroupInfo, Seg
@@ -46,6 +49,69 @@ class ChatBot:
if not self._started: if not self._started:
self._started = True self._started = True
async def handle_notice(self, event: NoticeEvent, bot: Bot) -> None:
"""处理收到的通知"""
# 戳一戳通知
if isinstance(event, PokeNotifyEvent):
# 用户屏蔽,不区分私聊/群聊
if event.user_id in global_config.ban_user_id:
return
reply_poke_probability = 1 # 回复戳一戳的概率
if random() < reply_poke_probability:
user_info = UserInfo(
user_id=event.user_id,
user_nickname=get_user_nickname(event.user_id) or None,
user_cardname=get_user_cardname(event.user_id) or None,
platform="qq",
)
group_info = GroupInfo(group_id=event.group_id, group_name=None, platform="qq")
message_cq = MessageRecvCQ(
message_id=None,
user_info=user_info,
raw_message=str("[戳了戳]你"),
group_info=group_info,
reply_message=None,
platform="qq",
)
message_json = message_cq.to_dict()
# 进入maimbot
message = MessageRecv(message_json)
groupinfo = message.message_info.group_info
userinfo = message.message_info.user_info
messageinfo = message.message_info
chat = await chat_manager.get_or_create_stream(
platform=messageinfo.platform, user_info=userinfo, group_info=groupinfo
)
message.update_chat_stream(chat)
await message.process()
bot_user_info = UserInfo(
user_id=global_config.BOT_QQ,
user_nickname=global_config.BOT_NICKNAME,
platform=messageinfo.platform,
)
response, raw_content = await self.gpt.generate_response(message)
if response:
for msg in response:
message_segment = Seg(type="text", data=msg)
bot_message = MessageSending(
message_id=None,
chat_stream=chat,
bot_user_info=bot_user_info,
sender_info=userinfo,
message_segment=message_segment,
reply=None,
is_head=False,
is_emoji=False,
)
message_manager.add_message(bot_message)
async def handle_message(self, event: MessageEvent, bot: Bot) -> None: async def handle_message(self, event: MessageEvent, bot: Bot) -> None:
"""处理收到的消息""" """处理收到的消息"""
@@ -55,6 +121,9 @@ class ChatBot:
if event.user_id in global_config.ban_user_id: if event.user_id in global_config.ban_user_id:
return return
if event.reply and hasattr(event.reply, 'sender') and hasattr(event.reply.sender, 'user_id') and event.reply.sender.user_id in global_config.ban_user_id:
logger.debug(f"跳过处理回复来自被ban用户 {event.reply.sender.user_id} 的消息")
return
# 处理私聊消息 # 处理私聊消息
if isinstance(event, PrivateMessageEvent): if isinstance(event, PrivateMessageEvent):
if not global_config.enable_friend_chat: # 私聊过滤 if not global_config.enable_friend_chat: # 私聊过滤
@@ -126,7 +195,7 @@ class ChatBot:
for word in global_config.ban_words: for word in global_config.ban_words:
if word in message.processed_plain_text: if word in message.processed_plain_text:
logger.info( logger.info(
f"[{chat.group_info.group_name if chat.group_info.group_id else '私聊'}]{userinfo.user_nickname}:{message.processed_plain_text}" f"[{chat.group_info.group_name if chat.group_info else '私聊'}]{userinfo.user_nickname}:{message.processed_plain_text}"
) )
logger.info(f"[过滤词识别]消息中含有{word}filtered") logger.info(f"[过滤词识别]消息中含有{word}filtered")
return return
@@ -135,7 +204,7 @@ class ChatBot:
for pattern in global_config.ban_msgs_regex: for pattern in global_config.ban_msgs_regex:
if re.search(pattern, message.raw_message): if re.search(pattern, message.raw_message):
logger.info( logger.info(
f"[{chat.group_info.group_name if chat.group_info.group_id else '私聊'}]{userinfo.user_nickname}:{message.raw_message}" f"[{chat.group_info.group_name if chat.group_info else '私聊'}]{userinfo.user_nickname}:{message.raw_message}"
) )
logger.info(f"[正则表达式过滤]消息匹配到{pattern}filtered") logger.info(f"[正则表达式过滤]消息匹配到{pattern}filtered")
return return
@@ -163,7 +232,7 @@ class ChatBot:
current_willing = willing_manager.get_willing(chat_stream=chat) current_willing = willing_manager.get_willing(chat_stream=chat)
logger.info( logger.info(
f"[{current_time}][{chat.group_info.group_name if chat.group_info.group_id else '私聊'}]{chat.user_info.user_nickname}:" f"[{current_time}][{chat.group_info.group_name if chat.group_info else '私聊'}]{chat.user_info.user_nickname}:"
f"{message.processed_plain_text}[回复意愿:{current_willing:.2f}][概率:{reply_probability * 100:.1f}%]" f"{message.processed_plain_text}[回复意愿:{current_willing:.2f}][概率:{reply_probability * 100:.1f}%]"
) )

View File

@@ -86,9 +86,12 @@ class CQCode:
else: else:
self.translated_segments = Seg(type="text", data="[图片]") self.translated_segments = Seg(type="text", data="[图片]")
elif self.type == "at": elif self.type == "at":
user_nickname = get_user_nickname(self.params.get("qq", "")) if self.params.get("qq") == "all":
self.translated_segments = Seg( self.translated_segments = Seg(type="text", data="@[全体成员]")
type="text", data=f"[@{user_nickname or '某人'}]" else:
user_nickname = get_user_nickname(self.params.get("qq", ""))
self.translated_segments = Seg(
type="text", data=f"[@{user_nickname or '某人'}]"
) )
elif self.type == "reply": elif self.type == "reply":
reply_segments = self.translate_reply() reply_segments = self.translate_reply()

View File

@@ -36,9 +36,9 @@ class EmojiManager:
def __init__(self): def __init__(self):
self._scan_task = None self._scan_task = None
self.vlm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=1000) self.vlm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=1000)
self.llm_emotion_judge = LLM_request(model=global_config.llm_emotion_judge, max_tokens=60, self.llm_emotion_judge = LLM_request(
temperature=0.8) # 更高的温度更少的token后续可以根据情绪来调整温度 model=global_config.llm_emotion_judge, max_tokens=60, temperature=0.8
) # 更高的温度更少的token后续可以根据情绪来调整温度
def _ensure_emoji_dir(self): def _ensure_emoji_dir(self):
"""确保表情存储目录存在""" """确保表情存储目录存在"""
@@ -75,23 +75,20 @@ class EmojiManager:
没有索引的话,数据库每次查询都需要扫描全部数据,建立索引后可以大大提高查询效率。 没有索引的话,数据库每次查询都需要扫描全部数据,建立索引后可以大大提高查询效率。
""" """
if 'emoji' not in db.list_collection_names(): if "emoji" not in db.list_collection_names():
db.create_collection('emoji') db.create_collection("emoji")
db.emoji.create_index([('embedding', '2dsphere')]) db.emoji.create_index([("embedding", "2dsphere")])
db.emoji.create_index([('filename', 1)], unique=True) db.emoji.create_index([("filename", 1)], unique=True)
def record_usage(self, emoji_id: str): def record_usage(self, emoji_id: str):
"""记录表情使用次数""" """记录表情使用次数"""
try: try:
self._ensure_db() self._ensure_db()
db.emoji.update_one( db.emoji.update_one({"_id": emoji_id}, {"$inc": {"usage_count": 1}})
{'_id': emoji_id},
{'$inc': {'usage_count': 1}}
)
except Exception as e: except Exception as e:
logger.error(f"记录表情使用失败: {str(e)}") logger.error(f"记录表情使用失败: {str(e)}")
async def get_emoji_for_text(self, text: str) -> Optional[Tuple[str,str]]: async def get_emoji_for_text(self, text: str) -> Optional[Tuple[str, str]]:
"""根据文本内容获取相关表情包 """根据文本内容获取相关表情包
Args: Args:
text: 输入文本 text: 输入文本
@@ -118,7 +115,7 @@ class EmojiManager:
try: try:
# 获取所有表情包 # 获取所有表情包
all_emojis = list(db.emoji.find({}, {'_id': 1, 'path': 1, 'embedding': 1, 'description': 1})) all_emojis = list(db.emoji.find({}, {"_id": 1, "path": 1, "embedding": 1, "description": 1}))
if not all_emojis: if not all_emojis:
logger.warning("数据库中没有任何表情包") logger.warning("数据库中没有任何表情包")
@@ -137,15 +134,14 @@ class EmojiManager:
# 计算所有表情包与输入文本的相似度 # 计算所有表情包与输入文本的相似度
emoji_similarities = [ emoji_similarities = [
(emoji, cosine_similarity(text_embedding, emoji.get('embedding', []))) (emoji, cosine_similarity(text_embedding, emoji.get("embedding", []))) for emoji in all_emojis
for emoji in all_emojis
] ]
# 按相似度降序排序 # 按相似度降序排序
emoji_similarities.sort(key=lambda x: x[1], reverse=True) emoji_similarities.sort(key=lambda x: x[1], reverse=True)
# 获取前3个最相似的表情包 # 获取前3个最相似的表情包
top_10_emojis = emoji_similarities[:10 if len(emoji_similarities) > 10 else len(emoji_similarities)] top_10_emojis = emoji_similarities[: 10 if len(emoji_similarities) > 10 else len(emoji_similarities)]
if not top_10_emojis: if not top_10_emojis:
logger.warning("未找到匹配的表情包") logger.warning("未找到匹配的表情包")
@@ -154,17 +150,15 @@ class EmojiManager:
# 从前3个中随机选择一个 # 从前3个中随机选择一个
selected_emoji, similarity = random.choice(top_10_emojis) selected_emoji, similarity = random.choice(top_10_emojis)
if selected_emoji and 'path' in selected_emoji: if selected_emoji and "path" in selected_emoji:
# 更新使用次数 # 更新使用次数
db.emoji.update_one( db.emoji.update_one({"_id": selected_emoji["_id"]}, {"$inc": {"usage_count": 1}})
{'_id': selected_emoji['_id']},
{'$inc': {'usage_count': 1}}
)
logger.success( logger.success(
f"找到匹配的表情包: {selected_emoji.get('description', '无描述')} (相似度: {similarity:.4f})") f"找到匹配的表情包: {selected_emoji.get('description', '无描述')} (相似度: {similarity:.4f})"
)
# 稍微改一下文本描述,不然容易产生幻觉,描述已经包含 表情包 了 # 稍微改一下文本描述,不然容易产生幻觉,描述已经包含 表情包 了
return selected_emoji['path'], "[ %s ]" % selected_emoji.get('description', '无描述') return selected_emoji["path"], "[ %s ]" % selected_emoji.get("description", "无描述")
except Exception as search_error: except Exception as search_error:
logger.error(f"搜索表情包失败: {str(search_error)}") logger.error(f"搜索表情包失败: {str(search_error)}")
@@ -176,7 +170,6 @@ class EmojiManager:
logger.error(f"获取表情包失败: {str(e)}") logger.error(f"获取表情包失败: {str(e)}")
return None return None
async def _get_emoji_discription(self, image_base64: str) -> str: async def _get_emoji_discription(self, image_base64: str) -> str:
"""获取表情包的标签使用image_manager的描述生成功能""" """获取表情包的标签使用image_manager的描述生成功能"""
@@ -184,7 +177,7 @@ class EmojiManager:
# 使用image_manager获取描述去掉前后的方括号和"表情包:"前缀 # 使用image_manager获取描述去掉前后的方括号和"表情包:"前缀
description = await image_manager.get_emoji_description(image_base64) description = await image_manager.get_emoji_description(image_base64)
# 去掉[表情包xxx]的格式,只保留描述内容 # 去掉[表情包xxx]的格式,只保留描述内容
description = description.strip('[]').replace('表情包:', '') description = description.strip("[]").replace("表情包:", "")
return description return description
except Exception as e: except Exception as e:
@@ -193,7 +186,7 @@ class EmojiManager:
async def _check_emoji(self, image_base64: str, image_format: str) -> str: async def _check_emoji(self, image_base64: str, image_format: str) -> str:
try: try:
prompt = f'这是一个表情包,请回答这个表情包是否满足\"{global_config.EMOJI_CHECK_PROMPT}\"的要求,是则回答是,否则回答否,不要出现任何其他内容' prompt = f'这是一个表情包,请回答这个表情包是否满足"{global_config.EMOJI_CHECK_PROMPT}"的要求,是则回答是,否则回答否,不要出现任何其他内容'
content, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format) content, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format)
logger.debug(f"输出描述: {content}") logger.debug(f"输出描述: {content}")
@@ -205,9 +198,9 @@ class EmojiManager:
async def _get_kimoji_for_text(self, text: str): async def _get_kimoji_for_text(self, text: str):
try: try:
prompt = f'这是{global_config.BOT_NICKNAME}将要发送的消息内容:\n{text}\n若要为其配上表情包,请你输出这个表情包应该表达怎样的情感,应该给人什么样的感觉,不要太简洁也不要太长,注意不要输出任何对消息内容的分析内容,只输出\"一种什么样的感觉\"中间的形容词部分。' prompt = f'这是{global_config.BOT_NICKNAME}将要发送的消息内容:\n{text}\n若要为其配上表情包,请你输出这个表情包应该表达怎样的情感,应该给人什么样的感觉,不要太简洁也不要太长,注意不要输出任何对消息内容的分析内容,只输出"一种什么样的感觉"中间的形容词部分。'
content, _ = await self.llm_emotion_judge.generate_response_async(prompt,temperature=1.5) content, _ = await self.llm_emotion_judge.generate_response_async(prompt, temperature=1.5)
logger.info(f"输出描述: {content}") logger.info(f"输出描述: {content}")
return content return content
@@ -222,8 +215,9 @@ class EmojiManager:
os.makedirs(emoji_dir, exist_ok=True) os.makedirs(emoji_dir, exist_ok=True)
# 获取所有支持的图片文件 # 获取所有支持的图片文件
files_to_process = [f for f in os.listdir(emoji_dir) if files_to_process = [
f.lower().endswith(('.jpg', '.jpeg', '.png', '.gif'))] f for f in os.listdir(emoji_dir) if f.lower().endswith((".jpg", ".jpeg", ".png", ".gif"))
]
for filename in files_to_process: for filename in files_to_process:
image_path = os.path.join(emoji_dir, filename) image_path = os.path.join(emoji_dir, filename)
@@ -238,35 +232,31 @@ class EmojiManager:
image_hash = hashlib.md5(image_bytes).hexdigest() image_hash = hashlib.md5(image_bytes).hexdigest()
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
# 检查是否已经注册过 # 检查是否已经注册过
existing_emoji = db['emoji'].find_one({'filename': filename}) existing_emoji = db["emoji"].find_one({"hash": image_hash})
description = None description = None
if existing_emoji: if existing_emoji:
# 即使表情包已存在也检查是否需要同步到images集合 # 即使表情包已存在也检查是否需要同步到images集合
description = existing_emoji.get('discription') description = existing_emoji.get("discription")
# 检查是否在images集合中存在 # 检查是否在images集合中存在
existing_image = db.images.find_one({'hash': image_hash}) existing_image = db.images.find_one({"hash": image_hash})
if not existing_image: if not existing_image:
# 同步到images集合 # 同步到images集合
image_doc = { image_doc = {
'hash': image_hash, "hash": image_hash,
'path': image_path, "path": image_path,
'type': 'emoji', "type": "emoji",
'description': description, "description": description,
'timestamp': int(time.time()) "timestamp": int(time.time()),
} }
db.images.update_one( db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True)
{'hash': image_hash},
{'$set': image_doc},
upsert=True
)
# 保存描述到image_descriptions集合 # 保存描述到image_descriptions集合
image_manager._save_description_to_db(image_hash, description, 'emoji') image_manager._save_description_to_db(image_hash, description, "emoji")
logger.success(f"同步已存在的表情包到images集合: {filename}") logger.success(f"同步已存在的表情包到images集合: {filename}")
continue continue
# 检查是否在images集合中已有描述 # 检查是否在images集合中已有描述
existing_description = image_manager._get_description_from_db(image_hash, 'emoji') existing_description = image_manager._get_description_from_db(image_hash, "emoji")
if existing_description: if existing_description:
description = existing_description description = existing_description
@@ -274,11 +264,9 @@ class EmojiManager:
# 获取表情包的描述 # 获取表情包的描述
description = await self._get_emoji_discription(image_base64) description = await self._get_emoji_discription(image_base64)
if global_config.EMOJI_CHECK: if global_config.EMOJI_CHECK:
check = await self._check_emoji(image_base64, image_format) check = await self._check_emoji(image_base64, image_format)
if '' not in check: if "" not in check:
os.remove(image_path) os.remove(image_path)
logger.info(f"描述: {description}") logger.info(f"描述: {description}")
@@ -295,35 +283,30 @@ class EmojiManager:
# 准备数据库记录 # 准备数据库记录
emoji_record = { emoji_record = {
'filename': filename, "filename": filename,
'path': image_path, "path": image_path,
'embedding': embedding, "embedding": embedding,
'discription': description, "discription": description,
'hash': image_hash, "hash": image_hash,
'timestamp': int(time.time()) "timestamp": int(time.time()),
} }
# 保存到emoji数据库 # 保存到emoji数据库
db['emoji'].insert_one(emoji_record) db["emoji"].insert_one(emoji_record)
logger.success(f"注册新表情包: {filename}") logger.success(f"注册新表情包: {filename}")
logger.info(f"描述: {description}") logger.info(f"描述: {description}")
# 保存到images数据库 # 保存到images数据库
image_doc = { image_doc = {
'hash': image_hash, "hash": image_hash,
'path': image_path, "path": image_path,
'type': 'emoji', "type": "emoji",
'description': description, "description": description,
'timestamp': int(time.time()) "timestamp": int(time.time()),
} }
db.images.update_one( db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True)
{'hash': image_hash},
{'$set': image_doc},
upsert=True
)
# 保存描述到image_descriptions集合 # 保存描述到image_descriptions集合
image_manager._save_description_to_db(image_hash, description, 'emoji') image_manager._save_description_to_db(image_hash, description, "emoji")
logger.success(f"同步保存到images集合: {filename}") logger.success(f"同步保存到images集合: {filename}")
else: else:
logger.warning(f"跳过表情包: {filename}") logger.warning(f"跳过表情包: {filename}")
@@ -351,28 +334,35 @@ class EmojiManager:
for emoji in all_emojis: for emoji in all_emojis:
try: try:
if 'path' not in emoji: if "path" not in emoji:
logger.warning(f"发现无效记录缺少path字段ID: {emoji.get('_id', 'unknown')}") logger.warning(f"发现无效记录缺少path字段ID: {emoji.get('_id', 'unknown')}")
db.emoji.delete_one({'_id': emoji['_id']}) db.emoji.delete_one({"_id": emoji["_id"]})
removed_count += 1 removed_count += 1
continue continue
if 'embedding' not in emoji: if "embedding" not in emoji:
logger.warning(f"发现过时记录缺少embedding字段ID: {emoji.get('_id', 'unknown')}") logger.warning(f"发现过时记录缺少embedding字段ID: {emoji.get('_id', 'unknown')}")
db.emoji.delete_one({'_id': emoji['_id']}) db.emoji.delete_one({"_id": emoji["_id"]})
removed_count += 1 removed_count += 1
continue continue
# 检查文件是否存在 # 检查文件是否存在
if not os.path.exists(emoji['path']): if not os.path.exists(emoji["path"]):
logger.warning(f"表情包文件已被删除: {emoji['path']}") logger.warning(f"表情包文件已被删除: {emoji['path']}")
# 从数据库中删除记录 # 从数据库中删除记录
result = db.emoji.delete_one({'_id': emoji['_id']}) result = db.emoji.delete_one({"_id": emoji["_id"]})
if result.deleted_count > 0: if result.deleted_count > 0:
logger.debug(f"成功删除数据库记录: {emoji['_id']}") logger.debug(f"成功删除数据库记录: {emoji['_id']}")
removed_count += 1 removed_count += 1
else: else:
logger.error(f"删除数据库记录失败: {emoji['_id']}") logger.error(f"删除数据库记录失败: {emoji['_id']}")
continue
if "hash" not in emoji:
logger.warning(f"发现缺失记录缺少hash字段ID: {emoji.get('_id', 'unknown')}")
hash = hashlib.md5(open(emoji["path"], "rb").read()).hexdigest()
db.emoji.update_one({"_id": emoji["_id"]}, {"$set": {"hash": hash}})
except Exception as item_error: except Exception as item_error:
logger.error(f"处理表情包记录时出错: {str(item_error)}") logger.error(f"处理表情包记录时出错: {str(item_error)}")
continue continue
@@ -398,5 +388,3 @@ class EmojiManager:
# 创建全局单例 # 创建全局单例
emoji_manager = EmojiManager() emoji_manager = EmojiManager()

View File

@@ -23,8 +23,8 @@ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
@dataclass @dataclass
class Message(MessageBase): class Message(MessageBase):
chat_stream: ChatStream=None chat_stream: ChatStream = None
reply: Optional['Message'] = None reply: Optional["Message"] = None
detailed_plain_text: str = "" detailed_plain_text: str = ""
processed_plain_text: str = "" processed_plain_text: str = ""
@@ -35,7 +35,7 @@ class Message(MessageBase):
chat_stream: ChatStream, chat_stream: ChatStream,
user_info: UserInfo, user_info: UserInfo,
message_segment: Optional[Seg] = None, message_segment: Optional[Seg] = None,
reply: Optional['MessageRecv'] = None, reply: Optional["MessageRecv"] = None,
detailed_plain_text: str = "", detailed_plain_text: str = "",
processed_plain_text: str = "", processed_plain_text: str = "",
): ):
@@ -45,15 +45,11 @@ class Message(MessageBase):
message_id=message_id, message_id=message_id,
time=time, time=time,
group_info=chat_stream.group_info, group_info=chat_stream.group_info,
user_info=user_info user_info=user_info,
) )
# 调用父类初始化 # 调用父类初始化
super().__init__( super().__init__(message_info=message_info, message_segment=message_segment, raw_message=None)
message_info=message_info,
message_segment=message_segment,
raw_message=None
)
self.chat_stream = chat_stream self.chat_stream = chat_stream
# 文本处理相关属性 # 文本处理相关属性
@@ -74,41 +70,38 @@ class MessageRecv(Message):
Args: Args:
message_dict: MessageCQ序列化后的字典 message_dict: MessageCQ序列化后的字典
""" """
self.message_info = BaseMessageInfo.from_dict(message_dict.get('message_info', {})) self.message_info = BaseMessageInfo.from_dict(message_dict.get("message_info", {}))
message_segment = message_dict.get('message_segment', {}) message_segment = message_dict.get("message_segment", {})
if message_segment.get('data','') == '[json]': if message_segment.get("data", "") == "[json]":
# 提取json消息中的展示信息 # 提取json消息中的展示信息
pattern = r'\[CQ:json,data=(?P<json_data>.+?)\]' pattern = r"\[CQ:json,data=(?P<json_data>.+?)\]"
match = re.search(pattern, message_dict.get('raw_message','')) match = re.search(pattern, message_dict.get("raw_message", ""))
raw_json = html.unescape(match.group('json_data')) raw_json = html.unescape(match.group("json_data"))
try: try:
json_message = json.loads(raw_json) json_message = json.loads(raw_json)
except json.JSONDecodeError: except json.JSONDecodeError:
json_message = {} json_message = {}
message_segment['data'] = json_message.get('prompt','') message_segment["data"] = json_message.get("prompt", "")
self.message_segment = Seg.from_dict(message_dict.get('message_segment', {})) self.message_segment = Seg.from_dict(message_dict.get("message_segment", {}))
self.raw_message = message_dict.get('raw_message') self.raw_message = message_dict.get("raw_message")
# 处理消息内容 # 处理消息内容
self.processed_plain_text = "" # 初始化为空字符串 self.processed_plain_text = "" # 初始化为空字符串
self.detailed_plain_text = "" # 初始化为空字符串 self.detailed_plain_text = "" # 初始化为空字符串
self.is_emoji=False self.is_emoji = False
def update_chat_stream(self, chat_stream: ChatStream):
def update_chat_stream(self,chat_stream:ChatStream): self.chat_stream = chat_stream
self.chat_stream=chat_stream
async def process(self) -> None: async def process(self) -> None:
"""处理消息内容,生成纯文本和详细文本 """处理消息内容,生成纯文本和详细文本
这个方法必须在创建实例后显式调用,因为它包含异步操作。 这个方法必须在创建实例后显式调用,因为它包含异步操作。
""" """
self.processed_plain_text = await self._process_message_segments( self.processed_plain_text = await self._process_message_segments(self.message_segment)
self.message_segment
)
self.detailed_plain_text = self._generate_detailed_text() self.detailed_plain_text = self._generate_detailed_text()
async def _process_message_segments(self, segment: Seg) -> str: async def _process_message_segments(self, segment: Seg) -> str:
@@ -157,16 +150,12 @@ class MessageRecv(Message):
else: else:
return f"[{seg.type}:{str(seg.data)}]" return f"[{seg.type}:{str(seg.data)}]"
except Exception as e: except Exception as e:
logger.error( logger.error(f"处理消息段失败: {str(e)}, 类型: {seg.type}, 数据: {seg.data}")
f"处理消息段失败: {str(e)}, 类型: {seg.type}, 数据: {seg.data}"
)
return f"[处理失败的{seg.type}消息]" return f"[处理失败的{seg.type}消息]"
def _generate_detailed_text(self) -> str: def _generate_detailed_text(self) -> str:
"""生成详细文本,包含时间和用户信息""" """生成详细文本,包含时间和用户信息"""
time_str = time.strftime( time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.message_info.time))
"%m-%d %H:%M:%S", time.localtime(self.message_info.time)
)
user_info = self.message_info.user_info user_info = self.message_info.user_info
name = ( name = (
f"{user_info.user_nickname}(ta的昵称:{user_info.user_cardname},ta的id:{user_info.user_id})" f"{user_info.user_nickname}(ta的昵称:{user_info.user_cardname},ta的id:{user_info.user_id})"
@@ -257,16 +246,12 @@ class MessageProcessBase(Message):
else: else:
return f"[{seg.type}:{str(seg.data)}]" return f"[{seg.type}:{str(seg.data)}]"
except Exception as e: except Exception as e:
logger.error( logger.error(f"处理消息段失败: {str(e)}, 类型: {seg.type}, 数据: {seg.data}")
f"处理消息段失败: {str(e)}, 类型: {seg.type}, 数据: {seg.data}"
)
return f"[处理失败的{seg.type}消息]" return f"[处理失败的{seg.type}消息]"
def _generate_detailed_text(self) -> str: def _generate_detailed_text(self) -> str:
"""生成详细文本,包含时间和用户信息""" """生成详细文本,包含时间和用户信息"""
time_str = time.strftime( time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.message_info.time))
"%m-%d %H:%M:%S", time.localtime(self.message_info.time)
)
user_info = self.message_info.user_info user_info = self.message_info.user_info
name = ( name = (
f"{user_info.user_nickname}(ta的昵称:{user_info.user_cardname},ta的id:{user_info.user_id})" f"{user_info.user_nickname}(ta的昵称:{user_info.user_cardname},ta的id:{user_info.user_id})"
@@ -330,10 +315,11 @@ class MessageSending(MessageProcessBase):
self.is_head = is_head self.is_head = is_head
self.is_emoji = is_emoji self.is_emoji = is_emoji
def set_reply(self, reply: Optional["MessageRecv"]) -> None: def set_reply(self, reply: Optional["MessageRecv"] = None) -> None:
"""设置回复消息""" """设置回复消息"""
if reply: if reply:
self.reply = reply self.reply = reply
if self.reply:
self.reply_to_message_id = self.reply.message_info.message_id self.reply_to_message_id = self.reply.message_info.message_id
self.message_segment = Seg( self.message_segment = Seg(
type="seglist", type="seglist",
@@ -346,9 +332,7 @@ class MessageSending(MessageProcessBase):
async def process(self) -> None: async def process(self) -> None:
"""处理消息内容,生成纯文本和详细文本""" """处理消息内容,生成纯文本和详细文本"""
if self.message_segment: if self.message_segment:
self.processed_plain_text = await self._process_message_segments( self.processed_plain_text = await self._process_message_segments(self.message_segment)
self.message_segment
)
self.detailed_plain_text = self._generate_detailed_text() self.detailed_plain_text = self._generate_detailed_text()
@classmethod @classmethod
@@ -377,10 +361,7 @@ class MessageSending(MessageProcessBase):
def is_private_message(self) -> bool: def is_private_message(self) -> bool:
"""判断是否为私聊消息""" """判断是否为私聊消息"""
return ( return self.message_info.group_info is None or self.message_info.group_info.group_id is None
self.message_info.group_info is None
or self.message_info.group_info.group_id is None
)
@dataclass @dataclass

View File

@@ -65,6 +65,8 @@ class GroupInfo:
Returns: Returns:
GroupInfo: 新的实例 GroupInfo: 新的实例
""" """
if data.get('group_id') is None:
return None
return cls( return cls(
platform=data.get('platform'), platform=data.get('platform'),
group_id=data.get('group_id'), group_id=data.get('group_id'),
@@ -129,8 +131,8 @@ class BaseMessageInfo:
Returns: Returns:
BaseMessageInfo: 新的实例 BaseMessageInfo: 新的实例
""" """
group_info = GroupInfo(**data.get('group_info', {})) group_info = GroupInfo.from_dict(data.get('group_info', {}))
user_info = UserInfo(**data.get('user_info', {})) user_info = UserInfo.from_dict(data.get('user_info', {}))
return cls( return cls(
platform=data.get('platform'), platform=data.get('platform'),
message_id=data.get('message_id'), message_id=data.get('message_id'),
@@ -173,7 +175,7 @@ class MessageBase:
Returns: Returns:
MessageBase: 新的实例 MessageBase: 新的实例
""" """
message_info = BaseMessageInfo(**data.get('message_info', {})) message_info = BaseMessageInfo.from_dict(data.get('message_info', {}))
message_segment = Seg(**data.get('message_segment', {})) message_segment = Seg(**data.get('message_segment', {}))
raw_message = data.get('raw_message',None) raw_message = data.get('raw_message',None)
return cls( return cls(

View File

@@ -8,12 +8,14 @@ from .cq_code import cq_code_tool
from .utils_cq import parse_cq_code from .utils_cq import parse_cq_code
from .utils_user import get_groupname from .utils_user import get_groupname
from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase from .message_base import Seg, GroupInfo, UserInfo, BaseMessageInfo, MessageBase
# 禁用SSL警告 # 禁用SSL警告
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
#这个类是消息数据类,用于存储和管理消息数据。 # 这个类是消息数据类,用于存储和管理消息数据。
#它定义了消息的属性包括群组ID、用户ID、消息ID、原始消息内容、纯文本内容和时间戳。 # 它定义了消息的属性包括群组ID、用户ID、消息ID、原始消息内容、纯文本内容和时间戳。
#它还定义了两个辅助属性keywords用于提取消息的关键词is_plain_text用于判断消息是否为纯文本。 # 它还定义了两个辅助属性keywords用于提取消息的关键词is_plain_text用于判断消息是否为纯文本。
@dataclass @dataclass
class MessageCQ(MessageBase): class MessageCQ(MessageBase):
@@ -24,27 +26,17 @@ class MessageCQ(MessageBase):
- user_id: 发送者/接收者ID - user_id: 发送者/接收者ID
- platform: 平台标识(默认为"qq" - platform: 平台标识(默认为"qq"
""" """
def __init__( def __init__(
self, self, message_id: int, user_info: UserInfo, group_info: Optional[GroupInfo] = None, platform: str = "qq"
message_id: int,
user_info: UserInfo,
group_info: Optional[GroupInfo] = None,
platform: str = "qq"
): ):
# 构造基础消息信息 # 构造基础消息信息
message_info = BaseMessageInfo( message_info = BaseMessageInfo(
platform=platform, platform=platform, message_id=message_id, time=int(time.time()), group_info=group_info, user_info=user_info
message_id=message_id,
time=int(time.time()),
group_info=group_info,
user_info=user_info
) )
# 调用父类初始化message_segment 由子类设置 # 调用父类初始化message_segment 由子类设置
super().__init__( super().__init__(message_info=message_info, message_segment=None, raw_message=None)
message_info=message_info,
message_segment=None,
raw_message=None
)
@dataclass @dataclass
class MessageRecvCQ(MessageCQ): class MessageRecvCQ(MessageCQ):
@@ -80,7 +72,7 @@ class MessageRecvCQ(MessageCQ):
start = 0 start = 0
while True: while True:
cq_start = message.find('[CQ:', start) cq_start = message.find("[CQ:", start)
if cq_start == -1: if cq_start == -1:
if start < len(message): if start < len(message):
text = message[start:].strip() text = message[start:].strip()
@@ -93,20 +85,20 @@ class MessageRecvCQ(MessageCQ):
if text: if text:
cq_code_dict_list.append(parse_cq_code(text)) cq_code_dict_list.append(parse_cq_code(text))
cq_end = message.find(']', cq_start) cq_end = message.find("]", cq_start)
if cq_end == -1: if cq_end == -1:
text = message[cq_start:].strip() text = message[cq_start:].strip()
if text: if text:
cq_code_dict_list.append(parse_cq_code(text)) cq_code_dict_list.append(parse_cq_code(text))
break break
cq_code = message[cq_start:cq_end + 1] cq_code = message[cq_start : cq_end + 1]
cq_code_dict_list.append(parse_cq_code(cq_code)) cq_code_dict_list.append(parse_cq_code(cq_code))
start = cq_end + 1 start = cq_end + 1
# 转换CQ码为Seg对象 # 转换CQ码为Seg对象
for code_item in cq_code_dict_list: for code_item in cq_code_dict_list:
message_obj = cq_code_tool.cq_from_dict_to_class(code_item,msg=self,reply=reply_message) message_obj = cq_code_tool.cq_from_dict_to_class(code_item, msg=self, reply=reply_message)
if message_obj.translated_segments: if message_obj.translated_segments:
segments.append(message_obj.translated_segments) segments.append(message_obj.translated_segments)
@@ -115,59 +107,58 @@ class MessageRecvCQ(MessageCQ):
return segments[0] return segments[0]
# 否则返回seglist类型的Seg # 否则返回seglist类型的Seg
return Seg(type='seglist', data=segments) return Seg(type="seglist", data=segments)
def to_dict(self) -> Dict: def to_dict(self) -> Dict:
"""转换为字典格式,包含所有必要信息""" """转换为字典格式,包含所有必要信息"""
base_dict = super().to_dict() base_dict = super().to_dict()
return base_dict return base_dict
@dataclass @dataclass
class MessageSendCQ(MessageCQ): class MessageSendCQ(MessageCQ):
"""QQ发送消息类用于将Seg对象转换为raw_message""" """QQ发送消息类用于将Seg对象转换为raw_message"""
def __init__( def __init__(self, data: Dict):
self,
data: Dict
):
# 调用父类初始化 # 调用父类初始化
message_info = BaseMessageInfo.from_dict(data.get('message_info', {})) message_info = BaseMessageInfo.from_dict(data.get("message_info", {}))
message_segment = Seg.from_dict(data.get('message_segment', {})) message_segment = Seg.from_dict(data.get("message_segment", {}))
super().__init__( super().__init__(
message_info.message_id, message_info.message_id,
message_info.user_info, message_info.user_info,
message_info.group_info if message_info.group_info else None, message_info.group_info if message_info.group_info else None,
message_info.platform message_info.platform,
) )
self.message_segment = message_segment self.message_segment = message_segment
self.raw_message = self._generate_raw_message() self.raw_message = self._generate_raw_message()
def _generate_raw_message(self, ) -> str: def _generate_raw_message(
self,
) -> str:
"""将Seg对象转换为raw_message""" """将Seg对象转换为raw_message"""
segments = [] segments = []
# 处理消息段 # 处理消息段
if self.message_segment.type == 'seglist': if self.message_segment.type == "seglist":
for seg in self.message_segment.data: for seg in self.message_segment.data:
segments.append(self._seg_to_cq_code(seg)) segments.append(self._seg_to_cq_code(seg))
else: else:
segments.append(self._seg_to_cq_code(self.message_segment)) segments.append(self._seg_to_cq_code(self.message_segment))
return ''.join(segments) return "".join(segments)
def _seg_to_cq_code(self, seg: Seg) -> str: def _seg_to_cq_code(self, seg: Seg) -> str:
"""将单个Seg对象转换为CQ码字符串""" """将单个Seg对象转换为CQ码字符串"""
if seg.type == 'text': if seg.type == "text":
return str(seg.data) return str(seg.data)
elif seg.type == 'image': elif seg.type == "image":
return cq_code_tool.create_image_cq_base64(seg.data) return cq_code_tool.create_image_cq_base64(seg.data)
elif seg.type == 'emoji': elif seg.type == "emoji":
return cq_code_tool.create_emoji_cq_base64(seg.data) return cq_code_tool.create_emoji_cq_base64(seg.data)
elif seg.type == 'at': elif seg.type == "at":
return f"[CQ:at,qq={seg.data}]" return f"[CQ:at,qq={seg.data}]"
elif seg.type == 'reply': elif seg.type == "reply":
return cq_code_tool.create_reply_cq(int(seg.data)) return cq_code_tool.create_reply_cq(int(seg.data))
else: else:
return f"[{seg.data}]" return f"[{seg.data}]"

View File

@@ -13,9 +13,11 @@ from nonebot import get_driver
from ...common.database import db from ...common.database import db
from ..chat.config import global_config from ..chat.config import global_config
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
driver = get_driver() driver = get_driver()
config = driver.config config = driver.config
class ImageManager: class ImageManager:
_instance = None _instance = None
IMAGE_DIR = "data" # 图像存储根目录 IMAGE_DIR = "data" # 图像存储根目录
@@ -40,20 +42,20 @@ class ImageManager:
def _ensure_image_collection(self): def _ensure_image_collection(self):
"""确保images集合存在并创建索引""" """确保images集合存在并创建索引"""
if 'images' not in db.list_collection_names(): if "images" not in db.list_collection_names():
db.create_collection('images') db.create_collection("images")
# 创建索引 # 创建索引
db.images.create_index([('hash', 1)], unique=True) db.images.create_index([("hash", 1)], unique=True)
db.images.create_index([('url', 1)]) db.images.create_index([("url", 1)])
db.images.create_index([('path', 1)]) db.images.create_index([("path", 1)])
def _ensure_description_collection(self): def _ensure_description_collection(self):
"""确保image_descriptions集合存在并创建索引""" """确保image_descriptions集合存在并创建索引"""
if 'image_descriptions' not in db.list_collection_names(): if "image_descriptions" not in db.list_collection_names():
db.create_collection('image_descriptions') db.create_collection("image_descriptions")
# 创建索引 # 创建索引
db.image_descriptions.create_index([('hash', 1)], unique=True) db.image_descriptions.create_index([("hash", 1)], unique=True)
db.image_descriptions.create_index([('type', 1)]) db.image_descriptions.create_index([("type", 1)])
def _get_description_from_db(self, image_hash: str, description_type: str) -> Optional[str]: def _get_description_from_db(self, image_hash: str, description_type: str) -> Optional[str]:
"""从数据库获取图片描述 """从数据库获取图片描述
@@ -65,11 +67,8 @@ class ImageManager:
Returns: Returns:
Optional[str]: 描述文本如果不存在则返回None Optional[str]: 描述文本如果不存在则返回None
""" """
result= db.image_descriptions.find_one({ result = db.image_descriptions.find_one({"hash": image_hash, "type": description_type})
'hash': image_hash, return result["description"] if result else None
'type': description_type
})
return result['description'] if result else None
def _save_description_to_db(self, image_hash: str, description: str, description_type: str) -> None: def _save_description_to_db(self, image_hash: str, description: str, description_type: str) -> None:
"""保存图片描述到数据库 """保存图片描述到数据库
@@ -80,77 +79,11 @@ class ImageManager:
description_type: 描述类型 ('emoji''image') description_type: 描述类型 ('emoji''image')
""" """
db.image_descriptions.update_one( db.image_descriptions.update_one(
{'hash': image_hash, 'type': description_type}, {"hash": image_hash, "type": description_type},
{ {"$set": {"description": description, "timestamp": int(time.time())}},
'$set': { upsert=True,
'description': description,
'timestamp': int(time.time())
}
},
upsert=True
) )
async def save_image(self,
image_data: Union[str, bytes],
url: str = None,
description: str = None,
is_base64: bool = False) -> Optional[str]:
"""保存图像
Args:
image_data: 图像数据(base64字符串或字节)
url: 图像URL
description: 图像描述
is_base64: image_data是否为base64格式
Returns:
str: 保存后的文件路径,失败返回None
"""
try:
# 转换为字节格式
if is_base64:
if isinstance(image_data, str):
image_bytes = base64.b64decode(image_data)
else:
return None
else:
if isinstance(image_data, bytes):
image_bytes = image_data
else:
return None
# 计算哈希值
image_hash = hashlib.md5(image_bytes).hexdigest()
image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
# 查重
existing = db.images.find_one({'hash': image_hash})
if existing:
return existing['path']
# 生成文件名和路径
timestamp = int(time.time())
filename = f"{timestamp}_{image_hash[:8]}.{image_format}"
file_path = os.path.join(self.IMAGE_DIR, filename)
# 保存文件
with open(file_path, "wb") as f:
f.write(image_bytes)
# 保存到数据库
image_doc = {
'hash': image_hash,
'path': file_path,
'url': url,
'description': description,
'timestamp': timestamp
}
db.images.insert_one(image_doc)
return file_path
except Exception as e:
logger.error(f"保存图像失败: {str(e)}")
return None
async def get_image_by_url(self, url: str) -> Optional[str]: async def get_image_by_url(self, url: str) -> Optional[str]:
"""根据URL获取图像路径(带查重) """根据URL获取图像路径(带查重)
Args: Args:
@@ -160,9 +93,9 @@ class ImageManager:
""" """
try: try:
# 先查找是否已存在 # 先查找是否已存在
existing = db.images.find_one({'url': url}) existing = db.images.find_one({"url": url})
if existing: if existing:
return existing['path'] return existing["path"]
# 下载图像 # 下载图像
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
@@ -176,63 +109,6 @@ class ImageManager:
logger.error(f"获取图像失败: {str(e)}") logger.error(f"获取图像失败: {str(e)}")
return None return None
async def get_base64_by_url(self, url: str) -> Optional[str]:
"""根据URL获取base64(带查重)
Args:
url: 图像URL
Returns:
str: base64字符串,失败返回None
"""
try:
image_path = await self.get_image_by_url(url)
if not image_path:
return None
with open(image_path, 'rb') as f:
image_bytes = f.read()
return base64.b64encode(image_bytes).decode('utf-8')
except Exception as e:
logger.error(f"获取base64失败: {str(e)}")
return None
def check_url_exists(self, url: str) -> bool:
"""检查URL是否已存在
Args:
url: 图像URL
Returns:
bool: 是否存在
"""
return db.images.find_one({'url': url}) is not None
def check_hash_exists(self, image_data: Union[str, bytes], is_base64: bool = False) -> bool:
"""检查图像是否已存在
Args:
image_data: 图像数据(base64或字节)
is_base64: 是否为base64格式
Returns:
bool: 是否存在
"""
try:
if is_base64:
if isinstance(image_data, str):
image_bytes = base64.b64decode(image_data)
else:
return False
else:
if isinstance(image_data, bytes):
image_bytes = image_data
else:
return False
image_hash = hashlib.md5(image_bytes).hexdigest()
return db.images.find_one({'hash': image_hash}) is not None
except Exception as e:
logger.error(f"检查哈希失败: {str(e)}")
return False
async def get_emoji_description(self, image_base64: str) -> str: async def get_emoji_description(self, image_base64: str) -> str:
"""获取表情包描述,带查重和保存功能""" """获取表情包描述,带查重和保存功能"""
try: try:
@@ -242,7 +118,7 @@ class ImageManager:
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
# 查询缓存的描述 # 查询缓存的描述
cached_description = self._get_description_from_db(image_hash, 'emoji') cached_description = self._get_description_from_db(image_hash, "emoji")
if cached_description: if cached_description:
logger.info(f"缓存表情包描述: {cached_description}") logger.info(f"缓存表情包描述: {cached_description}")
return f"[表情包:{cached_description}]" return f"[表情包:{cached_description}]"
@@ -251,12 +127,19 @@ class ImageManager:
prompt = "这是一个表情包,使用中文简洁的描述一下表情包的内容和表情包所表达的情感" prompt = "这是一个表情包,使用中文简洁的描述一下表情包的内容和表情包所表达的情感"
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format) description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
cached_description = self._get_description_from_db(image_hash, "emoji")
if cached_description:
logger.warning(f"虽然生成了描述,但找到缓存表情包描述: {cached_description}")
return f"[表情包:{cached_description}]"
# 根据配置决定是否保存图片 # 根据配置决定是否保存图片
if global_config.EMOJI_SAVE: if global_config.EMOJI_SAVE:
# 生成文件名和路径 # 生成文件名和路径
timestamp = int(time.time()) timestamp = int(time.time())
filename = f"{timestamp}_{image_hash[:8]}.{image_format}" filename = f"{timestamp}_{image_hash[:8]}.{image_format}"
file_path = os.path.join(self.IMAGE_DIR, 'emoji',filename) if not os.path.exists(os.path.join(self.IMAGE_DIR, "emoji")):
os.makedirs(os.path.join(self.IMAGE_DIR, "emoji"))
file_path = os.path.join(self.IMAGE_DIR, "emoji", filename)
try: try:
# 保存文件 # 保存文件
@@ -265,23 +148,19 @@ class ImageManager:
# 保存到数据库 # 保存到数据库
image_doc = { image_doc = {
'hash': image_hash, "hash": image_hash,
'path': file_path, "path": file_path,
'type': 'emoji', "type": "emoji",
'description': description, "description": description,
'timestamp': timestamp "timestamp": timestamp,
} }
db.images.update_one( db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True)
{'hash': image_hash},
{'$set': image_doc},
upsert=True
)
logger.success(f"保存表情包: {file_path}") logger.success(f"保存表情包: {file_path}")
except Exception as e: except Exception as e:
logger.error(f"保存表情包文件失败: {str(e)}") logger.error(f"保存表情包文件失败: {str(e)}")
# 保存描述到数据库 # 保存描述到数据库
self._save_description_to_db(image_hash, description, 'emoji') self._save_description_to_db(image_hash, description, "emoji")
return f"[表情包:{description}]" return f"[表情包:{description}]"
except Exception as e: except Exception as e:
@@ -298,14 +177,20 @@ class ImageManager:
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() image_format = Image.open(io.BytesIO(image_bytes)).format.lower()
# 查询缓存的描述 # 查询缓存的描述
cached_description = self._get_description_from_db(image_hash, 'image') cached_description = self._get_description_from_db(image_hash, "image")
if cached_description: if cached_description:
print("图片描述缓存中") print("图片描述缓存中")
return f"[图片:{cached_description}]" return f"[图片:{cached_description}]"
# 调用AI获取描述 # 调用AI获取描述
prompt = "请用中文描述这张图片的内容。如果有文字请把文字都描述出来。并尝试猜测这个图片的含义。最多200个字。" prompt = (
"请用中文描述这张图片的内容。如果有文字请把文字都描述出来。并尝试猜测这个图片的含义。最多200个字。"
)
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format) description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
cached_description = self._get_description_from_db(image_hash, "emoji")
if cached_description:
logger.info(f"缓存图片描述: {cached_description}")
return f"[图片:{cached_description}]"
print(f"描述是{description}") print(f"描述是{description}")
@@ -318,7 +203,9 @@ class ImageManager:
# 生成文件名和路径 # 生成文件名和路径
timestamp = int(time.time()) timestamp = int(time.time())
filename = f"{timestamp}_{image_hash[:8]}.{image_format}" filename = f"{timestamp}_{image_hash[:8]}.{image_format}"
file_path = os.path.join(self.IMAGE_DIR,'image', filename) if not os.path.exists(os.path.join(self.IMAGE_DIR, "image")):
os.makedirs(os.path.join(self.IMAGE_DIR, "image"))
file_path = os.path.join(self.IMAGE_DIR, "image", filename)
try: try:
# 保存文件 # 保存文件
@@ -327,23 +214,19 @@ class ImageManager:
# 保存到数据库 # 保存到数据库
image_doc = { image_doc = {
'hash': image_hash, "hash": image_hash,
'path': file_path, "path": file_path,
'type': 'image', "type": "image",
'description': description, "description": description,
'timestamp': timestamp "timestamp": timestamp,
} }
db.images.update_one( db.images.update_one({"hash": image_hash}, {"$set": image_doc}, upsert=True)
{'hash': image_hash},
{'$set': image_doc},
upsert=True
)
logger.success(f"保存图片: {file_path}") logger.success(f"保存图片: {file_path}")
except Exception as e: except Exception as e:
logger.error(f"保存图片文件失败: {str(e)}") logger.error(f"保存图片文件失败: {str(e)}")
# 保存描述到数据库 # 保存描述到数据库
self._save_description_to_db(image_hash, description, 'image') self._save_description_to_db(image_hash, description, "image")
return f"[图片:{description}]" return f"[图片:{description}]"
except Exception as e: except Exception as e:
@@ -351,7 +234,6 @@ class ImageManager:
return "[图片]" return "[图片]"
# 创建全局单例 # 创建全局单例
image_manager = ImageManager() image_manager = ImageManager()
@@ -364,9 +246,9 @@ def image_path_to_base64(image_path: str) -> str:
str: base64编码的图片数据 str: base64编码的图片数据
""" """
try: try:
with open(image_path, 'rb') as f: with open(image_path, "rb") as f:
image_data = f.read() image_data = f.read()
return base64.b64encode(image_data).decode('utf-8') return base64.b64encode(image_data).decode("utf-8")
except Exception as e: except Exception as e:
logger.error(f"读取图片失败: {image_path}, 错误: {str(e)}") logger.error(f"读取图片失败: {image_path}, 错误: {str(e)}")
return None return None

View File

@@ -132,7 +132,7 @@ class LLM_request:
# 常见Error Code Mapping # 常见Error Code Mapping
error_code_mapping = { error_code_mapping = {
400: "参数不正确", 400: "参数不正确",
401: "API key 错误,认证失败", 401: "API key 错误,认证失败,请检查/config/bot_config.toml和.env.prod中的配置是否正确哦~",
402: "账号余额不足", 402: "账号余额不足",
403: "需要实名,或余额不足", 403: "需要实名,或余额不足",
404: "Not Found", 404: "Not Found",

View File

@@ -23,7 +23,7 @@ CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1
SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/ SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/
DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1 DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1
#定义你要用的api的base_url #定义你要用的api的key(需要去对应网站申请哦)
DEEP_SEEK_KEY= DEEP_SEEK_KEY=
CHAT_ANY_WHERE_KEY= CHAT_ANY_WHERE_KEY=
SILICONFLOW_KEY= SILICONFLOW_KEY=