feat: 表情包匹配从情绪匹配改成嵌入相似度匹配
This commit is contained in:
@@ -174,7 +174,7 @@ class ChatBot:
|
||||
|
||||
bot_response_time = tinking_time_point
|
||||
if random() < global_config.emoji_chance:
|
||||
emoji_path = await emoji_manager.get_emoji_for_emotion(emotion)
|
||||
emoji_path = await emoji_manager.get_emoji_for_text(response)
|
||||
if emoji_path:
|
||||
emoji_cq = CQCode.create_emoji_cq(emoji_path)
|
||||
|
||||
|
||||
@@ -15,11 +15,12 @@ import time
|
||||
from PIL import Image
|
||||
import io
|
||||
from loguru import logger
|
||||
import traceback
|
||||
|
||||
from nonebot import get_driver
|
||||
from ..chat.config import global_config
|
||||
from ..models.utils_model import LLM_request
|
||||
from utils import get_embedding
|
||||
from ..chat.utils import get_embedding
|
||||
|
||||
driver = get_driver()
|
||||
config = driver.config
|
||||
@@ -39,7 +40,7 @@ class EmojiManager:
|
||||
def __init__(self):
|
||||
self.db = Database.get_instance()
|
||||
self._scan_task = None
|
||||
self.llm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=50)
|
||||
self.llm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=1000)
|
||||
|
||||
def _ensure_emoji_dir(self):
|
||||
"""确保表情存储目录存在"""
|
||||
@@ -98,45 +99,44 @@ class EmojiManager:
|
||||
if not text_embedding:
|
||||
logger.error("无法获取文本的embedding")
|
||||
return None
|
||||
|
||||
# 使用embedding进行相似度搜索,获取最相似的3个表情包
|
||||
pipeline = [
|
||||
{
|
||||
"$search": {
|
||||
"index": "default",
|
||||
"knnBeta": {
|
||||
"vector": text_embedding,
|
||||
"path": "embedding",
|
||||
"k": 3
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
try:
|
||||
# 获取搜索结果
|
||||
results = list(self.db.db.emoji.aggregate(pipeline))
|
||||
# 获取所有表情包
|
||||
all_emojis = list(self.db.db.emoji.find({}, {'_id': 1, 'path': 1, 'embedding': 1, 'discription': 1}))
|
||||
|
||||
if not results:
|
||||
logger.warning("未找到匹配的表情包,尝试随机选择")
|
||||
# 如果没有匹配的表情,随机选择一个
|
||||
try:
|
||||
emoji = self.db.db.emoji.aggregate([
|
||||
{'$sample': {'size': 1}}
|
||||
]).next()
|
||||
if emoji and 'path' in emoji:
|
||||
# 更新使用次数
|
||||
self.db.db.emoji.update_one(
|
||||
{'_id': emoji['_id']},
|
||||
{'$inc': {'usage_count': 1}}
|
||||
)
|
||||
return emoji['path']
|
||||
except StopIteration:
|
||||
logger.error("数据库中没有任何表情")
|
||||
return None
|
||||
if not all_emojis:
|
||||
logger.warning("数据库中没有任何表情包")
|
||||
return None
|
||||
|
||||
# 从最相似的3个表情包中随机选择一个
|
||||
selected_emoji = random.choice(results)
|
||||
# 计算余弦相似度并排序
|
||||
def cosine_similarity(v1, v2):
|
||||
if not v1 or not v2:
|
||||
return 0
|
||||
dot_product = sum(a * b for a, b in zip(v1, v2))
|
||||
norm_v1 = sum(a * a for a in v1) ** 0.5
|
||||
norm_v2 = sum(b * b for b in v2) ** 0.5
|
||||
if norm_v1 == 0 or norm_v2 == 0:
|
||||
return 0
|
||||
return dot_product / (norm_v1 * norm_v2)
|
||||
|
||||
# 计算所有表情包与输入文本的相似度
|
||||
emoji_similarities = [
|
||||
(emoji, cosine_similarity(text_embedding, emoji.get('embedding', [])))
|
||||
for emoji in all_emojis
|
||||
]
|
||||
|
||||
# 按相似度降序排序
|
||||
emoji_similarities.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
# 获取前3个最相似的表情包
|
||||
top_3_emojis = emoji_similarities[:3]
|
||||
|
||||
if not top_3_emojis:
|
||||
logger.warning("未找到匹配的表情包")
|
||||
return None
|
||||
|
||||
# 从前3个中随机选择一个
|
||||
selected_emoji, similarity = random.choice(top_3_emojis)
|
||||
|
||||
if selected_emoji and 'path' in selected_emoji:
|
||||
# 更新使用次数
|
||||
@@ -144,7 +144,7 @@ class EmojiManager:
|
||||
{'_id': selected_emoji['_id']},
|
||||
{'$inc': {'usage_count': 1}}
|
||||
)
|
||||
logger.success(f"找到匹配的表情包: {selected_emoji.get('discription', '无描述')}")
|
||||
logger.success(f"找到匹配的表情包: {selected_emoji.get('discription', '无描述')} (相似度: {similarity:.4f})")
|
||||
return selected_emoji['path']
|
||||
|
||||
except Exception as search_error:
|
||||
@@ -285,6 +285,14 @@ class EmojiManager:
|
||||
except Exception as e:
|
||||
logger.error(f"扫描表情包失败: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def _periodic_scan(self, interval_MINS: int = 10):
|
||||
"""定期扫描新表情包"""
|
||||
while True:
|
||||
print(f"\033[1;36m[表情包]\033[0m 开始扫描新表情包...")
|
||||
await self.scan_new_emojis()
|
||||
await asyncio.sleep(interval_MINS * 60) # 每600秒扫描一次
|
||||
|
||||
|
||||
def check_emoji_file_integrity(self):
|
||||
"""检查表情包文件完整性
|
||||
|
||||
Reference in New Issue
Block a user