重构数据库模块实现延迟初始化
- 使用Global Object Pattern设计模式 - 实现数据库连接的延迟初始化 - 添加类型注解支持IDE类型推导 - 确保环境变量在bot.py加载后再连接数据库 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -1,73 +1,53 @@
|
||||
from typing import Optional
|
||||
import os
|
||||
from typing import cast
|
||||
from pymongo import MongoClient
|
||||
from pymongo.database import Database as MongoDatabase
|
||||
from pymongo.database import Database
|
||||
|
||||
class Database:
|
||||
_instance: Optional["Database"] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str,
|
||||
port: int,
|
||||
db_name: str,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
auth_source: Optional[str] = None,
|
||||
uri: Optional[str] = None,
|
||||
):
|
||||
if uri and uri.startswith("mongodb://"):
|
||||
# 优先使用URI连接
|
||||
self.client = MongoClient(uri)
|
||||
elif username and password:
|
||||
# 如果有用户名和密码,使用认证连接
|
||||
self.client = MongoClient(
|
||||
host, port, username=username, password=password, authSource=auth_source
|
||||
)
|
||||
else:
|
||||
# 否则使用无认证连接
|
||||
self.client = MongoClient(host, port)
|
||||
self.db: MongoDatabase = self.client[db_name]
|
||||
|
||||
@classmethod
|
||||
def initialize(
|
||||
cls,
|
||||
host: str,
|
||||
port: int,
|
||||
db_name: str,
|
||||
username: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
auth_source: Optional[str] = None,
|
||||
uri: Optional[str] = None,
|
||||
) -> MongoDatabase:
|
||||
if cls._instance is None:
|
||||
cls._instance = cls(
|
||||
host, port, db_name, username, password, auth_source, uri
|
||||
)
|
||||
return cls._instance.db
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> MongoDatabase:
|
||||
if cls._instance is None:
|
||||
raise RuntimeError("Database not initialized")
|
||||
return cls._instance.db
|
||||
_client = None
|
||||
_db = None
|
||||
|
||||
|
||||
#测试用
|
||||
|
||||
def get_random_group_messages(self, group_id: str, limit: int = 5):
|
||||
# 先随机获取一条消息
|
||||
random_message = list(self.db.messages.aggregate([
|
||||
{"$match": {"group_id": group_id}},
|
||||
{"$sample": {"size": 1}}
|
||||
]))[0]
|
||||
|
||||
# 获取该消息之后的消息
|
||||
subsequent_messages = list(self.db.messages.find({
|
||||
"group_id": group_id,
|
||||
"time": {"$gt": random_message["time"]}
|
||||
}).sort("time", 1).limit(limit))
|
||||
|
||||
# 将随机消息和后续消息合并
|
||||
messages = [random_message] + subsequent_messages
|
||||
|
||||
return messages
|
||||
def __create_database_instance():
|
||||
uri = os.getenv("MONGODB_URI")
|
||||
host = os.getenv("MONGODB_HOST", "127.0.0.1")
|
||||
port = int(os.getenv("MONGODB_PORT", "27017"))
|
||||
db_name = os.getenv("DATABASE_NAME", "MegBot")
|
||||
username = os.getenv("MONGODB_USERNAME")
|
||||
password = os.getenv("MONGODB_PASSWORD")
|
||||
auth_source = os.getenv("MONGODB_AUTH_SOURCE")
|
||||
|
||||
if uri and uri.startswith("mongodb://"):
|
||||
# 优先使用URI连接
|
||||
return MongoClient(uri)
|
||||
|
||||
if username and password:
|
||||
# 如果有用户名和密码,使用认证连接
|
||||
return MongoClient(
|
||||
host, port, username=username, password=password, authSource=auth_source
|
||||
)
|
||||
|
||||
# 否则使用无认证连接
|
||||
return MongoClient(host, port)
|
||||
|
||||
|
||||
def get_db():
|
||||
"""获取数据库连接实例,延迟初始化。"""
|
||||
global _client, _db
|
||||
if _client is None:
|
||||
_client = __create_database_instance()
|
||||
_db = _client[os.getenv("DATABASE_NAME", "MegBot")]
|
||||
return _db
|
||||
|
||||
|
||||
class DBWrapper:
|
||||
"""数据库代理类,保持接口兼容性同时实现懒加载。"""
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(get_db(), name)
|
||||
|
||||
def __getitem__(self, key):
|
||||
return get_db()[key]
|
||||
|
||||
|
||||
# 全局数据库访问点
|
||||
db: Database = DBWrapper()
|
||||
|
||||
Reference in New Issue
Block a user