初始化

This commit is contained in:
雅诺狐
2025-08-11 19:34:18 +08:00
committed by Windpicker-owo
parent ef7a3aee23
commit 23ee3767ef
77 changed files with 10000 additions and 7525 deletions

View File

@@ -12,9 +12,10 @@ import binascii
from typing import Optional, Tuple, List, Any
from PIL import Image
from rich.traceback import install
from src.common.database.database_model import Emoji
from src.common.database.database import db as peewee_db
from sqlalchemy import select
from src.common.database.database import db
from src.common.database.sqlalchemy_database_api import get_session
from src.common.database.sqlalchemy_models import Emoji, Images
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.chat.utils.utils_image import image_path_to_base64, get_image_manager
@@ -29,6 +30,8 @@ EMOJI_DIR = os.path.join(BASE_DIR, "emoji") # 表情包存储目录
EMOJI_REGISTERED_DIR = os.path.join(BASE_DIR, "emoji_registed") # 已注册的表情包注册目录
MAX_EMOJI_FOR_PROMPT = 20 # 最大允许的表情包描述数量于图片替换的 prompt 中
session = get_session()
"""
还没经过测试,有些地方数据库和内存数据同步可能不完全
@@ -151,7 +154,7 @@ class MaiEmoji:
# 准备数据库记录 for emoji collection
emotion_str = ",".join(self.emotion) if self.emotion else ""
Emoji.create(
emoji = Emoji(
emoji_hash=self.hash,
full_path=self.full_path,
format=self.format,
@@ -165,6 +168,8 @@ class MaiEmoji:
usage_count=self.usage_count,
last_used_time=self.last_used_time,
)
session.add(emoji)
session.commit()
logger.info(f"[注册] 表情包信息保存到数据库: {self.filename} ({self.emotion})")
@@ -200,7 +205,7 @@ class MaiEmoji:
# 2. 删除数据库记录
try:
will_delete_emoji = Emoji.get(Emoji.emoji_hash == self.hash)
will_delete_emoji = session.execute(select(Emoji).where(Emoji.emoji_hash == self.hash)).scalar_one_or_none()
result = will_delete_emoji.delete_instance() # Returns the number of rows deleted.
except Emoji.DoesNotExist: # type: ignore
logger.warning(f"[删除] 数据库中未找到哈希值为 {self.hash} 的表情包记录。")
@@ -248,7 +253,6 @@ def _emoji_objects_to_readable_list(emoji_objects: List["MaiEmoji"]) -> List[str
def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]:
emoji_objects = []
load_errors = 0
# data is now an iterable of Peewee Emoji model instances
emoji_data_list = list(data)
for emoji_data in emoji_data_list: # emoji_data is an Emoji model instance
@@ -393,12 +397,17 @@ class EmojiManager:
def initialize(self) -> None:
"""初始化数据库连接和表情目录"""
peewee_db.connect(reuse_if_open=True)
if peewee_db.is_closed():
raise RuntimeError("数据库连接失败")
_ensure_emoji_dir()
Emoji.create_table(safe=True) # Ensures table exists
self._initialized = True
try:
db.connect(reuse_if_open=True)
if db.is_closed():
raise RuntimeError("数据库连接失败")
_ensure_emoji_dir()
self._initialized = True # 标记为已初始化
logger.info("EmojiManager初始化成功")
except Exception as e:
logger.error(f"EmojiManager初始化失败: {e}")
self._initialized = False
raise
def _ensure_db(self) -> None:
"""确保数据库已初始化"""
@@ -410,7 +419,7 @@ class EmojiManager:
def record_usage(self, emoji_hash: str) -> None:
"""记录表情使用次数"""
try:
emoji_update = Emoji.get(Emoji.emoji_hash == emoji_hash)
emoji_update = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalar_one_or_none()
emoji_update.usage_count += 1
emoji_update.last_used_time = time.time() # Update last used time
emoji_update.save() # Persist changes to DB
@@ -644,10 +653,10 @@ class EmojiManager:
"""获取所有表情包并初始化为MaiEmoji类对象更新 self.emoji_objects"""
try:
self._ensure_db()
logger.debug("[数据库] 开始加载所有表情包记录 (Peewee)...")
logger.debug("[数据库] 开始加载所有表情包记录 ...")
emoji_peewee_instances = Emoji.select()
emoji_objects, load_errors = _to_emoji_objects(emoji_peewee_instances)
emoji_instances = session.execute(stmt = select(Emoji)).scalars().all()
emoji_objects, load_errors = _to_emoji_objects(emoji_instances)
# 更新内存中的列表和数量
self.emoji_objects = emoji_objects
@@ -675,15 +684,15 @@ class EmojiManager:
self._ensure_db()
if emoji_hash:
query = Emoji.select().where(Emoji.emoji_hash == emoji_hash)
session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalars().all()
else:
logger.warning(
"[查询] 未提供 hash将尝试加载所有表情包建议使用 get_all_emoji_from_db 更新管理器状态。"
)
query = Emoji.select()
query = session.execute(select(Emoji)).scalars().all()
emoji_peewee_instances = query
emoji_objects, load_errors = _to_emoji_objects(emoji_peewee_instances)
emoji_instances = query
emoji_objects, load_errors = _to_emoji_objects(emoji_instances)
if load_errors > 0:
logger.warning(f"[查询] 加载过程中出现 {load_errors} 个错误。")
@@ -760,7 +769,7 @@ class EmojiManager:
# 如果内存中没有,从数据库查找
self._ensure_db()
try:
emoji_record = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash)
emoji_record = session.execute(select(Emoji).where(Emoji.emoji_hash == emoji_hash)).scalar_one_or_none()
if emoji_record and emoji_record.description:
logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.description[:50]}...")
return emoji_record.description
@@ -921,9 +930,10 @@ class EmojiManager:
# 尝试从Images表获取已有的详细描述可能在收到表情包时已生成
existing_description = None
try:
from src.common.database.database_model import Images
# from src.common.database.database_model_compat import Images
existing_image = Images.get_or_none((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
stmt = select(Images).where((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
existing_image = session.execute(stmt).scalar_one_or_none()
if existing_image and existing_image.description:
existing_description = existing_image.description
logger.info(f"[复用描述] 找到已有详细描述: {existing_description[:50]}...")