重构数据库模块实现延迟初始化

- 使用Global Object Pattern设计模式
- 实现数据库连接的延迟初始化
- 添加类型注解支持IDE类型推导
- 确保环境变量在bot.py加载后再连接数据库

🤖 Generated with [Claude Code](https://claude.ai/code)
Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
晴猫
2025-03-12 22:27:05 +09:00
parent ae0481ff29
commit 49082267bb

View File

@@ -1,73 +1,53 @@
from typing import Optional import os
from typing import cast
from pymongo import MongoClient from pymongo import MongoClient
from pymongo.database import Database as MongoDatabase from pymongo.database import Database
class Database: _client = None
_instance: Optional["Database"] = None _db = None
def __init__(
self,
host: str,
port: int,
db_name: str,
username: Optional[str] = None,
password: Optional[str] = None,
auth_source: Optional[str] = None,
uri: Optional[str] = None,
):
if uri and uri.startswith("mongodb://"):
# 优先使用URI连接
self.client = MongoClient(uri)
elif username and password:
# 如果有用户名和密码,使用认证连接
self.client = MongoClient(
host, port, username=username, password=password, authSource=auth_source
)
else:
# 否则使用无认证连接
self.client = MongoClient(host, port)
self.db: MongoDatabase = self.client[db_name]
@classmethod
def initialize(
cls,
host: str,
port: int,
db_name: str,
username: Optional[str] = None,
password: Optional[str] = None,
auth_source: Optional[str] = None,
uri: Optional[str] = None,
) -> MongoDatabase:
if cls._instance is None:
cls._instance = cls(
host, port, db_name, username, password, auth_source, uri
)
return cls._instance.db
@classmethod
def get_instance(cls) -> MongoDatabase:
if cls._instance is None:
raise RuntimeError("Database not initialized")
return cls._instance.db
#测试用 def __create_database_instance():
uri = os.getenv("MONGODB_URI")
host = os.getenv("MONGODB_HOST", "127.0.0.1")
port = int(os.getenv("MONGODB_PORT", "27017"))
db_name = os.getenv("DATABASE_NAME", "MegBot")
username = os.getenv("MONGODB_USERNAME")
password = os.getenv("MONGODB_PASSWORD")
auth_source = os.getenv("MONGODB_AUTH_SOURCE")
def get_random_group_messages(self, group_id: str, limit: int = 5): if uri and uri.startswith("mongodb://"):
# 先随机获取一条消息 # 优先使用URI连接
random_message = list(self.db.messages.aggregate([ return MongoClient(uri)
{"$match": {"group_id": group_id}},
{"$sample": {"size": 1}}
]))[0]
# 获取该消息之后的消息 if username and password:
subsequent_messages = list(self.db.messages.find({ # 如果有用户名和密码,使用认证连接
"group_id": group_id, return MongoClient(
"time": {"$gt": random_message["time"]} host, port, username=username, password=password, authSource=auth_source
}).sort("time", 1).limit(limit)) )
# 将随机消息和后续消息合并 # 否则使用无认证连接
messages = [random_message] + subsequent_messages return MongoClient(host, port)
return messages
def get_db():
"""获取数据库连接实例,延迟初始化。"""
global _client, _db
if _client is None:
_client = __create_database_instance()
_db = _client[os.getenv("DATABASE_NAME", "MegBot")]
return _db
class DBWrapper:
"""数据库代理类,保持接口兼容性同时实现懒加载。"""
def __getattr__(self, name):
return getattr(get_db(), name)
def __getitem__(self, key):
return get_db()[key]
# 全局数据库访问点
db: Database = DBWrapper()