diff --git a/sqlalchemy_utils/functions/database.py b/sqlalchemy_utils/functions/database.py index 1be92a85..c9c1a426 100644 --- a/sqlalchemy_utils/functions/database.py +++ b/sqlalchemy_utils/functions/database.py @@ -453,10 +453,11 @@ def _sqlite_file_exists(database): return header[:16] == b'SQLite format 3\x00' -def database_exists(url): +def database_exists(url, default_db=None): """Check if a database exists. :param url: A SQLAlchemy engine URL. + :param default_db: The default database to use instead of requiring standard Performs backend-specific testing to quickly determine if a database exists on the server. :: @@ -481,7 +482,7 @@ def database_exists(url): try: if dialect_name == 'postgresql': text = "SELECT 1 FROM pg_database WHERE datname='%s'" % database - for db in (database, 'postgres', 'template1', 'template0', None): + for db in (database, default_db or 'postgres', 'template1', 'template0', None): url = _set_url_database(url, database=db) engine = sa.create_engine(url) try: @@ -518,7 +519,7 @@ def database_exists(url): engine.dispose() -def create_database(url, encoding='utf8', template=None): +def create_database(url, encoding='utf8', template=None, default_db=None): """Issue the appropriate CREATE DATABASE statement. :param url: A SQLAlchemy engine URL. @@ -526,6 +527,7 @@ def create_database(url, encoding='utf8', template=None): :param template: The name of the template from which to create the new database. At the moment only supported by PostgreSQL driver. + :param defualt_db: Overwrite the default database used when connecting. To create a database, you can pass a simple URL that would have been passed to ``create_engine``. :: @@ -545,14 +547,17 @@ def create_database(url, encoding='utf8', template=None): dialect_name = url.get_dialect().name dialect_driver = url.get_dialect().driver - if dialect_name == 'postgresql': - url = _set_url_database(url, database="postgres") - elif dialect_name == 'mssql': - url = _set_url_database(url, database="master") - elif dialect_name == 'cockroachdb': - url = _set_url_database(url, database="defaultdb") - elif not dialect_name == 'sqlite': - url = _set_url_database(url, database=None) + if default_db != None: + if dialect_name == 'postgresql': + url = _set_url_database(url, database="postgres") + elif dialect_name == 'mssql': + url = _set_url_database(url, database="master") + elif dialect_name == 'cockroachdb': + url = _set_url_database(url, database="defaultdb") + elif not dialect_name == 'sqlite': + url = _set_url_database(url, database=None) + else: + url = _set_url_database(url, database=default_db) if (dialect_name == 'mssql' and dialect_driver in {'pymssql', 'pyodbc'}) \ or (dialect_name == 'postgresql' and dialect_driver in {