重构数据库模块实现延迟初始化
- 使用Global Object Pattern设计模式 - 实现数据库连接的延迟初始化 - 添加类型注解支持IDE类型推导 - 确保环境变量在bot.py加载后再连接数据库 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -1,73 +1,53 @@
|
|||||||
from typing import Optional
|
import os
|
||||||
|
from typing import cast
|
||||||
from pymongo import MongoClient
|
from pymongo import MongoClient
|
||||||
from pymongo.database import Database as MongoDatabase
|
from pymongo.database import Database
|
||||||
|
|
||||||
class Database:
|
_client = None
|
||||||
_instance: Optional["Database"] = None
|
_db = None
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
host: str,
|
|
||||||
port: int,
|
|
||||||
db_name: str,
|
|
||||||
username: Optional[str] = None,
|
|
||||||
password: Optional[str] = None,
|
|
||||||
auth_source: Optional[str] = None,
|
|
||||||
uri: Optional[str] = None,
|
|
||||||
):
|
|
||||||
if uri and uri.startswith("mongodb://"):
|
|
||||||
# 优先使用URI连接
|
|
||||||
self.client = MongoClient(uri)
|
|
||||||
elif username and password:
|
|
||||||
# 如果有用户名和密码,使用认证连接
|
|
||||||
self.client = MongoClient(
|
|
||||||
host, port, username=username, password=password, authSource=auth_source
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# 否则使用无认证连接
|
|
||||||
self.client = MongoClient(host, port)
|
|
||||||
self.db: MongoDatabase = self.client[db_name]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def initialize(
|
|
||||||
cls,
|
|
||||||
host: str,
|
|
||||||
port: int,
|
|
||||||
db_name: str,
|
|
||||||
username: Optional[str] = None,
|
|
||||||
password: Optional[str] = None,
|
|
||||||
auth_source: Optional[str] = None,
|
|
||||||
uri: Optional[str] = None,
|
|
||||||
) -> MongoDatabase:
|
|
||||||
if cls._instance is None:
|
|
||||||
cls._instance = cls(
|
|
||||||
host, port, db_name, username, password, auth_source, uri
|
|
||||||
)
|
|
||||||
return cls._instance.db
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_instance(cls) -> MongoDatabase:
|
|
||||||
if cls._instance is None:
|
|
||||||
raise RuntimeError("Database not initialized")
|
|
||||||
return cls._instance.db
|
|
||||||
|
|
||||||
|
|
||||||
#测试用
|
def __create_database_instance():
|
||||||
|
uri = os.getenv("MONGODB_URI")
|
||||||
|
host = os.getenv("MONGODB_HOST", "127.0.0.1")
|
||||||
|
port = int(os.getenv("MONGODB_PORT", "27017"))
|
||||||
|
db_name = os.getenv("DATABASE_NAME", "MegBot")
|
||||||
|
username = os.getenv("MONGODB_USERNAME")
|
||||||
|
password = os.getenv("MONGODB_PASSWORD")
|
||||||
|
auth_source = os.getenv("MONGODB_AUTH_SOURCE")
|
||||||
|
|
||||||
def get_random_group_messages(self, group_id: str, limit: int = 5):
|
if uri and uri.startswith("mongodb://"):
|
||||||
# 先随机获取一条消息
|
# 优先使用URI连接
|
||||||
random_message = list(self.db.messages.aggregate([
|
return MongoClient(uri)
|
||||||
{"$match": {"group_id": group_id}},
|
|
||||||
{"$sample": {"size": 1}}
|
|
||||||
]))[0]
|
|
||||||
|
|
||||||
# 获取该消息之后的消息
|
if username and password:
|
||||||
subsequent_messages = list(self.db.messages.find({
|
# 如果有用户名和密码,使用认证连接
|
||||||
"group_id": group_id,
|
return MongoClient(
|
||||||
"time": {"$gt": random_message["time"]}
|
host, port, username=username, password=password, authSource=auth_source
|
||||||
}).sort("time", 1).limit(limit))
|
)
|
||||||
|
|
||||||
# 将随机消息和后续消息合并
|
# 否则使用无认证连接
|
||||||
messages = [random_message] + subsequent_messages
|
return MongoClient(host, port)
|
||||||
|
|
||||||
return messages
|
|
||||||
|
def get_db():
|
||||||
|
"""获取数据库连接实例,延迟初始化。"""
|
||||||
|
global _client, _db
|
||||||
|
if _client is None:
|
||||||
|
_client = __create_database_instance()
|
||||||
|
_db = _client[os.getenv("DATABASE_NAME", "MegBot")]
|
||||||
|
return _db
|
||||||
|
|
||||||
|
|
||||||
|
class DBWrapper:
|
||||||
|
"""数据库代理类,保持接口兼容性同时实现懒加载。"""
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
return getattr(get_db(), name)
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return get_db()[key]
|
||||||
|
|
||||||
|
|
||||||
|
# 全局数据库访问点
|
||||||
|
db: Database = DBWrapper()
|
||||||
|
|||||||
Reference in New Issue
Block a user