From dcf3378183b1a49a07e089b0c2dd640cb2077ad2 Mon Sep 17 00:00:00 2001 From: Ethan Ralph <63806108+GarbageHamburger@users.noreply.github.com> Date: Mon, 4 May 2020 15:44:52 +0000 Subject: [PATCH 1/3] Implement column defaults for INSERT/UPDATE --- databases/core.py | 62 ++++++++++++++++++++++++++++++++++++++++- tests/test_databases.py | 40 ++++++++++++++++++++++++++ 2 files changed, 101 insertions(+), 1 deletion(-) diff --git a/databases/core.py b/databases/core.py index cb201be3..e3f66a53 100644 --- a/databases/core.py +++ b/databases/core.py @@ -9,6 +9,9 @@ from sqlalchemy import text from sqlalchemy.sql import ClauseElement +from sqlalchemy.sql.dml import ValuesBase +from sqlalchemy.sql.expression import type_coerce + from databases.importer import import_from_string from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend @@ -294,11 +297,51 @@ def _build_query( query = text(query) return query.bindparams(**values) if values is not None else query - elif values: + + # 2 paths where we apply column defaults: + # - values are supplied (the object must be a ValuesBase) + # - values is None but the object is a ValuesBase + if values is not None and not isinstance(query, ValuesBase): + raise TypeError("values supplied but query doesn't support .values()") + + if values is not None or isinstance(query, ValuesBase): + values = Connection._apply_column_defaults(query, values) return query.values(**values) return query + @staticmethod + def _apply_column_defaults(query: ValuesBase, values: dict = None) -> dict: + """Add default values from the table of a query.""" + new_values = {} + values = values or {} + + for column in query.table.c: + if column.name in values: + continue + + if column.default: + default = column.default + + if default.is_sequence: # pragma: no cover + # TODO: support sequences + continue + elif default.is_callable: + value = default.arg(FakeExecutionContext()) + elif default.is_clause_element: # pragma: no cover + # TODO: implement clause element + # For this, the _build_query method needs to + # become an instance method so that it can access + # self._connection. + continue + else: + value = default.arg + + new_values[column.name] = value + + new_values.update(values) + return new_values + class Transaction: def __init__( @@ -489,3 +532,20 @@ def __repr__(self) -> str: def __eq__(self, other: typing.Any) -> bool: return str(self) == str(other) + + +class FakeExecutionContext: + """ + This is an object that raises an error when one of its properties are + attempted to be accessed. Because we're not _really_ using SQLAlchemy + (besides using its query builder), we can't pass a real ExecutionContext + to ColumnDefault objects. This class makes it so that any attempts to + access the execution context argument by a column default callable + blows up loudly and clearly. + """ + + def __getattr__(self, _: str) -> typing.NoReturn: # pragma: no cover + raise NotImplementedError( + "Databases does not have a real SQLAlchemy ExecutionContext " + "implementation." + ) diff --git a/tests/test_databases.py b/tests/test_databases.py index bc5382bd..d0287a82 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -70,6 +70,17 @@ def process_result_value(self, value, dialect): sqlalchemy.Column("price", sqlalchemy.Numeric(precision=30, scale=20)), ) +# Used to test column default values +timestamps = sqlalchemy.Table( + "timestamps", + metadata, + sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True), + sqlalchemy.Column( + "timestamp", sqlalchemy.DateTime, default=datetime.datetime.now, nullable=False + ), + sqlalchemy.Column("priority", sqlalchemy.Integer, default=0, nullable=False), +) + @pytest.fixture(autouse=True, scope="module") def create_test_database(): @@ -925,3 +936,32 @@ async def test_column_names(database_url, select_query): assert sorted(results[0].keys()) == ["completed", "id", "text"] assert results[0]["text"] == "example1" assert results[0]["completed"] == True + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_column_defaults(database_url): + """ + Test correct usage of column defaults. + """ + async with Database(database_url) as database: + async with database.transaction(force_rollback=True): + # with just defaults + query = timestamps.insert() + await database.execute(query) + results = await database.fetch_all(query=timestamps.select()) + assert len(results) == 1 + await database.execute(timestamps.delete()) + + # with default value overridden + dt = datetime.datetime.now() - datetime.timedelta(seconds=10) + values = {"timestamp": dt} + await database.execute(query, values) + results = await database.fetch_all(timestamps.select()) + assert len(results) == 1 + assert results[0]["timestamp"] == dt + + # testing invalid passing of values with non-ValuesBase + # argument + with pytest.raises(TypeError, match=r".*support \.values\(\).*"): + await database.execute(timestamps.select(), {}) From f66013e4e4c3a04fe3ae068e02dbb78d010f8156 Mon Sep 17 00:00:00 2001 From: Ethan Ralph <63806108+GarbageHamburger@users.noreply.github.com> Date: Mon, 4 May 2020 16:41:44 +0000 Subject: [PATCH 2/3] Fix column defaults test for MySQL MySQL rounds datetimes to seconds. --- tests/test_databases.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_databases.py b/tests/test_databases.py index d0287a82..f9efee09 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -954,7 +954,8 @@ async def test_column_defaults(database_url): await database.execute(timestamps.delete()) # with default value overridden - dt = datetime.datetime.now() - datetime.timedelta(seconds=10) + dt = datetime.datetime.now() + dt -= datetime.timedelta(seconds=10, microseconds=dt.microsecond) values = {"timestamp": dt} await database.execute(query, values) results = await database.fetch_all(timestamps.select()) From 8803eed1dc21a2305866b861c30072f35b5be9fb Mon Sep 17 00:00:00 2001 From: GarbageHamburger <63806108+GarbageHamburger@users.noreply.github.com> Date: Mon, 4 May 2020 18:05:46 +0000 Subject: [PATCH 3/3] Remove double blank line MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Rafał Pitoń --- databases/core.py | 1 - 1 file changed, 1 deletion(-) diff --git a/databases/core.py b/databases/core.py index e3f66a53..098b7259 100644 --- a/databases/core.py +++ b/databases/core.py @@ -12,7 +12,6 @@ from sqlalchemy.sql.dml import ValuesBase from sqlalchemy.sql.expression import type_coerce - from databases.importer import import_from_string from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend