Skip to content
27 changes: 16 additions & 11 deletions sqlalchemy_utils/functions/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,10 +453,11 @@ def _sqlite_file_exists(database):
return header[:16] == b'SQLite format 3\x00'


def database_exists(url):
def database_exists(url, default_db=None):
"""Check if a database exists.

:param url: A SQLAlchemy engine URL.
:param default_db: The default database to use instead of requiring standard

Performs backend-specific testing to quickly determine if a database
exists on the server. ::
Expand All @@ -481,7 +482,7 @@ def database_exists(url):
try:
if dialect_name == 'postgresql':
text = "SELECT 1 FROM pg_database WHERE datname='%s'" % database
for db in (database, 'postgres', 'template1', 'template0', None):
for db in (database, default_db or 'postgres', 'template1', 'template0', None):
url = _set_url_database(url, database=db)
engine = sa.create_engine(url)
try:
Expand Down Expand Up @@ -518,14 +519,15 @@ def database_exists(url):
engine.dispose()


def create_database(url, encoding='utf8', template=None):
def create_database(url, encoding='utf8', template=None, default_db=None):
"""Issue the appropriate CREATE DATABASE statement.

:param url: A SQLAlchemy engine URL.
:param encoding: The encoding to create the database as.
:param template:
The name of the template from which to create the new database. At the
moment only supported by PostgreSQL driver.
:param defualt_db: Overwrite the default database used when connecting.

To create a database, you can pass a simple URL that would have
been passed to ``create_engine``. ::
Expand All @@ -545,14 +547,17 @@ def create_database(url, encoding='utf8', template=None):
dialect_name = url.get_dialect().name
dialect_driver = url.get_dialect().driver

if dialect_name == 'postgresql':
url = _set_url_database(url, database="postgres")
elif dialect_name == 'mssql':
url = _set_url_database(url, database="master")
elif dialect_name == 'cockroachdb':
url = _set_url_database(url, database="defaultdb")
elif not dialect_name == 'sqlite':
url = _set_url_database(url, database=None)
if default_db != None:
if dialect_name == 'postgresql':
url = _set_url_database(url, database="postgres")
elif dialect_name == 'mssql':
url = _set_url_database(url, database="master")
elif dialect_name == 'cockroachdb':
url = _set_url_database(url, database="defaultdb")
elif not dialect_name == 'sqlite':
url = _set_url_database(url, database=None)
else:
url = _set_url_database(url, database=default_db)

if (dialect_name == 'mssql' and dialect_driver in {'pymssql', 'pyodbc'}) \
or (dialect_name == 'postgresql' and dialect_driver in {
Expand Down