初始化
This commit is contained in:
@@ -12,9 +12,10 @@ import binascii
|
||||
from typing import Optional, Tuple, List, Any
|
||||
from PIL import Image
|
||||
from rich.traceback import install
|
||||
|
||||
from src.common.database.database_model import Emoji
|
||||
from src.common.database.database import db as peewee_db
|
||||
from sqlalchemy import select
|
||||
from src.common.database.database import db
|
||||
from src.common.database.sqlalchemy_database_api import get_session
|
||||
from src.common.database.sqlalchemy_models import Emoji, Images
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.utils.utils_image import image_path_to_base64, get_image_manager
|
||||
@@ -29,6 +30,8 @@ EMOJI_DIR = os.path.join(BASE_DIR, "emoji") # 表情包存储目录
|
||||
EMOJI_REGISTERED_DIR = os.path.join(BASE_DIR, "emoji_registed") # 已注册的表情包注册目录
|
||||
MAX_EMOJI_FOR_PROMPT = 20 # 最大允许的表情包描述数量于图片替换的 prompt 中
|
||||
|
||||
session = get_session()
|
||||
|
||||
"""
|
||||
还没经过测试,有些地方数据库和内存数据同步可能不完全
|
||||
|
||||
@@ -151,7 +154,7 @@ class MaiEmoji:
|
||||
# 准备数据库记录 for emoji collection
|
||||
emotion_str = ",".join(self.emotion) if self.emotion else ""
|
||||
|
||||
Emoji.create(
|
||||
emoji = Emoji(
|
||||
emoji_hash=self.hash,
|
||||
full_path=self.full_path,
|
||||
format=self.format,
|
||||
@@ -165,6 +168,8 @@ class MaiEmoji:
|
||||
usage_count=self.usage_count,
|
||||
last_used_time=self.last_used_time,
|
||||
)
|
||||
session.add(emoji)
|
||||
session.commit()
|
||||
|
||||
logger.info(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})")
|
||||
|
||||
@@ -200,7 +205,7 @@ class MaiEmoji:
|
||||
|
||||
# 2. 删除数据库记录
|
||||
try:
|
||||
will_delete_emoji = Emoji.get(Emoji.emoji_hash == self.hash)
|
||||
will_delete_emoji = session.execute(select(Emoji).where(Emoji.emoji_hash == self.hash)).scalar_one_or_none()
|
||||
result = will_delete_emoji.delete_instance() # Returns the number of rows deleted.
|
||||
except Emoji.DoesNotExist: # type: ignore
|
||||
logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。")
|
||||
@@ -248,7 +253,6 @@ def _emoji_objects_to_readable_list(emoji_objects: List["MaiEmoji"]) -> List[str
|
||||
def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]:
|
||||
emoji_objects = []
|
||||
load_errors = 0
|
||||
# data is now an iterable of Peewee Emoji model instances
|
||||
emoji_data_list = list(data)
|
||||
|
||||
for emoji_data in emoji_data_list: # emoji_data is an Emoji model instance
|
||||
@@ -393,12 +397,17 @@ class EmojiManager:
|
||||
|
||||
def initialize(self) -> None:
|
||||
"""初始化数据库连接和表情目录"""
|
||||
peewee_db.connect(reuse_if_open=True)
|
||||
if peewee_db.is_closed():
|
||||
raise RuntimeError("数据库连接失败")
|
||||
_ensure_emoji_dir()
|
||||
Emoji.create_table(safe=True) # Ensures table exists
|
||||
self._initialized = True
|
||||
try:
|
||||
db.connect(reuse_if_open=True)
|
||||
if db.is_closed():
|
||||
raise RuntimeError("数据库连接失败")
|
||||
_ensure_emoji_dir()
|
||||
self._initialized = True # 标记为已初始化
|
||||
logger.info("EmojiManager初始化成功")
|
||||
except Exception as e:
|
||||
logger.error(f"EmojiManager初始化失败: {e}")
|
||||
self._initialized = False
|
||||
raise
|
||||
|
||||
def _ensure_db(self) -> None:
|
||||
"""确保数据库已初始化"""
|
||||
@@ -410,7 +419,7 @@ class EmojiManager:
|
||||
def record_usage(self, emoji_hash: str) -> None:
|
||||
"""记录表情使用次数"""
|
||||
try:
|
||||
emoji_update = Emoji.get(Emoji.emoji_hash == emoji_hash)
|
||||
emoji_update = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalar_one_or_none()
|
||||
emoji_update.usage_count += 1
|
||||
emoji_update.last_used_time = time.time() # Update last used time
|
||||
emoji_update.save() # Persist changes to DB
|
||||
@@ -644,10 +653,10 @@ class EmojiManager:
|
||||
"""获取所有表情包并初始化为MaiEmoji类对象,更新 self.emoji_objects"""
|
||||
try:
|
||||
self._ensure_db()
|
||||
logger.debug("[数据库] 开始加载所有表情包记录 (Peewee)...")
|
||||
logger.debug("[数据库] 开始加载所有表情包记录 ...")
|
||||
|
||||
emoji_peewee_instances = Emoji.select()
|
||||
emoji_objects, load_errors = _to_emoji_objects(emoji_peewee_instances)
|
||||
emoji_instances = session.execute(stmt = select(Emoji)).scalars().all()
|
||||
emoji_objects, load_errors = _to_emoji_objects(emoji_instances)
|
||||
|
||||
# 更新内存中的列表和数量
|
||||
self.emoji_objects = emoji_objects
|
||||
@@ -675,15 +684,15 @@ class EmojiManager:
|
||||
self._ensure_db()
|
||||
|
||||
if emoji_hash:
|
||||
query = Emoji.select().where(Emoji.emoji_hash == emoji_hash)
|
||||
session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalars().all()
|
||||
else:
|
||||
logger.warning(
|
||||
"[查询] 未提供 hash,将尝试加载所有表情包,建议使用 get_all_emoji_from_db 更新管理器状态。"
|
||||
)
|
||||
query = Emoji.select()
|
||||
query = session.execute(select(Emoji)).scalars().all()
|
||||
|
||||
emoji_peewee_instances = query
|
||||
emoji_objects, load_errors = _to_emoji_objects(emoji_peewee_instances)
|
||||
emoji_instances = query
|
||||
emoji_objects, load_errors = _to_emoji_objects(emoji_instances)
|
||||
|
||||
if load_errors > 0:
|
||||
logger.warning(f"[查询] 加载过程中出现 {load_errors} 个错误。")
|
||||
@@ -760,7 +769,7 @@ class EmojiManager:
|
||||
# 如果内存中没有,从数据库查找
|
||||
self._ensure_db()
|
||||
try:
|
||||
emoji_record = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash)
|
||||
emoji_record = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalar_one_or_none()
|
||||
if emoji_record and emoji_record.description:
|
||||
logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.description[:50]}...")
|
||||
return emoji_record.description
|
||||
@@ -921,9 +930,10 @@ class EmojiManager:
|
||||
# 尝试从Images表获取已有的详细描述(可能在收到表情包时已生成)
|
||||
existing_description = None
|
||||
try:
|
||||
from src.common.database.database_model import Images
|
||||
# from src.common.database.database_model_compat import Images
|
||||
|
||||
existing_image = Images.get_or_none((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
|
||||
stmt = select(Images).where((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
|
||||
existing_image = session.execute(stmt).scalar_one_or_none()
|
||||
if existing_image and existing_image.description:
|
||||
existing_description = existing_image.description
|
||||
logger.info(f"[复用描述] 找到已有详细描述: {existing_description[:50]}...")
|
||||
|
||||
Reference in New Issue
Block a user