数据库重构
This commit is contained in:
@@ -29,102 +29,6 @@ def get_string_field(max_length=255, **kwargs):
|
||||
return String(max_length, **kwargs)
|
||||
else:
|
||||
return Text(**kwargs)
|
||||
|
||||
|
||||
|
||||
class SessionProxy:
|
||||
"""线程安全的Session代理类,自动管理session生命周期"""
|
||||
|
||||
def __init__(self):
|
||||
self._local = threading.local()
|
||||
|
||||
def _get_current_session(self):
|
||||
"""获取当前线程的session,如果没有则创建新的"""
|
||||
if not hasattr(self._local, 'session') or self._local.session is None:
|
||||
_, SessionLocal = initialize_database()
|
||||
self._local.session = SessionLocal()
|
||||
return self._local.session
|
||||
|
||||
def _close_current_session(self):
|
||||
"""关闭当前线程的session"""
|
||||
if hasattr(self._local, 'session') and self._local.session is not None:
|
||||
try:
|
||||
self._local.session.close()
|
||||
except:
|
||||
pass
|
||||
finally:
|
||||
self._local.session = None
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""代理所有session方法"""
|
||||
session = self._get_current_session()
|
||||
attr = getattr(session, name)
|
||||
|
||||
# 如果是方法,需要特殊处理一些关键方法
|
||||
if callable(attr):
|
||||
if name in ['commit', 'rollback']:
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
result = attr(*args, **kwargs)
|
||||
if name == 'commit':
|
||||
# commit后不要清除session,只是刷新状态
|
||||
pass # 保持session活跃
|
||||
return result
|
||||
except Exception:
|
||||
try:
|
||||
if session and hasattr(session, 'rollback'):
|
||||
session.rollback()
|
||||
except:
|
||||
pass
|
||||
# 发生错误时重新创建session
|
||||
self._close_current_session()
|
||||
raise
|
||||
return wrapper
|
||||
elif name == 'close':
|
||||
def wrapper(*args, **kwargs):
|
||||
result = attr(*args, **kwargs)
|
||||
self._close_current_session()
|
||||
return result
|
||||
return wrapper
|
||||
elif name in ['execute', 'query', 'add', 'delete', 'merge']:
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return attr(*args, **kwargs)
|
||||
except Exception as e:
|
||||
# 如果是连接相关错误,重新创建session再试一次
|
||||
if "not bound to a Session" in str(e) or "provisioning a new connection" in str(e):
|
||||
logger.warning(f"Session问题,重新创建session: {e}")
|
||||
self._close_current_session()
|
||||
new_session = self._get_current_session()
|
||||
new_attr = getattr(new_session, name)
|
||||
return new_attr(*args, **kwargs)
|
||||
raise
|
||||
return wrapper
|
||||
|
||||
return attr
|
||||
|
||||
def new_session(self):
|
||||
"""强制创建新的session(关闭当前的,创建新的)"""
|
||||
self._close_current_session()
|
||||
return self._get_current_session()
|
||||
|
||||
def ensure_fresh_session(self):
|
||||
"""确保使用新鲜的session(如果当前session有问题则重新创建)"""
|
||||
if hasattr(self._local, 'session') and self._local.session is not None:
|
||||
try:
|
||||
# 测试session是否还可用
|
||||
self._local.session.execute("SELECT 1")
|
||||
except Exception:
|
||||
# session有问题,重新创建
|
||||
self._close_current_session()
|
||||
return self._get_current_session()
|
||||
|
||||
# 创建全局session代理实例
|
||||
_global_session_proxy = SessionProxy()
|
||||
|
||||
def get_session():
|
||||
"""返回线程安全的session代理,自动管理生命周期"""
|
||||
return _global_session_proxy
|
||||
|
||||
|
||||
class ChatStreams(Base):
|
||||
@@ -482,6 +386,22 @@ class MaiZoneScheduleStatus(Base):
|
||||
)
|
||||
|
||||
|
||||
class BanUser(Base):
|
||||
"""被禁用用户模型"""
|
||||
__tablename__ = 'ban_users'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
user_id = Column(get_string_field(50), nullable=False, index=True)
|
||||
violation_num = Column(Integer, nullable=False, default=0)
|
||||
reason = Column(Text, nullable=False)
|
||||
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_violation_num', 'violation_num'),
|
||||
Index('idx_banuser_user_id', 'user_id'),
|
||||
)
|
||||
|
||||
|
||||
# 数据库引擎和会话管理
|
||||
_engine = None
|
||||
_SessionLocal = None
|
||||
@@ -593,7 +513,7 @@ def get_db_session():
|
||||
_, SessionLocal = initialize_database()
|
||||
session = SessionLocal()
|
||||
yield session
|
||||
session.commit()
|
||||
# session.commit()
|
||||
except Exception:
|
||||
if session:
|
||||
session.rollback()
|
||||
@@ -601,6 +521,7 @@ def get_db_session():
|
||||
finally:
|
||||
if session:
|
||||
session.close()
|
||||
|
||||
|
||||
|
||||
def get_engine():
|
||||
|
||||
Reference in New Issue
Block a user