diff --git a/src/common/database.py b/src/common/database.py index c6cead225..ca73dc468 100644 --- a/src/common/database.py +++ b/src/common/database.py @@ -1,73 +1,53 @@ -from typing import Optional +import os +from typing import cast from pymongo import MongoClient -from pymongo.database import Database as MongoDatabase +from pymongo.database import Database -class Database: - _instance: Optional["Database"] = None - - def __init__( - self, - host: str, - port: int, - db_name: str, - username: Optional[str] = None, - password: Optional[str] = None, - auth_source: Optional[str] = None, - uri: Optional[str] = None, - ): - if uri and uri.startswith("mongodb://"): - # 优先使用URI连接 - self.client = MongoClient(uri) - elif username and password: - # 如果有用户名和密码,使用认证连接 - self.client = MongoClient( - host, port, username=username, password=password, authSource=auth_source - ) - else: - # 否则使用无认证连接 - self.client = MongoClient(host, port) - self.db: MongoDatabase = self.client[db_name] - - @classmethod - def initialize( - cls, - host: str, - port: int, - db_name: str, - username: Optional[str] = None, - password: Optional[str] = None, - auth_source: Optional[str] = None, - uri: Optional[str] = None, - ) -> MongoDatabase: - if cls._instance is None: - cls._instance = cls( - host, port, db_name, username, password, auth_source, uri - ) - return cls._instance.db - - @classmethod - def get_instance(cls) -> MongoDatabase: - if cls._instance is None: - raise RuntimeError("Database not initialized") - return cls._instance.db +_client = None +_db = None - #测试用 - - def get_random_group_messages(self, group_id: str, limit: int = 5): - # 先随机获取一条消息 - random_message = list(self.db.messages.aggregate([ - {"$match": {"group_id": group_id}}, - {"$sample": {"size": 1}} - ]))[0] - - # 获取该消息之后的消息 - subsequent_messages = list(self.db.messages.find({ - "group_id": group_id, - "time": {"$gt": random_message["time"]} - }).sort("time", 1).limit(limit)) - - # 将随机消息和后续消息合并 - messages = [random_message] + subsequent_messages - - return messages \ No newline at end of file +def __create_database_instance(): + uri = os.getenv("MONGODB_URI") + host = os.getenv("MONGODB_HOST", "127.0.0.1") + port = int(os.getenv("MONGODB_PORT", "27017")) + db_name = os.getenv("DATABASE_NAME", "MegBot") + username = os.getenv("MONGODB_USERNAME") + password = os.getenv("MONGODB_PASSWORD") + auth_source = os.getenv("MONGODB_AUTH_SOURCE") + + if uri and uri.startswith("mongodb://"): + # 优先使用URI连接 + return MongoClient(uri) + + if username and password: + # 如果有用户名和密码,使用认证连接 + return MongoClient( + host, port, username=username, password=password, authSource=auth_source + ) + + # 否则使用无认证连接 + return MongoClient(host, port) + + +def get_db(): + """获取数据库连接实例,延迟初始化。""" + global _client, _db + if _client is None: + _client = __create_database_instance() + _db = _client[os.getenv("DATABASE_NAME", "MegBot")] + return _db + + +class DBWrapper: + """数据库代理类,保持接口兼容性同时实现懒加载。""" + + def __getattr__(self, name): + return getattr(get_db(), name) + + def __getitem__(self, key): + return get_db()[key] + + +# 全局数据库访问点 +db: Database = DBWrapper()