Files
Mofox-Core/emoji_reviewer.py

383 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import re
import warnings
import gradio as gr
import os
import signal
import sys
import requests
import tomli
from dotenv import load_dotenv
from src.common.database import db
try:
from src.common.logger import get_module_logger
logger = get_module_logger("emoji_reviewer")
except ImportError:
from loguru import logger
# 检查并创建日志目录
log_dir = "logs/emoji_reviewer"
if not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
# 配置控制台输出格式
logger.remove() # 移除默认的处理器
logger.add(sys.stderr, format="{time:MM-DD HH:mm} | emoji_reviewer | {message}") # 添加控制台输出
logger.add(
"logs/emoji_reviewer/{time:YYYY-MM-DD}.log",
rotation="00:00",
format="{time:MM-DD HH:mm} | emoji_reviewer | {message}"
)
logger.warning("检测到src.common.logger并未导入将使用默认loguru作为日志记录器")
logger.warning("如果你是用的是低版本(0.5.13)麦麦,请忽略此警告")
# 忽略 gradio 版本警告
warnings.filterwarnings("ignore", message="IMPORTANT: You are using gradio version.*")
root_dir = os.path.dirname(os.path.abspath(__file__))
bot_config_path = os.path.join(root_dir, "config/bot_config.toml")
if os.path.exists(bot_config_path):
with open(bot_config_path, "rb") as f:
try:
toml_dict = tomli.load(f)
embedding_config = toml_dict['model']['embedding']
embedding_name = embedding_config["name"]
embedding_provider = embedding_config["provider"]
except tomli.TOMLDecodeError as e:
logger.critical(f"配置文件bot_config.toml填写有误请检查第{e.lineno}行第{e.colno}处:{e.msg}")
exit(1)
except KeyError:
logger.critical("配置文件bot_config.toml缺少model.embedding设置请补充后再编辑表情包")
exit(1)
else:
logger.critical(f"没有找到配置文件{bot_config_path}")
exit(1)
env_path = os.path.join(root_dir, ".env")
if not os.path.exists(env_path):
logger.critical(f"没有找到环境变量文件{env_path}")
exit(1)
load_dotenv(env_path)
tags_choices = ["", "包括", "排除"]
tags = {
"reviewed": ("已审查", "排除"),
"blacklist": ("黑名单", "排除"),
}
format_choices = ["包括", ""]
formats = ["jpg", "jpeg", "png", "gif", "其它"]
def signal_handler(signum, frame):
"""处理 Ctrl+C 信号"""
logger.info("收到终止信号,正在关闭 Gradio 服务器...")
sys.exit(0)
# 注册信号处理器
signal.signal(signal.SIGINT, signal_handler)
required_fields = ["_id", "path", "description", "hash", *tags.keys()] # 修复拼写错误的时候记得把这里的一起改了
emojis_db = list(db.emoji.find({}, {k: 1 for k in required_fields}))
emoji_filtered = []
emoji_show = None
max_num = 20
neglect_update = 0
async def get_embedding(text):
try:
base_url = os.environ.get(f"{embedding_provider}_BASE_URL")
if base_url.endswith('/'):
url = base_url + 'embeddings'
else:
url = base_url + '/embeddings'
key = os.environ.get(f"{embedding_provider}_KEY")
headers = {
"Authorization": f"Bearer {key}",
"Content-Type": "application/json"
}
payload = {
"model": embedding_name,
"input": text,
"encoding_format": "float"
}
response = requests.post(url, headers=headers, data=json.dumps(payload))
if response.status_code == 200:
result = response.json()
embedding = result["data"][0]["embedding"]
return embedding
else:
return f"网络错误{response.status_code}"
except Exception:
return None
def set_max_num(slider):
global max_num
max_num = slider
def filter_emojis(tag_filters, format_filters):
global emoji_filtered
e_filtered = emojis_db
format_include = []
for format, value in format_filters.items():
if value:
format_include.append(format)
if len(format_include) == 0:
return []
for tag, value in tag_filters.items():
if value == "包括":
e_filtered = [d for d in e_filtered if tag in d]
elif value == "排除":
e_filtered = [d for d in e_filtered if tag not in d]
if '其它' in format_include:
exclude = [f for f in formats if f not in format_include]
if exclude:
ff = '|'.join(exclude)
compiled_pattern = re.compile(rf"\.({ff})$", re.IGNORECASE)
e_filtered = [d for d in e_filtered if not compiled_pattern.search(d.get("path", ""), re.IGNORECASE)]
else:
ff = '|'.join(format_include)
compiled_pattern = re.compile(rf"\.({ff})$", re.IGNORECASE)
e_filtered = [d for d in e_filtered if compiled_pattern.search(d.get("path", ""), re.IGNORECASE)]
emoji_filtered = e_filtered
def update_gallery(from_latest, *filter_values):
global emoji_filtered
tf = filter_values[:len(tags)]
ff = filter_values[len(tags):]
filter_emojis({k: v for k, v in zip(tags.keys(), tf)}, {k: v for k, v in zip(formats, ff)})
if from_latest:
emoji_filtered.reverse()
if len(emoji_filtered) > max_num:
info = f"已筛选{len(emoji_filtered)}个表情包中的{max_num}个。"
emoji_filtered = emoji_filtered[:max_num]
else:
info = f"已筛选{len(emoji_filtered)}个表情包。"
global emoji_show
emoji_show = None
return [gr.update(value=[], selected_index=None, allow_preview=False), info]
def update_gallery2():
thumbnails = [e.get("path", "") for e in emoji_filtered]
return gr.update(value=thumbnails, allow_preview=True)
def on_select(evt: gr.SelectData, *tag_values):
new_index = evt.index
print(new_index)
global emoji_show, neglect_update
if new_index is None:
emoji_show = None
targets = []
for current_value in tag_values:
if current_value:
neglect_update += 1
targets.append(False)
else:
targets.append(gr.update())
return [
gr.update(selected_index=new_index),
"",
*targets
]
else:
emoji_show = emoji_filtered[new_index]
targets = []
neglect_update = 0
for current_value, tag in zip(tag_values, tags.keys()):
target = tag in emoji_show
if current_value != target:
neglect_update += 1
targets.append(target)
else:
targets.append(gr.update())
return [
gr.update(selected_index=new_index),
emoji_show.get("description", ""),
*targets
]
def desc_change(desc, edited):
if emoji_show and desc != emoji_show.get("description", ""):
if edited:
return [gr.update(), True]
else:
return ["(尚未保存)", True]
if edited:
return ["", False]
else:
return [gr.update(), False]
def revert_desc():
if emoji_show:
return emoji_show.get("description", "")
else:
return ""
async def save_desc(desc):
if emoji_show:
try:
yield ["正在构建embedding请勿关闭页面...", gr.update(interactive=False), gr.update(interactive=False)]
embedding = await get_embedding(desc)
if embedding is None or isinstance(embedding, str):
yield [
f"<span style='color: red;'>获取embeddings失败{embedding}</span>",
gr.update(interactive=True),
gr.update(interactive=True)
]
else:
e_id = emoji_show["_id"]
update_dict = {"$set": {"embedding": embedding, "description": desc}}
db.emoji.update_one({"_id": e_id}, update_dict)
e_hash = emoji_show["hash"]
update_dict = {"$set": {"description": desc}}
db.images.update_one({"hash": e_hash}, update_dict)
db.image_descriptions.update_one({"hash": e_hash}, update_dict)
emoji_show["description"] = desc
logger.info(f'Update description and embeddings: {e_id}(hash={hash})')
yield ["保存完成", gr.update(value=desc, interactive=True), gr.update(interactive=True)]
except Exception as e:
yield [
f"<span style='color: red;'>出现异常: {e}</span>",
gr.update(interactive=True),
gr.update(interactive=True)
]
else:
yield ["没有选中表情包", gr.update()]
def change_tag(*tag_values):
if not emoji_show:
return gr.update()
global neglect_update
if neglect_update > 0:
neglect_update -= 1
return gr.update()
set_dict = {}
unset_dict = {}
e_id = emoji_show["_id"]
for value, tag in zip(tag_values, tags.keys()):
if value:
if tag not in emoji_show:
set_dict[tag] = True
emoji_show[tag] = True
logger.info(f'Add tag "{tag}" to {e_id}')
else:
if tag in emoji_show:
unset_dict[tag] = ""
del emoji_show[tag]
logger.info(f'Delete tag "{tag}" from {e_id}')
update_dict = {"$set": set_dict, "$unset": unset_dict}
db.emoji.update_one({"_id": e_id}, update_dict)
return "已更新标签状态"
with gr.Blocks(title="MaimBot表情包审查器") as app:
desc_edit = gr.State(value=False)
gr.Markdown(
value="""
# MaimBot表情包审查器
"""
)
gr.Markdown(value="---") # 添加分割线
gr.Markdown(value="""
## 审查器说明\n
该审查器用于人工修正识图模型对表情包的识别偏差,以及管理表情包黑名单:\n
每一个表情包都有描述以及“已审查”和“黑名单”两个标签。描述可以编辑并保存。“黑名单”标签可以禁止麦麦使用该表情包。\n
作者:遗世紫丁香(HexatomicRing)
""")
gr.Markdown(value="---")
with gr.Row():
with gr.Column(scale=2):
info_label = gr.Markdown("")
gallery = gr.Gallery(label="表情包列表", columns=4, rows=6)
description = gr.Textbox(label="描述", interactive=True)
description_label = gr.Markdown("")
tag_boxes = {
tag: gr.Checkbox(label=name, interactive=True)
for tag, (name, _) in tags.items()
}
with gr.Row():
revert_btn = gr.Button("还原描述")
save_btn = gr.Button("保存描述")
with gr.Column(scale=1):
max_num_slider = gr.Slider(label="最大显示数量", minimum=1, maximum=500, value=max_num, interactive=True)
check_from_latest = gr.Checkbox(label="由新到旧", interactive=True)
tag_filters = {
tag: gr.Dropdown(tags_choices, value=value, label=f"{name}筛选")
for tag, (name, value) in tags.items()
}
gr.Markdown(value="---")
gr.Markdown(value="格式筛选:")
format_filters = {
f: gr.Checkbox(label=f, value=True)
for f in formats
}
refresh_btn = gr.Button("刷新筛选")
filters = list(tag_filters.values()) + list(format_filters.values())
max_num_slider.change(set_max_num, max_num_slider, None)
description.change(desc_change, [description, desc_edit], [description_label, desc_edit])
for component in filters:
component.change(
fn=update_gallery,
inputs=[check_from_latest, *filters],
outputs=[gallery, info_label],
preprocess=False
).then(
fn=update_gallery2,
inputs=None,
outputs=gallery)
refresh_btn.click(
fn=update_gallery,
inputs=[check_from_latest, *filters],
outputs=[gallery, info_label],
preprocess=False
).then(
fn=update_gallery2,
inputs=None,
outputs=gallery)
gallery.select(fn=on_select, inputs=list(tag_boxes.values()), outputs=[gallery, description, *tag_boxes.values()])
revert_btn.click(fn=revert_desc, inputs=None, outputs=description)
save_btn.click(fn=save_desc, inputs=description, outputs=[description_label, description, save_btn])
for box in tag_boxes.values():
box.change(fn=change_tag, inputs=list(tag_boxes.values()), outputs=description_label)
app.load(
fn=update_gallery,
inputs=[check_from_latest, *filters],
outputs=[gallery, info_label],
preprocess=False
).then(
fn=update_gallery2,
inputs=None,
outputs=gallery)
app.queue().launch(
server_name="0.0.0.0",
inbrowser=True,
share=False,
server_port=7001,
debug=True,
quiet=True,
)