feat: 表情包匹配从情绪匹配改成嵌入相似度匹配
This commit is contained in:
@@ -174,7 +174,7 @@ class ChatBot:
|
|||||||
|
|
||||||
bot_response_time = tinking_time_point
|
bot_response_time = tinking_time_point
|
||||||
if random() < global_config.emoji_chance:
|
if random() < global_config.emoji_chance:
|
||||||
emoji_path = await emoji_manager.get_emoji_for_emotion(emotion)
|
emoji_path = await emoji_manager.get_emoji_for_text(response)
|
||||||
if emoji_path:
|
if emoji_path:
|
||||||
emoji_cq = CQCode.create_emoji_cq(emoji_path)
|
emoji_cq = CQCode.create_emoji_cq(emoji_path)
|
||||||
|
|
||||||
|
|||||||
@@ -15,11 +15,12 @@ import time
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
import io
|
import io
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
import traceback
|
||||||
|
|
||||||
from nonebot import get_driver
|
from nonebot import get_driver
|
||||||
from ..chat.config import global_config
|
from ..chat.config import global_config
|
||||||
from ..models.utils_model import LLM_request
|
from ..models.utils_model import LLM_request
|
||||||
from utils import get_embedding
|
from ..chat.utils import get_embedding
|
||||||
|
|
||||||
driver = get_driver()
|
driver = get_driver()
|
||||||
config = driver.config
|
config = driver.config
|
||||||
@@ -39,7 +40,7 @@ class EmojiManager:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.db = Database.get_instance()
|
self.db = Database.get_instance()
|
||||||
self._scan_task = None
|
self._scan_task = None
|
||||||
self.llm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=50)
|
self.llm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=1000)
|
||||||
|
|
||||||
def _ensure_emoji_dir(self):
|
def _ensure_emoji_dir(self):
|
||||||
"""确保表情存储目录存在"""
|
"""确保表情存储目录存在"""
|
||||||
@@ -99,44 +100,43 @@ class EmojiManager:
|
|||||||
logger.error("无法获取文本的embedding")
|
logger.error("无法获取文本的embedding")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 使用embedding进行相似度搜索,获取最相似的3个表情包
|
|
||||||
pipeline = [
|
|
||||||
{
|
|
||||||
"$search": {
|
|
||||||
"index": "default",
|
|
||||||
"knnBeta": {
|
|
||||||
"vector": text_embedding,
|
|
||||||
"path": "embedding",
|
|
||||||
"k": 3
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 获取搜索结果
|
# 获取所有表情包
|
||||||
results = list(self.db.db.emoji.aggregate(pipeline))
|
all_emojis = list(self.db.db.emoji.find({}, {'_id': 1, 'path': 1, 'embedding': 1, 'discription': 1}))
|
||||||
|
|
||||||
if not results:
|
if not all_emojis:
|
||||||
logger.warning("未找到匹配的表情包,尝试随机选择")
|
logger.warning("数据库中没有任何表情包")
|
||||||
# 如果没有匹配的表情,随机选择一个
|
|
||||||
try:
|
|
||||||
emoji = self.db.db.emoji.aggregate([
|
|
||||||
{'$sample': {'size': 1}}
|
|
||||||
]).next()
|
|
||||||
if emoji and 'path' in emoji:
|
|
||||||
# 更新使用次数
|
|
||||||
self.db.db.emoji.update_one(
|
|
||||||
{'_id': emoji['_id']},
|
|
||||||
{'$inc': {'usage_count': 1}}
|
|
||||||
)
|
|
||||||
return emoji['path']
|
|
||||||
except StopIteration:
|
|
||||||
logger.error("数据库中没有任何表情")
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 从最相似的3个表情包中随机选择一个
|
# 计算余弦相似度并排序
|
||||||
selected_emoji = random.choice(results)
|
def cosine_similarity(v1, v2):
|
||||||
|
if not v1 or not v2:
|
||||||
|
return 0
|
||||||
|
dot_product = sum(a * b for a, b in zip(v1, v2))
|
||||||
|
norm_v1 = sum(a * a for a in v1) ** 0.5
|
||||||
|
norm_v2 = sum(b * b for b in v2) ** 0.5
|
||||||
|
if norm_v1 == 0 or norm_v2 == 0:
|
||||||
|
return 0
|
||||||
|
return dot_product / (norm_v1 * norm_v2)
|
||||||
|
|
||||||
|
# 计算所有表情包与输入文本的相似度
|
||||||
|
emoji_similarities = [
|
||||||
|
(emoji, cosine_similarity(text_embedding, emoji.get('embedding', [])))
|
||||||
|
for emoji in all_emojis
|
||||||
|
]
|
||||||
|
|
||||||
|
# 按相似度降序排序
|
||||||
|
emoji_similarities.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
|
# 获取前3个最相似的表情包
|
||||||
|
top_3_emojis = emoji_similarities[:3]
|
||||||
|
|
||||||
|
if not top_3_emojis:
|
||||||
|
logger.warning("未找到匹配的表情包")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 从前3个中随机选择一个
|
||||||
|
selected_emoji, similarity = random.choice(top_3_emojis)
|
||||||
|
|
||||||
if selected_emoji and 'path' in selected_emoji:
|
if selected_emoji and 'path' in selected_emoji:
|
||||||
# 更新使用次数
|
# 更新使用次数
|
||||||
@@ -144,7 +144,7 @@ class EmojiManager:
|
|||||||
{'_id': selected_emoji['_id']},
|
{'_id': selected_emoji['_id']},
|
||||||
{'$inc': {'usage_count': 1}}
|
{'$inc': {'usage_count': 1}}
|
||||||
)
|
)
|
||||||
logger.success(f"找到匹配的表情包: {selected_emoji.get('discription', '无描述')}")
|
logger.success(f"找到匹配的表情包: {selected_emoji.get('discription', '无描述')} (相似度: {similarity:.4f})")
|
||||||
return selected_emoji['path']
|
return selected_emoji['path']
|
||||||
|
|
||||||
except Exception as search_error:
|
except Exception as search_error:
|
||||||
@@ -286,6 +286,14 @@ class EmojiManager:
|
|||||||
logger.error(f"扫描表情包失败: {str(e)}")
|
logger.error(f"扫描表情包失败: {str(e)}")
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
|
async def _periodic_scan(self, interval_MINS: int = 10):
|
||||||
|
"""定期扫描新表情包"""
|
||||||
|
while True:
|
||||||
|
print(f"\033[1;36m[表情包]\033[0m 开始扫描新表情包...")
|
||||||
|
await self.scan_new_emojis()
|
||||||
|
await asyncio.sleep(interval_MINS * 60) # 每600秒扫描一次
|
||||||
|
|
||||||
|
|
||||||
def check_emoji_file_integrity(self):
|
def check_emoji_file_integrity(self):
|
||||||
"""检查表情包文件完整性
|
"""检查表情包文件完整性
|
||||||
如果文件已被删除,则从数据库中移除对应记录
|
如果文件已被删除,则从数据库中移除对应记录
|
||||||
|
|||||||
Reference in New Issue
Block a user