数据库重构

This commit is contained in:
雅诺狐
2025-08-16 23:43:45 +08:00
parent 0f0619762b
commit d46d689c43
21 changed files with 834 additions and 1007 deletions

View File

@@ -29,102 +29,6 @@ def get_string_field(max_length=255, **kwargs):
return String(max_length, **kwargs)
else:
return Text(**kwargs)
class SessionProxy:
"""线程安全的Session代理类自动管理session生命周期"""
def __init__(self):
self._local = threading.local()
def _get_current_session(self):
"""获取当前线程的session如果没有则创建新的"""
if not hasattr(self._local, 'session') or self._local.session is None:
_, SessionLocal = initialize_database()
self._local.session = SessionLocal()
return self._local.session
def _close_current_session(self):
"""关闭当前线程的session"""
if hasattr(self._local, 'session') and self._local.session is not None:
try:
self._local.session.close()
except:
pass
finally:
self._local.session = None
def __getattr__(self, name):
"""代理所有session方法"""
session = self._get_current_session()
attr = getattr(session, name)
# 如果是方法,需要特殊处理一些关键方法
if callable(attr):
if name in ['commit', 'rollback']:
def wrapper(*args, **kwargs):
try:
result = attr(*args, **kwargs)
if name == 'commit':
# commit后不要清除session只是刷新状态
pass # 保持session活跃
return result
except Exception:
try:
if session and hasattr(session, 'rollback'):
session.rollback()
except:
pass
# 发生错误时重新创建session
self._close_current_session()
raise
return wrapper
elif name == 'close':
def wrapper(*args, **kwargs):
result = attr(*args, **kwargs)
self._close_current_session()
return result
return wrapper
elif name in ['execute', 'query', 'add', 'delete', 'merge']:
def wrapper(*args, **kwargs):
try:
return attr(*args, **kwargs)
except Exception as e:
# 如果是连接相关错误重新创建session再试一次
if "not bound to a Session" in str(e) or "provisioning a new connection" in str(e):
logger.warning(f"Session问题重新创建session: {e}")
self._close_current_session()
new_session = self._get_current_session()
new_attr = getattr(new_session, name)
return new_attr(*args, **kwargs)
raise
return wrapper
return attr
def new_session(self):
"""强制创建新的session关闭当前的创建新的"""
self._close_current_session()
return self._get_current_session()
def ensure_fresh_session(self):
"""确保使用新鲜的session如果当前session有问题则重新创建"""
if hasattr(self._local, 'session') and self._local.session is not None:
try:
# 测试session是否还可用
self._local.session.execute("SELECT 1")
except Exception:
# session有问题重新创建
self._close_current_session()
return self._get_current_session()
# 创建全局session代理实例
_global_session_proxy = SessionProxy()
def get_session():
"""返回线程安全的session代理自动管理生命周期"""
return _global_session_proxy
class ChatStreams(Base):
@@ -482,6 +386,22 @@ class MaiZoneScheduleStatus(Base):
)
class BanUser(Base):
"""被禁用用户模型"""
__tablename__ = 'ban_users'
id = Column(Integer, primary_key=True, autoincrement=True)
user_id = Column(get_string_field(50), nullable=False, index=True)
violation_num = Column(Integer, nullable=False, default=0)
reason = Column(Text, nullable=False)
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
__table_args__ = (
Index('idx_violation_num', 'violation_num'),
Index('idx_banuser_user_id', 'user_id'),
)
# 数据库引擎和会话管理
_engine = None
_SessionLocal = None
@@ -593,7 +513,7 @@ def get_db_session():
_, SessionLocal = initialize_database()
session = SessionLocal()
yield session
session.commit()
# session.commit()
except Exception:
if session:
session.rollback()
@@ -601,6 +521,7 @@ def get_db_session():
finally:
if session:
session.close()
def get_engine():