fix: 修复相同图片不同类型的重复hash问题,修复@全体
This commit is contained in:
@@ -44,18 +44,23 @@ class ImageManager:
|
|||||||
"""确保images集合存在并创建索引"""
|
"""确保images集合存在并创建索引"""
|
||||||
if "images" not in db.list_collection_names():
|
if "images" not in db.list_collection_names():
|
||||||
db.create_collection("images")
|
db.create_collection("images")
|
||||||
# 创建索引
|
|
||||||
db.images.create_index([("hash", 1)], unique=True)
|
# 删除旧索引
|
||||||
db.images.create_index([("url", 1)])
|
db.images.drop_indexes()
|
||||||
db.images.create_index([("path", 1)])
|
# 创建新的复合索引
|
||||||
|
db.images.create_index([("hash", 1), ("type", 1)], unique=True)
|
||||||
|
db.images.create_index([("url", 1)])
|
||||||
|
db.images.create_index([("path", 1)])
|
||||||
|
|
||||||
def _ensure_description_collection(self):
|
def _ensure_description_collection(self):
|
||||||
"""确保image_descriptions集合存在并创建索引"""
|
"""确保image_descriptions集合存在并创建索引"""
|
||||||
if "image_descriptions" not in db.list_collection_names():
|
if "image_descriptions" not in db.list_collection_names():
|
||||||
db.create_collection("image_descriptions")
|
db.create_collection("image_descriptions")
|
||||||
# 创建索引
|
|
||||||
db.image_descriptions.create_index([("hash", 1)], unique=True)
|
# 删除旧索引
|
||||||
db.image_descriptions.create_index([("type", 1)])
|
db.image_descriptions.drop_indexes()
|
||||||
|
# 创建新的复合索引
|
||||||
|
db.image_descriptions.create_index([("hash", 1), ("type", 1)], unique=True)
|
||||||
|
|
||||||
def _get_description_from_db(self, image_hash: str, description_type: str) -> Optional[str]:
|
def _get_description_from_db(self, image_hash: str, description_type: str) -> Optional[str]:
|
||||||
"""从数据库获取图片描述
|
"""从数据库获取图片描述
|
||||||
@@ -78,36 +83,21 @@ class ImageManager:
|
|||||||
description: 描述文本
|
description: 描述文本
|
||||||
description_type: 描述类型 ('emoji' 或 'image')
|
description_type: 描述类型 ('emoji' 或 'image')
|
||||||
"""
|
"""
|
||||||
db.image_descriptions.update_one(
|
|
||||||
{"hash": image_hash, "type": description_type},
|
|
||||||
{"$set": {"description": description, "timestamp": int(time.time())}},
|
|
||||||
upsert=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_image_by_url(self, url: str) -> Optional[str]:
|
|
||||||
"""根据URL获取图像路径(带查重)
|
|
||||||
Args:
|
|
||||||
url: 图像URL
|
|
||||||
Returns:
|
|
||||||
str: 本地文件路径,不存在返回None
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
# 先查找是否已存在
|
db.image_descriptions.update_one(
|
||||||
existing = db.images.find_one({"url": url})
|
{"hash": image_hash, "type": description_type},
|
||||||
if existing:
|
{
|
||||||
return existing["path"]
|
"$set": {
|
||||||
|
"description": description,
|
||||||
# 下载图像
|
"timestamp": int(time.time()),
|
||||||
async with aiohttp.ClientSession() as session:
|
"hash": image_hash, # 确保hash字段存在
|
||||||
async with session.get(url) as resp:
|
"type": description_type, # 确保type字段存在
|
||||||
if resp.status == 200:
|
}
|
||||||
image_bytes = await resp.read()
|
},
|
||||||
return await self.save_image(image_bytes, url=url)
|
upsert=True,
|
||||||
return None
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取图像失败: {str(e)}")
|
logger.error(f"保存描述到数据库失败: {str(e)}")
|
||||||
return None
|
|
||||||
|
|
||||||
async def get_emoji_description(self, image_base64: str) -> str:
|
async def get_emoji_description(self, image_base64: str) -> str:
|
||||||
"""获取表情包描述,带查重和保存功能"""
|
"""获取表情包描述,带查重和保存功能"""
|
||||||
@@ -129,7 +119,7 @@ class ImageManager:
|
|||||||
|
|
||||||
cached_description = self._get_description_from_db(image_hash, "emoji")
|
cached_description = self._get_description_from_db(image_hash, "emoji")
|
||||||
if cached_description:
|
if cached_description:
|
||||||
logger.warning(f"虽然生成了描述,但找到缓存表情包描述: {cached_description}")
|
logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}")
|
||||||
return f"[表情包:{cached_description}]"
|
return f"[表情包:{cached_description}]"
|
||||||
|
|
||||||
# 根据配置决定是否保存图片
|
# 根据配置决定是否保存图片
|
||||||
@@ -170,7 +160,6 @@ class ImageManager:
|
|||||||
async def get_image_description(self, image_base64: str) -> str:
|
async def get_image_description(self, image_base64: str) -> str:
|
||||||
"""获取普通图片描述,带查重和保存功能"""
|
"""获取普通图片描述,带查重和保存功能"""
|
||||||
try:
|
try:
|
||||||
print("处理图片中")
|
|
||||||
# 计算图片哈希
|
# 计算图片哈希
|
||||||
image_bytes = base64.b64decode(image_base64)
|
image_bytes = base64.b64decode(image_base64)
|
||||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||||
@@ -179,7 +168,7 @@ class ImageManager:
|
|||||||
# 查询缓存的描述
|
# 查询缓存的描述
|
||||||
cached_description = self._get_description_from_db(image_hash, "image")
|
cached_description = self._get_description_from_db(image_hash, "image")
|
||||||
if cached_description:
|
if cached_description:
|
||||||
print("图片描述缓存中")
|
logger.info(f"图片描述缓存中 {cached_description}")
|
||||||
return f"[图片:{cached_description}]"
|
return f"[图片:{cached_description}]"
|
||||||
|
|
||||||
# 调用AI获取描述
|
# 调用AI获取描述
|
||||||
@@ -187,12 +176,13 @@ class ImageManager:
|
|||||||
"请用中文描述这张图片的内容。如果有文字,请把文字都描述出来。并尝试猜测这个图片的含义。最多200个字。"
|
"请用中文描述这张图片的内容。如果有文字,请把文字都描述出来。并尝试猜测这个图片的含义。最多200个字。"
|
||||||
)
|
)
|
||||||
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
|
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
|
||||||
cached_description = self._get_description_from_db(image_hash, "emoji")
|
|
||||||
|
cached_description = self._get_description_from_db(image_hash, "image")
|
||||||
if cached_description:
|
if cached_description:
|
||||||
logger.info(f"缓存图片描述: {cached_description}")
|
logger.warning(f"虽然生成了描述,但是找到缓存图片描述 {cached_description}")
|
||||||
return f"[图片:{cached_description}]"
|
return f"[图片:{cached_description}]"
|
||||||
|
|
||||||
print(f"描述是{description}")
|
logger.info(f"描述是{description}")
|
||||||
|
|
||||||
if description is None:
|
if description is None:
|
||||||
logger.warning("AI未能生成图片描述")
|
logger.warning("AI未能生成图片描述")
|
||||||
|
|||||||
@@ -3,16 +3,20 @@ from .relationship_manager import relationship_manager
|
|||||||
|
|
||||||
|
|
||||||
def get_user_nickname(user_id: int) -> str:
|
def get_user_nickname(user_id: int) -> str:
|
||||||
|
if user_id == "all":
|
||||||
|
return "全体成员"
|
||||||
if int(user_id) == int(global_config.BOT_QQ):
|
if int(user_id) == int(global_config.BOT_QQ):
|
||||||
return global_config.BOT_NICKNAME
|
return global_config.BOT_NICKNAME
|
||||||
# print(user_id)
|
# print(user_id)
|
||||||
return relationship_manager.get_name(user_id)
|
return relationship_manager.get_name(user_id)
|
||||||
|
|
||||||
|
|
||||||
def get_user_cardname(user_id: int) -> str:
|
def get_user_cardname(user_id: int) -> str:
|
||||||
if int(user_id) == int(global_config.BOT_QQ):
|
if int(user_id) == int(global_config.BOT_QQ):
|
||||||
return global_config.BOT_NICKNAME
|
return global_config.BOT_NICKNAME
|
||||||
# print(user_id)
|
# print(user_id)
|
||||||
return ''
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def get_groupname(group_id: int) -> str:
|
def get_groupname(group_id: int) -> str:
|
||||||
return f"群{group_id}"
|
return f"群{group_id}"
|
||||||
|
|||||||
Reference in New Issue
Block a user