feat: 表情包匹配从情绪匹配改成嵌入相似度匹配

This commit is contained in:
tcmofashi
2025-03-05 10:43:08 +08:00
parent 669f9e400a
commit 97bddb83e7
2 changed files with 47 additions and 39 deletions

View File

@@ -174,7 +174,7 @@ class ChatBot:
bot_response_time = tinking_time_point
if random() < global_config.emoji_chance:
emoji_path = await emoji_manager.get_emoji_for_emotion(emotion)
emoji_path = await emoji_manager.get_emoji_for_text(response)
if emoji_path:
emoji_cq = CQCode.create_emoji_cq(emoji_path)

View File

@@ -15,11 +15,12 @@ import time
from PIL import Image
import io
from loguru import logger
import traceback
from nonebot import get_driver
from ..chat.config import global_config
from ..models.utils_model import LLM_request
from utils import get_embedding
from ..chat.utils import get_embedding
driver = get_driver()
config = driver.config
@@ -39,7 +40,7 @@ class EmojiManager:
def __init__(self):
self.db = Database.get_instance()
self._scan_task = None
self.llm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=50)
self.llm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=1000)
def _ensure_emoji_dir(self):
"""确保表情存储目录存在"""
@@ -98,45 +99,44 @@ class EmojiManager:
if not text_embedding:
logger.error("无法获取文本的embedding")
return None
# 使用embedding进行相似度搜索获取最相似的3个表情包
pipeline = [
{
"$search": {
"index": "default",
"knnBeta": {
"vector": text_embedding,
"path": "embedding",
"k": 3
}
}
}
]
try:
# 获取搜索结果
results = list(self.db.db.emoji.aggregate(pipeline))
# 获取所有表情包
all_emojis = list(self.db.db.emoji.find({}, {'_id': 1, 'path': 1, 'embedding': 1, 'discription': 1}))
if not results:
logger.warning("未找到匹配的表情包,尝试随机选择")
# 如果没有匹配的表情,随机选择一个
try:
emoji = self.db.db.emoji.aggregate([
{'$sample': {'size': 1}}
]).next()
if emoji and 'path' in emoji:
# 更新使用次数
self.db.db.emoji.update_one(
{'_id': emoji['_id']},
{'$inc': {'usage_count': 1}}
)
return emoji['path']
except StopIteration:
logger.error("数据库中没有任何表情")
return None
if not all_emojis:
logger.warning("数据库中没有任何表情包")
return None
# 从最相似的3个表情包中随机选择一个
selected_emoji = random.choice(results)
# 计算余弦相似度并排序
def cosine_similarity(v1, v2):
if not v1 or not v2:
return 0
dot_product = sum(a * b for a, b in zip(v1, v2))
norm_v1 = sum(a * a for a in v1) ** 0.5
norm_v2 = sum(b * b for b in v2) ** 0.5
if norm_v1 == 0 or norm_v2 == 0:
return 0
return dot_product / (norm_v1 * norm_v2)
# 计算所有表情包与输入文本的相似度
emoji_similarities = [
(emoji, cosine_similarity(text_embedding, emoji.get('embedding', [])))
for emoji in all_emojis
]
# 按相似度降序排序
emoji_similarities.sort(key=lambda x: x[1], reverse=True)
# 获取前3个最相似的表情包
top_3_emojis = emoji_similarities[:3]
if not top_3_emojis:
logger.warning("未找到匹配的表情包")
return None
# 从前3个中随机选择一个
selected_emoji, similarity = random.choice(top_3_emojis)
if selected_emoji and 'path' in selected_emoji:
# 更新使用次数
@@ -144,7 +144,7 @@ class EmojiManager:
{'_id': selected_emoji['_id']},
{'$inc': {'usage_count': 1}}
)
logger.success(f"找到匹配的表情包: {selected_emoji.get('discription', '无描述')}")
logger.success(f"找到匹配的表情包: {selected_emoji.get('discription', '无描述')} (相似度: {similarity:.4f})")
return selected_emoji['path']
except Exception as search_error:
@@ -285,6 +285,14 @@ class EmojiManager:
except Exception as e:
logger.error(f"扫描表情包失败: {str(e)}")
logger.error(traceback.format_exc())
async def _periodic_scan(self, interval_MINS: int = 10):
"""定期扫描新表情包"""
while True:
print(f"\033[1;36m[表情包]\033[0m 开始扫描新表情包...")
await self.scan_new_emojis()
await asyncio.sleep(interval_MINS * 60) # 每600秒扫描一次
def check_emoji_file_integrity(self):
"""检查表情包文件完整性