From 97bddb83e71cfcd2a965ddaaef29c3970a4d21df Mon Sep 17 00:00:00 2001 From: tcmofashi Date: Wed, 5 Mar 2025 10:43:08 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E8=A1=A8=E6=83=85=E5=8C=85=E5=8C=B9?= =?UTF-8?q?=E9=85=8D=E4=BB=8E=E6=83=85=E7=BB=AA=E5=8C=B9=E9=85=8D=E6=94=B9?= =?UTF-8?q?=E6=88=90=E5=B5=8C=E5=85=A5=E7=9B=B8=E4=BC=BC=E5=BA=A6=E5=8C=B9?= =?UTF-8?q?=E9=85=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/plugins/chat/bot.py | 2 +- src/plugins/chat/emoji_manager.py | 84 +++++++++++++++++-------------- 2 files changed, 47 insertions(+), 39 deletions(-) diff --git a/src/plugins/chat/bot.py b/src/plugins/chat/bot.py index 6b0e76db5..f9488b96f 100644 --- a/src/plugins/chat/bot.py +++ b/src/plugins/chat/bot.py @@ -174,7 +174,7 @@ class ChatBot: bot_response_time = tinking_time_point if random() < global_config.emoji_chance: - emoji_path = await emoji_manager.get_emoji_for_emotion(emotion) + emoji_path = await emoji_manager.get_emoji_for_text(response) if emoji_path: emoji_cq = CQCode.create_emoji_cq(emoji_path) diff --git a/src/plugins/chat/emoji_manager.py b/src/plugins/chat/emoji_manager.py index a0164d065..aa0bc1fb5 100644 --- a/src/plugins/chat/emoji_manager.py +++ b/src/plugins/chat/emoji_manager.py @@ -15,11 +15,12 @@ import time from PIL import Image import io from loguru import logger +import traceback from nonebot import get_driver from ..chat.config import global_config from ..models.utils_model import LLM_request -from utils import get_embedding +from ..chat.utils import get_embedding driver = get_driver() config = driver.config @@ -39,7 +40,7 @@ class EmojiManager: def __init__(self): self.db = Database.get_instance() self._scan_task = None - self.llm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=50) + self.llm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=1000) def _ensure_emoji_dir(self): """确保表情存储目录存在""" @@ -98,45 +99,44 @@ class EmojiManager: if not text_embedding: logger.error("无法获取文本的embedding") return None - - # 使用embedding进行相似度搜索,获取最相似的3个表情包 - pipeline = [ - { - "$search": { - "index": "default", - "knnBeta": { - "vector": text_embedding, - "path": "embedding", - "k": 3 - } - } - } - ] try: - # 获取搜索结果 - results = list(self.db.db.emoji.aggregate(pipeline)) + # 获取所有表情包 + all_emojis = list(self.db.db.emoji.find({}, {'_id': 1, 'path': 1, 'embedding': 1, 'discription': 1})) - if not results: - logger.warning("未找到匹配的表情包,尝试随机选择") - # 如果没有匹配的表情,随机选择一个 - try: - emoji = self.db.db.emoji.aggregate([ - {'$sample': {'size': 1}} - ]).next() - if emoji and 'path' in emoji: - # 更新使用次数 - self.db.db.emoji.update_one( - {'_id': emoji['_id']}, - {'$inc': {'usage_count': 1}} - ) - return emoji['path'] - except StopIteration: - logger.error("数据库中没有任何表情") - return None + if not all_emojis: + logger.warning("数据库中没有任何表情包") + return None - # 从最相似的3个表情包中随机选择一个 - selected_emoji = random.choice(results) + # 计算余弦相似度并排序 + def cosine_similarity(v1, v2): + if not v1 or not v2: + return 0 + dot_product = sum(a * b for a, b in zip(v1, v2)) + norm_v1 = sum(a * a for a in v1) ** 0.5 + norm_v2 = sum(b * b for b in v2) ** 0.5 + if norm_v1 == 0 or norm_v2 == 0: + return 0 + return dot_product / (norm_v1 * norm_v2) + + # 计算所有表情包与输入文本的相似度 + emoji_similarities = [ + (emoji, cosine_similarity(text_embedding, emoji.get('embedding', []))) + for emoji in all_emojis + ] + + # 按相似度降序排序 + emoji_similarities.sort(key=lambda x: x[1], reverse=True) + + # 获取前3个最相似的表情包 + top_3_emojis = emoji_similarities[:3] + + if not top_3_emojis: + logger.warning("未找到匹配的表情包") + return None + + # 从前3个中随机选择一个 + selected_emoji, similarity = random.choice(top_3_emojis) if selected_emoji and 'path' in selected_emoji: # 更新使用次数 @@ -144,7 +144,7 @@ class EmojiManager: {'_id': selected_emoji['_id']}, {'$inc': {'usage_count': 1}} ) - logger.success(f"找到匹配的表情包: {selected_emoji.get('discription', '无描述')}") + logger.success(f"找到匹配的表情包: {selected_emoji.get('discription', '无描述')} (相似度: {similarity:.4f})") return selected_emoji['path'] except Exception as search_error: @@ -285,6 +285,14 @@ class EmojiManager: except Exception as e: logger.error(f"扫描表情包失败: {str(e)}") logger.error(traceback.format_exc()) + + async def _periodic_scan(self, interval_MINS: int = 10): + """定期扫描新表情包""" + while True: + print(f"\033[1;36m[表情包]\033[0m 开始扫描新表情包...") + await self.scan_new_emojis() + await asyncio.sleep(interval_MINS * 60) # 每600秒扫描一次 + def check_emoji_file_integrity(self): """检查表情包文件完整性