import json import re import warnings import gradio as gr import os import signal import sys import requests import tomli from dotenv import load_dotenv from src.common.database import db try: from src.common.logger import get_module_logger logger = get_module_logger("emoji_reviewer") except ImportError: from loguru import logger # 检查并创建日志目录 log_dir = "logs/emoji_reviewer" if not os.path.exists(log_dir): os.makedirs(log_dir, exist_ok=True) # 配置控制台输出格式 logger.remove() # 移除默认的处理器 logger.add(sys.stderr, format="{time:MM-DD HH:mm} | emoji_reviewer | {message}") # 添加控制台输出 logger.add( "logs/emoji_reviewer/{time:YYYY-MM-DD}.log", rotation="00:00", format="{time:MM-DD HH:mm} | emoji_reviewer | {message}" ) logger.warning("检测到src.common.logger并未导入,将使用默认loguru作为日志记录器") logger.warning("如果你是用的是低版本(0.5.13)麦麦,请忽略此警告") # 忽略 gradio 版本警告 warnings.filterwarnings("ignore", message="IMPORTANT: You are using gradio version.*") root_dir = os.path.dirname(os.path.abspath(__file__)) bot_config_path = os.path.join(root_dir, "config/bot_config.toml") if os.path.exists(bot_config_path): with open(bot_config_path, "rb") as f: try: toml_dict = tomli.load(f) embedding_config = toml_dict['model']['embedding'] embedding_name = embedding_config["name"] embedding_provider = embedding_config["provider"] except tomli.TOMLDecodeError as e: logger.critical(f"配置文件bot_config.toml填写有误,请检查第{e.lineno}行第{e.colno}处:{e.msg}") exit(1) except KeyError: logger.critical("配置文件bot_config.toml缺少model.embedding设置,请补充后再编辑表情包") exit(1) else: logger.critical(f"没有找到配置文件{bot_config_path}") exit(1) env_path = os.path.join(root_dir, ".env") if not os.path.exists(env_path): logger.critical(f"没有找到环境变量文件{env_path}") exit(1) load_dotenv(env_path) tags_choices = ["无", "包括", "排除"] tags = { "reviewed": ("已审查", "排除"), "blacklist": ("黑名单", "排除"), } format_choices = ["包括", "无"] formats = ["jpg", "jpeg", "png", "gif", "其它"] def signal_handler(signum, frame): """处理 Ctrl+C 信号""" logger.info("收到终止信号,正在关闭 Gradio 服务器...") sys.exit(0) # 注册信号处理器 signal.signal(signal.SIGINT, signal_handler) required_fields = ["_id", "path", "description", "hash", *tags.keys()] # 修复拼写错误的时候记得把这里的一起改了 emojis_db = list(db.emoji.find({}, {k: 1 for k in required_fields})) emoji_filtered = [] emoji_show = None max_num = 20 neglect_update = 0 async def get_embedding(text): try: base_url = os.environ.get(f"{embedding_provider}_BASE_URL") if base_url.endswith('/'): url = base_url + 'embeddings' else: url = base_url + '/embeddings' key = os.environ.get(f"{embedding_provider}_KEY") headers = { "Authorization": f"Bearer {key}", "Content-Type": "application/json" } payload = { "model": embedding_name, "input": text, "encoding_format": "float" } response = requests.post(url, headers=headers, data=json.dumps(payload)) if response.status_code == 200: result = response.json() embedding = result["data"][0]["embedding"] return embedding else: return f"网络错误{response.status_code}" except Exception: return None def set_max_num(slider): global max_num max_num = slider def filter_emojis(tag_filters, format_filters): global emoji_filtered e_filtered = emojis_db format_include = [] for format, value in format_filters.items(): if value: format_include.append(format) if len(format_include) == 0: return [] for tag, value in tag_filters.items(): if value == "包括": e_filtered = [d for d in e_filtered if tag in d] elif value == "排除": e_filtered = [d for d in e_filtered if tag not in d] if '其它' in format_include: exclude = [f for f in formats if f not in format_include] if exclude: ff = '|'.join(exclude) compiled_pattern = re.compile(rf"\.({ff})$", re.IGNORECASE) e_filtered = [d for d in e_filtered if not compiled_pattern.search(d.get("path", ""), re.IGNORECASE)] else: ff = '|'.join(format_include) compiled_pattern = re.compile(rf"\.({ff})$", re.IGNORECASE) e_filtered = [d for d in e_filtered if compiled_pattern.search(d.get("path", ""), re.IGNORECASE)] emoji_filtered = e_filtered def update_gallery(from_latest, *filter_values): global emoji_filtered tf = filter_values[:len(tags)] ff = filter_values[len(tags):] filter_emojis({k: v for k, v in zip(tags.keys(), tf)}, {k: v for k, v in zip(formats, ff)}) if from_latest: emoji_filtered.reverse() if len(emoji_filtered) > max_num: info = f"已筛选{len(emoji_filtered)}个表情包中的{max_num}个。" emoji_filtered = emoji_filtered[:max_num] else: info = f"已筛选{len(emoji_filtered)}个表情包。" global emoji_show emoji_show = None return [gr.update(value=[], selected_index=None, allow_preview=False), info] def update_gallery2(): thumbnails = [e.get("path", "") for e in emoji_filtered] return gr.update(value=thumbnails, allow_preview=True) def on_select(evt: gr.SelectData, *tag_values): new_index = evt.index print(new_index) global emoji_show, neglect_update if new_index is None: emoji_show = None targets = [] for current_value in tag_values: if current_value: neglect_update += 1 targets.append(False) else: targets.append(gr.update()) return [ gr.update(selected_index=new_index), "", *targets ] else: emoji_show = emoji_filtered[new_index] targets = [] neglect_update = 0 for current_value, tag in zip(tag_values, tags.keys()): target = tag in emoji_show if current_value != target: neglect_update += 1 targets.append(target) else: targets.append(gr.update()) return [ gr.update(selected_index=new_index), emoji_show.get("description", ""), *targets ] def desc_change(desc, edited): if emoji_show and desc != emoji_show.get("description", ""): if edited: return [gr.update(), True] else: return ["(尚未保存)", True] if edited: return ["", False] else: return [gr.update(), False] def revert_desc(): if emoji_show: return emoji_show.get("description", "") else: return "" async def save_desc(desc): if emoji_show: try: yield ["正在构建embedding,请勿关闭页面...", gr.update(interactive=False), gr.update(interactive=False)] embedding = await get_embedding(desc) if embedding is None or isinstance(embedding, str): yield [ f"获取embeddings失败!{embedding}", gr.update(interactive=True), gr.update(interactive=True) ] else: e_id = emoji_show["_id"] update_dict = {"$set": {"embedding": embedding, "description": desc}} db.emoji.update_one({"_id": e_id}, update_dict) e_hash = emoji_show["hash"] update_dict = {"$set": {"description": desc}} db.images.update_one({"hash": e_hash}, update_dict) db.image_descriptions.update_one({"hash": e_hash}, update_dict) emoji_show["description"] = desc logger.info(f'Update description and embeddings: {e_id}(hash={hash})') yield ["保存完成", gr.update(value=desc, interactive=True), gr.update(interactive=True)] except Exception as e: yield [ f"出现异常: {e}", gr.update(interactive=True), gr.update(interactive=True) ] else: yield ["没有选中表情包", gr.update()] def change_tag(*tag_values): if not emoji_show: return gr.update() global neglect_update if neglect_update > 0: neglect_update -= 1 return gr.update() set_dict = {} unset_dict = {} e_id = emoji_show["_id"] for value, tag in zip(tag_values, tags.keys()): if value: if tag not in emoji_show: set_dict[tag] = True emoji_show[tag] = True logger.info(f'Add tag "{tag}" to {e_id}') else: if tag in emoji_show: unset_dict[tag] = "" del emoji_show[tag] logger.info(f'Delete tag "{tag}" from {e_id}') update_dict = {"$set": set_dict, "$unset": unset_dict} db.emoji.update_one({"_id": e_id}, update_dict) return "已更新标签状态" with gr.Blocks(title="MaimBot表情包审查器") as app: desc_edit = gr.State(value=False) gr.Markdown( value=""" # MaimBot表情包审查器 """ ) gr.Markdown(value="---") # 添加分割线 gr.Markdown(value=""" ## 审查器说明\n 该审查器用于人工修正识图模型对表情包的识别偏差,以及管理表情包黑名单:\n 每一个表情包都有描述以及“已审查”和“黑名单”两个标签。描述可以编辑并保存。“黑名单”标签可以禁止麦麦使用该表情包。\n 作者:遗世紫丁香(HexatomicRing) """) gr.Markdown(value="---") with gr.Row(): with gr.Column(scale=2): info_label = gr.Markdown("") gallery = gr.Gallery(label="表情包列表", columns=4, rows=6) description = gr.Textbox(label="描述", interactive=True) description_label = gr.Markdown("") tag_boxes = { tag: gr.Checkbox(label=name, interactive=True) for tag, (name, _) in tags.items() } with gr.Row(): revert_btn = gr.Button("还原描述") save_btn = gr.Button("保存描述") with gr.Column(scale=1): max_num_slider = gr.Slider(label="最大显示数量", minimum=1, maximum=500, value=max_num, interactive=True) check_from_latest = gr.Checkbox(label="由新到旧", interactive=True) tag_filters = { tag: gr.Dropdown(tags_choices, value=value, label=f"{name}筛选") for tag, (name, value) in tags.items() } gr.Markdown(value="---") gr.Markdown(value="格式筛选:") format_filters = { f: gr.Checkbox(label=f, value=True) for f in formats } refresh_btn = gr.Button("刷新筛选") filters = list(tag_filters.values()) + list(format_filters.values()) max_num_slider.change(set_max_num, max_num_slider, None) description.change(desc_change, [description, desc_edit], [description_label, desc_edit]) for component in filters: component.change( fn=update_gallery, inputs=[check_from_latest, *filters], outputs=[gallery, info_label], preprocess=False ).then( fn=update_gallery2, inputs=None, outputs=gallery) refresh_btn.click( fn=update_gallery, inputs=[check_from_latest, *filters], outputs=[gallery, info_label], preprocess=False ).then( fn=update_gallery2, inputs=None, outputs=gallery) gallery.select(fn=on_select, inputs=list(tag_boxes.values()), outputs=[gallery, description, *tag_boxes.values()]) revert_btn.click(fn=revert_desc, inputs=None, outputs=description) save_btn.click(fn=save_desc, inputs=description, outputs=[description_label, description, save_btn]) for box in tag_boxes.values(): box.change(fn=change_tag, inputs=list(tag_boxes.values()), outputs=description_label) app.load( fn=update_gallery, inputs=[check_from_latest, *filters], outputs=[gallery, info_label], preprocess=False ).then( fn=update_gallery2, inputs=None, outputs=gallery) app.queue().launch( server_name="0.0.0.0", inbrowser=True, share=False, server_port=7001, debug=True, quiet=True, )