diff --git a/emoji_reviewer.py b/emoji_reviewer.py new file mode 100644 index 000000000..96da18ea8 --- /dev/null +++ b/emoji_reviewer.py @@ -0,0 +1,364 @@ +import json +import re +import warnings +import gradio as gr +import os +import signal +import sys +import requests +import tomli + +from dotenv import load_dotenv +from src.common.database import db + +try: + from src.common.logger import get_module_logger + + logger = get_module_logger("emoji_reviewer") +except ImportError: + from loguru import logger + + # 检查并创建日志目录 + log_dir = "logs/emoji_reviewer" + if not os.path.exists(log_dir): + os.makedirs(log_dir, exist_ok=True) + # 配置控制台输出格式 + logger.remove() # 移除默认的处理器 + logger.add(sys.stderr, format="{time:MM-DD HH:mm} | emoji_reviewer | {message}") # 添加控制台输出 + logger.add("logs/emoji_reviewer/{time:YYYY-MM-DD}.log", rotation="00:00", format="{time:MM-DD HH:mm} | emoji_reviewer | {message}") + logger.warning("检测到src.common.logger并未导入,将使用默认loguru作为日志记录器") + logger.warning("如果你是用的是低版本(0.5.13)麦麦,请忽略此警告") +# 忽略 gradio 版本警告 +warnings.filterwarnings("ignore", message="IMPORTANT: You are using gradio version.*") + +root_dir = os.path.dirname(os.path.abspath(__file__)) +bot_config_path = os.path.join(root_dir, "config/bot_config.toml") +if os.path.exists(bot_config_path): + with open(bot_config_path, "rb") as f: + try: + toml_dict = tomli.load(f) + embedding_config = toml_dict['model']['embedding'] + embedding_name = embedding_config["name"] + embedding_provider = embedding_config["provider"] + except tomli.TOMLDecodeError as e: + logger.critical(f"配置文件bot_config.toml填写有误,请检查第{e.lineno}行第{e.colno}处:{e.msg}") + exit(1) + except KeyError as e: + logger.critical(f"配置文件bot_config.toml缺少model.embedding设置,请补充后再编辑表情包") + exit(1) +else: + logger.critical(f"没有找到配置文件{bot_config_path}") + exit(1) +env_path = os.path.join(root_dir, ".env.prod") +if not os.path.exists(env_path): + logger.critical(f"没有找到环境变量文件{env_path}") + exit(1) +load_dotenv(env_path) + +tags_choices = ["无", "包括", "排除"] +tags = { + "reviewed": ("已审查", "排除"), + "blacklist": ("黑名单", "排除"), +} +format_choices = ["包括", "无"] +formats = ["jpg", "png", "gif"] + + +def signal_handler(signum, frame): + """处理 Ctrl+C 信号""" + logger.info("收到终止信号,正在关闭 Gradio 服务器...") + sys.exit(0) + + +# 注册信号处理器 +signal.signal(signal.SIGINT, signal_handler) +required_fields = ["_id", "path", "description", "hash", *tags.keys()] # 修复拼写错误的时候记得把这里的一起改了 + +emojis_db = list(db.emoji.find({}, {k: 1 for k in required_fields})) +emoji_filtered = [] +emoji_show = None + +max_num = 20 +neglect_update = 0 + + +async def get_embedding(text): + try: + base_url = os.environ.get(f"{embedding_provider}_BASE_URL") + if base_url.endswith('/'): + url = base_url + 'embeddings' + else: + url = base_url + '/embeddings' + key = os.environ.get(f"{embedding_provider}_KEY") + headers = { + "Authorization": f"Bearer {key}", + "Content-Type": "application/json" + } + payload = { + "model": embedding_name, + "input": text, + "encoding_format": "float" + } + response = requests.post(url, headers=headers, data=json.dumps(payload)) + if response.status_code == 200: + result = response.json() + embedding = result["data"][0]["embedding"] + return embedding + else: + return f"网络错误{response.status_code}" + except: + return None + + +def set_max_num(slider): + global max_num + max_num = slider + + +def filter_emojis(tag_filters, format_filters): + global emoji_filtered + e_filtered = emojis_db + + format_include = [] + for format, value in format_filters.items(): + if value: + format_include.append(format) + + if len(format_include) == 0: + return [] + + for tag, value in tag_filters.items(): + if value == "包括": + e_filtered = [d for d in e_filtered if tag in d] + elif value == "排除": + e_filtered = [d for d in e_filtered if tag not in d] + + if len(format_include) > 0: + ff = '|'.join(format_include) + pattern = rf"\.({ff})$" + e_filtered = [d for d in e_filtered if re.search(pattern, d.get("path", ""), re.IGNORECASE)] + + emoji_filtered = e_filtered + + +def update_gallery(from_latest, *filter_values): + global emoji_filtered + tf = filter_values[:len(tags)] + ff = filter_values[len(tags):] + filter_emojis({k: v for k, v in zip(tags.keys(), tf)}, {k: v for k, v in zip(formats, ff)}) + if from_latest: + emoji_filtered.reverse() + if len(emoji_filtered) > max_num: + info = f"已筛选{len(emoji_filtered)}个表情包中的{max_num}个。" + emoji_filtered = emoji_filtered[:max_num] + else: + info = f"已筛选{len(emoji_filtered)}个表情包。" + global emoji_show + emoji_show = None + return [gr.update(value=[], selected_index=None, allow_preview=False), info] + + +def update_gallery2(): + thumbnails = [e.get("path", "") for e in emoji_filtered] + return gr.update(value=thumbnails, allow_preview=True) + + +def on_select(evt: gr.SelectData, *tag_values): + new_index = evt.index + print(new_index) + global emoji_show, neglect_update + if new_index is None: + emoji_show = None + targets = [] + for current_value, tag in zip(tag_values, tags.keys()): + if current_value: + neglect_update += 1 + targets.append(False) + else: + targets.append(gr.update()) + return [ + gr.update(selected_index=new_index), + "", + *targets + ] + else: + emoji_show = emoji_filtered[new_index] + targets = [] + neglect_update = 0 + for current_value, tag in zip(tag_values, tags.keys()): + target = tag in emoji_show + if current_value != target: + neglect_update += 1 + targets.append(target) + else: + targets.append(gr.update()) + return [ + gr.update(selected_index=new_index), + emoji_show.get("description", ""), + *targets + ] + + +def desc_change(desc, edited): + if emoji_show and desc != emoji_show.get("description", ""): + if edited: + return [gr.update(), True] + else: + return ["(尚未保存)", True] + if edited: + return ["", False] + else: + return [gr.update(), False] + + +def revert_desc(): + if emoji_show: + return emoji_show.get("description", "") + else: + return "" + + +async def save_desc(desc): + if emoji_show: + try: + yield ["正在构建embedding,请勿关闭页面...", gr.update(interactive=False), gr.update(interactive=False)] + embedding = await get_embedding(desc) + if embedding is None or isinstance(embedding, str): + yield [f"获取embeddings失败!{embedding}", gr.update(interactive=True), gr.update(interactive=True)] + else: + e_id = emoji_show["_id"] + update_dict = {"$set": {"embedding": embedding, "description": desc}} + db.emoji.update_one({"_id": e_id}, update_dict) + + e_hash = emoji_show["hash"] + update_dict = {"$set": {"description": desc}} + db.images.update_one({"hash": e_hash}, update_dict) + db.image_descriptions.update_one({"hash": e_hash}, update_dict) + emoji_show["description"] = desc + + logger.info(f'Update description and embeddings: {e_id}(hash={hash})') + yield ["保存完成", gr.update(value=desc, interactive=True), gr.update(interactive=True)] + except Exception as e: + yield [f"出现异常: {e}", gr.update(interactive=True), gr.update(interactive=True)] + + else: + yield ["没有选中表情包", gr.update()] + + +def change_tag(*tag_values): + if not emoji_show: + return gr.update() + global neglect_update + if neglect_update > 0: + neglect_update -= 1 + return gr.update() + set_dict = {} + unset_dict = {} + e_id = emoji_show["_id"] + for value, tag in zip(tag_values, tags.keys()): + if value: + if tag not in emoji_show: + set_dict[tag] = True + emoji_show[tag] = True + logger.info(f'Add tag "{tag}" to {e_id}') + else: + if tag in emoji_show: + unset_dict[tag] = "" + del emoji_show[tag] + logger.info(f'Delete tag "{tag}" from {e_id}') + + update_dict = {"$set": set_dict, "$unset": unset_dict} + db.emoji.update_one({"_id": e_id}, update_dict) + return "已更新标签状态" + + +with gr.Blocks(title="MaimBot表情包审查器") as app: + desc_edit = gr.State(value=False) + gr.Markdown( + value=""" + # MaimBot表情包审查器 + """ + ) + gr.Markdown(value="---") # 添加分割线 + gr.Markdown(value=""" + ## 审查器说明\n + 该审查器用于人工修正识图模型对表情包的识别偏差,以及管理表情包黑名单:\n + 每一个表情包都有描述以及“已审查”和“黑名单”两个标签。描述可以编辑并保存。“黑名单”标签可以禁止麦麦使用该表情包。\n + 作者:遗世紫丁香(HexatomicRing) + """) + gr.Markdown(value="---") + + with gr.Row(): + with gr.Column(scale=2): + info_label = gr.Markdown("") + gallery = gr.Gallery(label="表情包列表", columns=4, rows=6) + description = gr.Textbox(label="描述", interactive=True) + description_label = gr.Markdown("") + tag_boxes = { + tag: gr.Checkbox(label=name, interactive=True) + for tag, (name, _) in tags.items() + } + + with gr.Row(): + revert_btn = gr.Button("还原描述") + save_btn = gr.Button("保存描述") + + with gr.Column(scale=1): + max_num_slider = gr.Slider(label="最大显示数量", minimum=1, maximum=500, value=max_num, interactive=True) + check_from_latest = gr.Checkbox(label="由新到旧", interactive=True) + tag_filters = { + tag: gr.Dropdown(tags_choices, value=value, label=f"{name}筛选") + for tag, (name, value) in tags.items() + } + gr.Markdown(value="---") + gr.Markdown(value="格式筛选:") + format_filters = { + f: gr.Checkbox(label=f, value=True) + for f in formats + } + refresh_btn = gr.Button("刷新筛选") + filters = list(tag_filters.values()) + list(format_filters.values()) + + max_num_slider.change(set_max_num, max_num_slider, None) + description.change(desc_change, [description, desc_edit], [description_label, desc_edit]) + for component in filters: + component.change( + fn=update_gallery, + inputs=[check_from_latest, *filters], + outputs=[gallery, info_label], + preprocess=False + ).then( + fn=update_gallery2, + inputs=None, + outputs=gallery) + refresh_btn.click( + fn=update_gallery, + inputs=[check_from_latest, *filters], + outputs=[gallery, info_label], + preprocess=False + ).then( + fn=update_gallery2, + inputs=None, + outputs=gallery) + gallery.select(fn=on_select, inputs=list(tag_boxes.values()), outputs=[gallery, description, *tag_boxes.values()]) + revert_btn.click(fn=revert_desc, inputs=None, outputs=description) + save_btn.click(fn=save_desc, inputs=description, outputs=[description_label, description, save_btn]) + for k, v in tag_boxes.items(): + v.change(fn=change_tag, inputs=list(tag_boxes.values()), outputs=description_label) + app.load( + fn=update_gallery, + inputs=[check_from_latest, *filters], + outputs=[gallery, info_label], + preprocess=False + ).then( + fn=update_gallery2, + inputs=None, + outputs=gallery) + app.queue().launch( + server_name="0.0.0.0", + inbrowser=True, + share=False, + server_port=7001, + debug=True, + quiet=True, + ) diff --git a/src/plugins/chat/emoji_manager.py b/src/plugins/chat/emoji_manager.py index 57c2b0b85..8dea30b60 100644 --- a/src/plugins/chat/emoji_manager.py +++ b/src/plugins/chat/emoji_manager.py @@ -118,7 +118,9 @@ class EmojiManager: try: # 获取所有表情包 - all_emojis = list(db.emoji.find({}, {"_id": 1, "path": 1, "embedding": 1, "description": 1})) + all_emojis = [e for e in + db.emoji.find({}, {"_id": 1, "path": 1, "embedding": 1, "description": 1, "blacklist": 1}) + if 'blacklist' not in e] if not all_emojis: logger.warning("数据库中没有任何表情包") @@ -173,7 +175,7 @@ class EmojiManager: logger.error(f"[错误] 获取表情包失败: {str(e)}") return None - async def _get_emoji_discription(self, image_base64: str) -> str: + async def _get_emoji_description(self, image_base64: str) -> str: """获取表情包的标签,使用image_manager的描述生成功能""" try: @@ -248,20 +250,14 @@ class EmojiManager: if existing_emoji_by_path["_id"] != existing_emoji_by_hash["_id"]: logger.error(f"[错误] 表情包已存在但记录不一致: {filename}") db.emoji.delete_one({"_id": existing_emoji_by_path["_id"]}) - db.emoji.update_one( - {"_id": existing_emoji_by_hash["_id"]}, {"$set": {"path": image_path, "filename": filename}} - ) - existing_emoji_by_hash["path"] = image_path - existing_emoji_by_hash["filename"] = filename - existing_emoji = existing_emoji_by_hash + db.emoji.delete_one({"_id": existing_emoji_by_hash["_id"]}) + existing_emoji = None + else: + existing_emoji = existing_emoji_by_hash elif existing_emoji_by_hash: logger.error(f"[错误] 表情包hash已存在但path不存在: {filename}") - db.emoji.update_one( - {"_id": existing_emoji_by_hash["_id"]}, {"$set": {"path": image_path, "filename": filename}} - ) - existing_emoji_by_hash["path"] = image_path - existing_emoji_by_hash["filename"] = filename - existing_emoji = existing_emoji_by_hash + db.emoji.delete_one({"_id": existing_emoji_by_hash["_id"]}) + existing_emoji = None elif existing_emoji_by_path: logger.error(f"[错误] 表情包path已存在但hash不存在: {filename}") db.emoji.delete_one({"_id": existing_emoji_by_path["_id"]}) @@ -273,7 +269,7 @@ class EmojiManager: if existing_emoji: # 即使表情包已存在,也检查是否需要同步到images集合 - description = existing_emoji.get("discription") + description = existing_emoji.get("description") # 检查是否在images集合中存在 existing_image = db.images.find_one({"hash": image_hash}) if not existing_image: @@ -298,7 +294,7 @@ class EmojiManager: description = existing_description else: # 获取表情包的描述 - description = await self._get_emoji_discription(image_base64) + description = await self._get_emoji_description(image_base64) if global_config.EMOJI_CHECK: check = await self._check_emoji(image_base64, image_format) @@ -316,7 +312,7 @@ class EmojiManager: "filename": filename, "path": image_path, "embedding": embedding, - "discription": description, + "description": description, "hash": image_hash, "timestamp": int(time.time()), } @@ -399,6 +395,11 @@ class EmojiManager: db.emoji.delete_one({"_id": emoji["_id"]}) removed_count += 1 + # 修复拼写错误 + if "discription" in emoji: + desc = emoji["discription"] + db.emoji.update_one({"_id": emoji["_id"]}, {"$unset": {"discription": ""}, "$set": {"description": desc}}) + except Exception as item_error: logger.error(f"[错误] 处理表情包记录时出错: {str(item_error)}") continue diff --git a/src/plugins/schedule/schedule_generator.py b/src/plugins/schedule/schedule_generator.py index d6ba165ee..b26b29549 100644 --- a/src/plugins/schedule/schedule_generator.py +++ b/src/plugins/schedule/schedule_generator.py @@ -97,14 +97,28 @@ class ScheduleGenerator: reg = r"\{(.|\r|\n)+\}" matched = re.search(reg, schedule_text)[0] schedule_dict = json.loads(matched) + self._check_schedule_validity(schedule_dict) return schedule_dict except json.JSONDecodeError: logger.exception("解析日程失败: {}".format(schedule_text)) return False + except ValueError as e: + logger.exception(f"解析日程失败: {str(e)}") + return False except Exception as e: logger.exception(f"解析日程发生错误:{str(e)}") return False + def _check_schedule_validity(self, schedule_dict: Dict[str, str]): + """检查日程是否合法""" + if not schedule_dict: + return + for time_str in schedule_dict.keys(): + try: + self._parse_time(time_str) + except ValueError: + raise ValueError("日程时间格式不正确") from None + def _parse_time(self, time_str: str) -> str: """解析时间字符串,转换为时间""" return datetime.datetime.strptime(time_str, "%H:%M")