feat: 表情包匹配从情绪匹配改成嵌入相似度匹配

This commit is contained in:
tcmofashi
2025-03-05 10:43:08 +08:00
parent 669f9e400a
commit 97bddb83e7
2 changed files with 47 additions and 39 deletions

View File

@@ -174,7 +174,7 @@ class ChatBot:
bot_response_time = tinking_time_point bot_response_time = tinking_time_point
if random() < global_config.emoji_chance: if random() < global_config.emoji_chance:
emoji_path = await emoji_manager.get_emoji_for_emotion(emotion) emoji_path = await emoji_manager.get_emoji_for_text(response)
if emoji_path: if emoji_path:
emoji_cq = CQCode.create_emoji_cq(emoji_path) emoji_cq = CQCode.create_emoji_cq(emoji_path)

View File

@@ -15,11 +15,12 @@ import time
from PIL import Image from PIL import Image
import io import io
from loguru import logger from loguru import logger
import traceback
from nonebot import get_driver from nonebot import get_driver
from ..chat.config import global_config from ..chat.config import global_config
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
from utils import get_embedding from ..chat.utils import get_embedding
driver = get_driver() driver = get_driver()
config = driver.config config = driver.config
@@ -39,7 +40,7 @@ class EmojiManager:
def __init__(self): def __init__(self):
self.db = Database.get_instance() self.db = Database.get_instance()
self._scan_task = None self._scan_task = None
self.llm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=50) self.llm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=1000)
def _ensure_emoji_dir(self): def _ensure_emoji_dir(self):
"""确保表情存储目录存在""" """确保表情存储目录存在"""
@@ -99,44 +100,43 @@ class EmojiManager:
logger.error("无法获取文本的embedding") logger.error("无法获取文本的embedding")
return None return None
# 使用embedding进行相似度搜索获取最相似的3个表情包
pipeline = [
{
"$search": {
"index": "default",
"knnBeta": {
"vector": text_embedding,
"path": "embedding",
"k": 3
}
}
}
]
try: try:
# 获取搜索结果 # 获取所有表情包
results = list(self.db.db.emoji.aggregate(pipeline)) all_emojis = list(self.db.db.emoji.find({}, {'_id': 1, 'path': 1, 'embedding': 1, 'discription': 1}))
if not results: if not all_emojis:
logger.warning("未找到匹配的表情包,尝试随机选择") logger.warning("数据库中没有任何表情包")
# 如果没有匹配的表情,随机选择一个 return None
try:
emoji = self.db.db.emoji.aggregate([
{'$sample': {'size': 1}}
]).next()
if emoji and 'path' in emoji:
# 更新使用次数
self.db.db.emoji.update_one(
{'_id': emoji['_id']},
{'$inc': {'usage_count': 1}}
)
return emoji['path']
except StopIteration:
logger.error("数据库中没有任何表情")
return None
# 从最相似的3个表情包中随机选择一个 # 计算余弦相似度并排序
selected_emoji = random.choice(results) def cosine_similarity(v1, v2):
if not v1 or not v2:
return 0
dot_product = sum(a * b for a, b in zip(v1, v2))
norm_v1 = sum(a * a for a in v1) ** 0.5
norm_v2 = sum(b * b for b in v2) ** 0.5
if norm_v1 == 0 or norm_v2 == 0:
return 0
return dot_product / (norm_v1 * norm_v2)
# 计算所有表情包与输入文本的相似度
emoji_similarities = [
(emoji, cosine_similarity(text_embedding, emoji.get('embedding', [])))
for emoji in all_emojis
]
# 按相似度降序排序
emoji_similarities.sort(key=lambda x: x[1], reverse=True)
# 获取前3个最相似的表情包
top_3_emojis = emoji_similarities[:3]
if not top_3_emojis:
logger.warning("未找到匹配的表情包")
return None
# 从前3个中随机选择一个
selected_emoji, similarity = random.choice(top_3_emojis)
if selected_emoji and 'path' in selected_emoji: if selected_emoji and 'path' in selected_emoji:
# 更新使用次数 # 更新使用次数
@@ -144,7 +144,7 @@ class EmojiManager:
{'_id': selected_emoji['_id']}, {'_id': selected_emoji['_id']},
{'$inc': {'usage_count': 1}} {'$inc': {'usage_count': 1}}
) )
logger.success(f"找到匹配的表情包: {selected_emoji.get('discription', '无描述')}") logger.success(f"找到匹配的表情包: {selected_emoji.get('discription', '无描述')} (相似度: {similarity:.4f})")
return selected_emoji['path'] return selected_emoji['path']
except Exception as search_error: except Exception as search_error:
@@ -286,6 +286,14 @@ class EmojiManager:
logger.error(f"扫描表情包失败: {str(e)}") logger.error(f"扫描表情包失败: {str(e)}")
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
async def _periodic_scan(self, interval_MINS: int = 10):
"""定期扫描新表情包"""
while True:
print(f"\033[1;36m[表情包]\033[0m 开始扫描新表情包...")
await self.scan_new_emojis()
await asyncio.sleep(interval_MINS * 60) # 每600秒扫描一次
def check_emoji_file_integrity(self): def check_emoji_file_integrity(self):
"""检查表情包文件完整性 """检查表情包文件完整性
如果文件已被删除,则从数据库中移除对应记录 如果文件已被删除,则从数据库中移除对应记录