Merge pull request #324 from tcmofashi/debug

fix: 修复图像重复hash问题
This commit is contained in:
tcmofashi
2025-03-13 14:54:37 +08:00
committed by GitHub
2 changed files with 37 additions and 45 deletions

View File

@@ -44,18 +44,23 @@ class ImageManager:
"""确保images集合存在并创建索引"""
if "images" not in db.list_collection_names():
db.create_collection("images")
# 创建索引
db.images.create_index([("hash", 1)], unique=True)
db.images.create_index([("url", 1)])
db.images.create_index([("path", 1)])
# 删除旧索引
db.images.drop_indexes()
# 创建新的复合索引
db.images.create_index([("hash", 1), ("type", 1)], unique=True)
db.images.create_index([("url", 1)])
db.images.create_index([("path", 1)])
def _ensure_description_collection(self):
"""确保image_descriptions集合存在并创建索引"""
if "image_descriptions" not in db.list_collection_names():
db.create_collection("image_descriptions")
# 创建索引
db.image_descriptions.create_index([("hash", 1)], unique=True)
db.image_descriptions.create_index([("type", 1)])
# 删除旧索引
db.image_descriptions.drop_indexes()
# 创建新的复合索引
db.image_descriptions.create_index([("hash", 1), ("type", 1)], unique=True)
def _get_description_from_db(self, image_hash: str, description_type: str) -> Optional[str]:
"""从数据库获取图片描述
@@ -78,36 +83,21 @@ class ImageManager:
description: 描述文本
description_type: 描述类型 ('emoji''image')
"""
db.image_descriptions.update_one(
{"hash": image_hash, "type": description_type},
{"$set": {"description": description, "timestamp": int(time.time())}},
upsert=True,
)
async def get_image_by_url(self, url: str) -> Optional[str]:
"""根据URL获取图像路径(带查重)
Args:
url: 图像URL
Returns:
str: 本地文件路径,不存在返回None
"""
try:
# 先查找是否已存在
existing = db.images.find_one({"url": url})
if existing:
return existing["path"]
# 下载图像
async with aiohttp.ClientSession() as session:
async with session.get(url) as resp:
if resp.status == 200:
image_bytes = await resp.read()
return await self.save_image(image_bytes, url=url)
return None
db.image_descriptions.update_one(
{"hash": image_hash, "type": description_type},
{
"$set": {
"description": description,
"timestamp": int(time.time()),
"hash": image_hash, # 确保hash字段存在
"type": description_type, # 确保type字段存在
}
},
upsert=True,
)
except Exception as e:
logger.error(f"获取图像失败: {str(e)}")
return None
logger.error(f"保存描述到数据库失败: {str(e)}")
async def get_emoji_description(self, image_base64: str) -> str:
"""获取表情包描述,带查重和保存功能"""
@@ -129,7 +119,7 @@ class ImageManager:
cached_description = self._get_description_from_db(image_hash, "emoji")
if cached_description:
logger.warning(f"虽然生成了描述,但找到缓存表情包描述: {cached_description}")
logger.warning(f"虽然生成了描述,但找到缓存表情包描述: {cached_description}")
return f"[表情包:{cached_description}]"
# 根据配置决定是否保存图片
@@ -170,7 +160,6 @@ class ImageManager:
async def get_image_description(self, image_base64: str) -> str:
"""获取普通图片描述,带查重和保存功能"""
try:
print("处理图片中")
# 计算图片哈希
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
@@ -179,7 +168,7 @@ class ImageManager:
# 查询缓存的描述
cached_description = self._get_description_from_db(image_hash, "image")
if cached_description:
print("图片描述缓存中")
logger.info(f"图片描述缓存中 {cached_description}")
return f"[图片:{cached_description}]"
# 调用AI获取描述
@@ -187,12 +176,13 @@ class ImageManager:
"请用中文描述这张图片的内容。如果有文字请把文字都描述出来。并尝试猜测这个图片的含义。最多200个字。"
)
description, _ = await self._llm.generate_response_for_image(prompt, image_base64, image_format)
cached_description = self._get_description_from_db(image_hash, "emoji")
cached_description = self._get_description_from_db(image_hash, "image")
if cached_description:
logger.info(f"缓存图片描述: {cached_description}")
logger.warning(f"虽然生成了描述,但是找到缓存图片描述 {cached_description}")
return f"[图片:{cached_description}]"
print(f"描述是{description}")
logger.info(f"描述是{description}")
if description is None:
logger.warning("AI未能生成图片描述")

View File

@@ -5,14 +5,16 @@ from .relationship_manager import relationship_manager
def get_user_nickname(user_id: int) -> str:
if int(user_id) == int(global_config.BOT_QQ):
return global_config.BOT_NICKNAME
# print(user_id)
# print(user_id)
return relationship_manager.get_name(user_id)
def get_user_cardname(user_id: int) -> str:
if int(user_id) == int(global_config.BOT_QQ):
return global_config.BOT_NICKNAME
# print(user_id)
return ''
# print(user_id)
return ""
def get_groupname(group_id: int) -> str:
return f"{group_id}"
return f"{group_id}"