From 49082267bb5427d0f2affbd15b6942785b94c591 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=99=B4=E7=8C=AB?= Date: Wed, 12 Mar 2025 22:27:05 +0900 Subject: [PATCH] =?UTF-8?q?=E9=87=8D=E6=9E=84=E6=95=B0=E6=8D=AE=E5=BA=93?= =?UTF-8?q?=E6=A8=A1=E5=9D=97=E5=AE=9E=E7=8E=B0=E5=BB=B6=E8=BF=9F=E5=88=9D?= =?UTF-8?q?=E5=A7=8B=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 使用Global Object Pattern设计模式 - 实现数据库连接的延迟初始化 - 添加类型注解支持IDE类型推导 - 确保环境变量在bot.py加载后再连接数据库 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- src/common/database.py | 118 +++++++++++++++++------------------------ 1 file changed, 49 insertions(+), 69 deletions(-) diff --git a/src/common/database.py b/src/common/database.py index c6cead225..ca73dc468 100644 --- a/src/common/database.py +++ b/src/common/database.py @@ -1,73 +1,53 @@ -from typing import Optional +import os +from typing import cast from pymongo import MongoClient -from pymongo.database import Database as MongoDatabase +from pymongo.database import Database -class Database: - _instance: Optional["Database"] = None - - def __init__( - self, - host: str, - port: int, - db_name: str, - username: Optional[str] = None, - password: Optional[str] = None, - auth_source: Optional[str] = None, - uri: Optional[str] = None, - ): - if uri and uri.startswith("mongodb://"): - # 优先使用URI连接 - self.client = MongoClient(uri) - elif username and password: - # 如果有用户名和密码,使用认证连接 - self.client = MongoClient( - host, port, username=username, password=password, authSource=auth_source - ) - else: - # 否则使用无认证连接 - self.client = MongoClient(host, port) - self.db: MongoDatabase = self.client[db_name] - - @classmethod - def initialize( - cls, - host: str, - port: int, - db_name: str, - username: Optional[str] = None, - password: Optional[str] = None, - auth_source: Optional[str] = None, - uri: Optional[str] = None, - ) -> MongoDatabase: - if cls._instance is None: - cls._instance = cls( - host, port, db_name, username, password, auth_source, uri - ) - return cls._instance.db - - @classmethod - def get_instance(cls) -> MongoDatabase: - if cls._instance is None: - raise RuntimeError("Database not initialized") - return cls._instance.db +_client = None +_db = None - #测试用 - - def get_random_group_messages(self, group_id: str, limit: int = 5): - # 先随机获取一条消息 - random_message = list(self.db.messages.aggregate([ - {"$match": {"group_id": group_id}}, - {"$sample": {"size": 1}} - ]))[0] - - # 获取该消息之后的消息 - subsequent_messages = list(self.db.messages.find({ - "group_id": group_id, - "time": {"$gt": random_message["time"]} - }).sort("time", 1).limit(limit)) - - # 将随机消息和后续消息合并 - messages = [random_message] + subsequent_messages - - return messages \ No newline at end of file +def __create_database_instance(): + uri = os.getenv("MONGODB_URI") + host = os.getenv("MONGODB_HOST", "127.0.0.1") + port = int(os.getenv("MONGODB_PORT", "27017")) + db_name = os.getenv("DATABASE_NAME", "MegBot") + username = os.getenv("MONGODB_USERNAME") + password = os.getenv("MONGODB_PASSWORD") + auth_source = os.getenv("MONGODB_AUTH_SOURCE") + + if uri and uri.startswith("mongodb://"): + # 优先使用URI连接 + return MongoClient(uri) + + if username and password: + # 如果有用户名和密码,使用认证连接 + return MongoClient( + host, port, username=username, password=password, authSource=auth_source + ) + + # 否则使用无认证连接 + return MongoClient(host, port) + + +def get_db(): + """获取数据库连接实例,延迟初始化。""" + global _client, _db + if _client is None: + _client = __create_database_instance() + _db = _client[os.getenv("DATABASE_NAME", "MegBot")] + return _db + + +class DBWrapper: + """数据库代理类,保持接口兼容性同时实现懒加载。""" + + def __getattr__(self, name): + return getattr(get_db(), name) + + def __getitem__(self, key): + return get_db()[key] + + +# 全局数据库访问点 +db: Database = DBWrapper()