Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement column defaults for INSERT/UPDATE #206

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 60 additions & 1 deletion databases/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

from sqlalchemy import text
from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql.dml import ValuesBase
from sqlalchemy.sql.expression import type_coerce

GarbageHamburger marked this conversation as resolved.
Show resolved Hide resolved
from databases.importer import import_from_string
from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend
Expand Down Expand Up @@ -294,11 +296,51 @@ def _build_query(
query = text(query)

return query.bindparams(**values) if values is not None else query
elif values:

# 2 paths where we apply column defaults:
# - values are supplied (the object must be a ValuesBase)
# - values is None but the object is a ValuesBase
if values is not None and not isinstance(query, ValuesBase):
raise TypeError("values supplied but query doesn't support .values()")

if values is not None or isinstance(query, ValuesBase):
values = Connection._apply_column_defaults(query, values)
Comment on lines +299 to +307
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# 2 paths where we apply column defaults:
# - values are supplied (the object must be a ValuesBase)
# - values is None but the object is a ValuesBase
if values is not None and not isinstance(query, ValuesBase):
raise TypeError("values supplied but query doesn't support .values()")
if values is not None or isinstance(query, ValuesBase):
values = Connection._apply_column_defaults(query, values)
elif values:
assert isinstance(query, ValuesBase)
values = self._apply_column_defaults(query, values)

return query.values(**values)

return query

@staticmethod
def _apply_column_defaults(query: ValuesBase, values: dict = None) -> dict:
"""Add default values from the table of a query."""
new_values = {}
values = values or {}

for column in query.table.c:
if column.name in values:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

values can have columns as keys, in such cases this will be always false

continue

if column.default:
default = column.default

if default.is_sequence: # pragma: no cover
# TODO: support sequences
continue
Comment on lines +325 to +327
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert not default.is_sequence, "sequences are not supported, PRs welcome"

elif default.is_callable:
value = default.arg(FakeExecutionContext())
elif default.is_clause_element: # pragma: no cover
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
elif default.is_clause_element: # pragma: no cover
assert not default.is_clause_element, "clause defaults are not supported, PRs welcome"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • can you please group assertions together

# TODO: implement clause element
# For this, the _build_query method needs to
# become an instance method so that it can access
# self._connection.
continue
else:
value = default.arg

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

new_values[column.name] = value

new_values.update(values)
return new_values


class Transaction:
def __init__(
Expand Down Expand Up @@ -489,3 +531,20 @@ def __repr__(self) -> str:

def __eq__(self, other: typing.Any) -> bool:
return str(self) == str(other)


class FakeExecutionContext:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class FakeExecutionContext:
class DummyExecutionContext:

(it's not completely fake, after all)

"""
This is an object that raises an error when one of its properties are
attempted to be accessed. Because we're not _really_ using SQLAlchemy
(besides using its query builder), we can't pass a real ExecutionContext
to ColumnDefault objects. This class makes it so that any attempts to
access the execution context argument by a column default callable
blows up loudly and clearly.
"""

def __getattr__(self, _: str) -> typing.NoReturn: # pragma: no cover
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should cover this by another test which tests raising NotImplementedError

raise NotImplementedError(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw: my own custom implementation of this (yep, I have a similar hack in my prod), I pass the current values to this context so that it becomes much more useful.

"Databases does not have a real SQLAlchemy ExecutionContext "
"implementation."
)
41 changes: 41 additions & 0 deletions tests/test_databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,17 @@ def process_result_value(self, value, dialect):
sqlalchemy.Column("price", sqlalchemy.Numeric(precision=30, scale=20)),
)

# Used to test column default values
timestamps = sqlalchemy.Table(
"timestamps",
metadata,
sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True),
sqlalchemy.Column(
"timestamp", sqlalchemy.DateTime, default=datetime.datetime.now, nullable=False
),
sqlalchemy.Column("priority", sqlalchemy.Integer, default=0, nullable=False),
)


@pytest.fixture(autouse=True, scope="module")
def create_test_database():
Expand Down Expand Up @@ -925,3 +936,33 @@ async def test_column_names(database_url, select_query):
assert sorted(results[0].keys()) == ["completed", "id", "text"]
assert results[0]["text"] == "example1"
assert results[0]["completed"] == True


@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_column_defaults(database_url):
"""
Test correct usage of column defaults.
"""
async with Database(database_url) as database:
async with database.transaction(force_rollback=True):
# with just defaults
query = timestamps.insert()
await database.execute(query)
results = await database.fetch_all(query=timestamps.select())
assert len(results) == 1
await database.execute(timestamps.delete())

# with default value overridden
dt = datetime.datetime.now()
dt -= datetime.timedelta(seconds=10, microseconds=dt.microsecond)
values = {"timestamp": dt}
await database.execute(query, values)
results = await database.fetch_all(timestamps.select())
assert len(results) == 1
assert results[0]["timestamp"] == dt

# testing invalid passing of values with non-ValuesBase
# argument
with pytest.raises(TypeError, match=r".*support \.values\(\).*"):
await database.execute(timestamps.select(), {})