数据库重构
This commit is contained in:
@@ -15,7 +15,7 @@ from src.common.logger import get_logger
|
||||
from src.common.database.sqlalchemy_models import (
|
||||
Base, get_db_session, Messages, ActionRecords, PersonInfo, ChatStreams,
|
||||
LLMUsage, Emoji, Images, ImageDescriptions, OnlineTime, Memory,
|
||||
Expression, ThinkingLog, GraphNodes, GraphEdges, Schedule, MaiZoneScheduleStatus, get_session
|
||||
Expression, ThinkingLog, GraphNodes, GraphEdges, Schedule, MaiZoneScheduleStatus
|
||||
)
|
||||
|
||||
logger = get_logger("sqlalchemy_database_api")
|
||||
@@ -41,38 +41,9 @@ MODEL_MAPPING = {
|
||||
}
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db_session():
|
||||
"""数据库会话上下文管理器,自动处理事务和连接错误"""
|
||||
session = None
|
||||
max_retries = 3
|
||||
retry_delay = 1.0
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
session = get_session()
|
||||
yield session
|
||||
session.commit()
|
||||
break
|
||||
except (DisconnectionError, OperationalError) as e:
|
||||
logger.warning(f"数据库连接错误 (尝试 {attempt + 1}/{max_retries}): {e}")
|
||||
if session:
|
||||
session.rollback()
|
||||
session.close()
|
||||
if attempt < max_retries - 1:
|
||||
time.sleep(retry_delay * (attempt + 1))
|
||||
else:
|
||||
raise
|
||||
except Exception:
|
||||
if session:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
if session:
|
||||
session.close()
|
||||
|
||||
|
||||
def build_filters(session: Session, model_class: Type[Base], filters: Dict[str, Any]):
|
||||
def build_filters(session, model_class: Type[Base], filters: Dict[str, Any]):
|
||||
"""构建查询过滤条件"""
|
||||
conditions = []
|
||||
|
||||
@@ -296,6 +267,7 @@ async def db_save(
|
||||
# 创建新记录
|
||||
new_record = model_class(**data)
|
||||
session.add(new_record)
|
||||
session.commit()
|
||||
session.flush()
|
||||
|
||||
# 转换为字典格式返回
|
||||
@@ -415,8 +387,3 @@ async def store_action_info(
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
# 兼容性函数,方便从Peewee迁移
|
||||
def get_model_class(model_name: str) -> Optional[Type[Base]]:
|
||||
"""根据模型名称获取模型类"""
|
||||
return MODEL_MAPPING.get(model_name)
|
||||
|
||||
Reference in New Issue
Block a user