diff --git a/databases/core.py b/databases/core.py index cb201be3..098b7259 100644 --- a/databases/core.py +++ b/databases/core.py @@ -9,6 +9,8 @@ from sqlalchemy import text from sqlalchemy.sql import ClauseElement +from sqlalchemy.sql.dml import ValuesBase +from sqlalchemy.sql.expression import type_coerce from databases.importer import import_from_string from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend @@ -294,11 +296,51 @@ def _build_query( query = text(query) return query.bindparams(**values) if values is not None else query - elif values: + + # 2 paths where we apply column defaults: + # - values are supplied (the object must be a ValuesBase) + # - values is None but the object is a ValuesBase + if values is not None and not isinstance(query, ValuesBase): + raise TypeError("values supplied but query doesn't support .values()") + + if values is not None or isinstance(query, ValuesBase): + values = Connection._apply_column_defaults(query, values) return query.values(**values) return query + @staticmethod + def _apply_column_defaults(query: ValuesBase, values: dict = None) -> dict: + """Add default values from the table of a query.""" + new_values = {} + values = values or {} + + for column in query.table.c: + if column.name in values: + continue + + if column.default: + default = column.default + + if default.is_sequence: # pragma: no cover + # TODO: support sequences + continue + elif default.is_callable: + value = default.arg(FakeExecutionContext()) + elif default.is_clause_element: # pragma: no cover + # TODO: implement clause element + # For this, the _build_query method needs to + # become an instance method so that it can access + # self._connection. + continue + else: + value = default.arg + + new_values[column.name] = value + + new_values.update(values) + return new_values + class Transaction: def __init__( @@ -489,3 +531,20 @@ def __repr__(self) -> str: def __eq__(self, other: typing.Any) -> bool: return str(self) == str(other) + + +class FakeExecutionContext: + """ + This is an object that raises an error when one of its properties are + attempted to be accessed. Because we're not _really_ using SQLAlchemy + (besides using its query builder), we can't pass a real ExecutionContext + to ColumnDefault objects. This class makes it so that any attempts to + access the execution context argument by a column default callable + blows up loudly and clearly. + """ + + def __getattr__(self, _: str) -> typing.NoReturn: # pragma: no cover + raise NotImplementedError( + "Databases does not have a real SQLAlchemy ExecutionContext " + "implementation." + ) diff --git a/tests/test_databases.py b/tests/test_databases.py index bc5382bd..f9efee09 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -70,6 +70,17 @@ def process_result_value(self, value, dialect): sqlalchemy.Column("price", sqlalchemy.Numeric(precision=30, scale=20)), ) +# Used to test column default values +timestamps = sqlalchemy.Table( + "timestamps", + metadata, + sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True), + sqlalchemy.Column( + "timestamp", sqlalchemy.DateTime, default=datetime.datetime.now, nullable=False + ), + sqlalchemy.Column("priority", sqlalchemy.Integer, default=0, nullable=False), +) + @pytest.fixture(autouse=True, scope="module") def create_test_database(): @@ -925,3 +936,33 @@ async def test_column_names(database_url, select_query): assert sorted(results[0].keys()) == ["completed", "id", "text"] assert results[0]["text"] == "example1" assert results[0]["completed"] == True + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_column_defaults(database_url): + """ + Test correct usage of column defaults. + """ + async with Database(database_url) as database: + async with database.transaction(force_rollback=True): + # with just defaults + query = timestamps.insert() + await database.execute(query) + results = await database.fetch_all(query=timestamps.select()) + assert len(results) == 1 + await database.execute(timestamps.delete()) + + # with default value overridden + dt = datetime.datetime.now() + dt -= datetime.timedelta(seconds=10, microseconds=dt.microsecond) + values = {"timestamp": dt} + await database.execute(query, values) + results = await database.fetch_all(timestamps.select()) + assert len(results) == 1 + assert results[0]["timestamp"] == dt + + # testing invalid passing of values with non-ValuesBase + # argument + with pytest.raises(TypeError, match=r".*support \.values\(\).*"): + await database.execute(timestamps.select(), {})