Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix sqlalchemy default values for insert and update queries #266

Closed
wants to merge 10 commits into from
33 changes: 33 additions & 0 deletions databases/backends/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import typing

from sqlalchemy import ColumnDefault
from sqlalchemy.engine.default import DefaultDialect


class ConstructDefaultParamsMixin:
"""
A mixin to support column default values for insert queries for asyncpg,
aiomysql and aiosqlite
"""

prefetch: typing.List
dialect: DefaultDialect

def construct_params(
self,
params: typing.Optional[typing.Mapping] = None,
_group_number: typing.Any = None,
_check: bool = True,
) -> typing.Dict:
pd = super().construct_params(params, _group_number, _check) # type: ignore

for column in self.prefetch:
pd[column.key] = self._exec_default(column.default)

return pd

def _exec_default(self, default: ColumnDefault) -> typing.Any:
if default.is_callable:
return default.arg(self.dialect)
else:
return default.arg
17 changes: 14 additions & 3 deletions databases/backends/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,35 @@
from sqlalchemy.engine.result import ResultMetaData, RowProxy
from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql.ddl import DDLElement
from sqlalchemy.types import TypeEngine

from databases.backends.common import ConstructDefaultParamsMixin
from databases.core import LOG_EXTRA, DatabaseURL
from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend

logger = logging.getLogger("databases")


class MySQLCompiler(ConstructDefaultParamsMixin, pymysql.dialect.statement_compiler):
pass


class MySQLBackend(DatabaseBackend):
def __init__(
self, database_url: typing.Union[DatabaseURL, str], **options: typing.Any
) -> None:
self._database_url = DatabaseURL(database_url)
self._options = options
self._dialect = pymysql.dialect(paramstyle="pyformat")
self._dialect.supports_native_decimal = True
self._dialect = self._get_dialect()
self._pool = None

def _get_dialect(self) -> Dialect:
dialect = pymysql.dialect(paramstyle="pyformat")

dialect.statement_compiler = MySQLCompiler
dialect.supports_native_decimal = True

return dialect

def _get_connection_kwargs(self) -> dict:
url_options = self._database_url.options

Expand Down
10 changes: 10 additions & 0 deletions databases/backends/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,20 @@
from sqlalchemy.sql.schema import Column
from sqlalchemy.types import TypeEngine

from databases.backends.common import ConstructDefaultParamsMixin
from databases.core import LOG_EXTRA, DatabaseURL
from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend

logger = logging.getLogger("databases")

_result_processors = {} # type: dict


class APGCompiler_psycopg2(
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems strange to call this APGCompiler_psycopg2, given that the db backend used for postgres in databases is asyncpg, not psycopg2.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right but that's because aiopg is just a wrapper around psycopg2, asyncpg is not.

ConstructDefaultParamsMixin, pypostgresql.dialect.statement_compiler
):
pass


class PostgresBackend(DatabaseBackend):
def __init__(
Expand All @@ -28,6 +37,7 @@ def __init__(
def _get_dialect(self) -> Dialect:
dialect = pypostgresql.dialect(paramstyle="pyformat")

dialect.statement_compiler = APGCompiler_psycopg2
dialect.implicit_returning = True
dialect.supports_native_enum = True
dialect.supports_smallserial = True # 9.2+
Expand Down
19 changes: 15 additions & 4 deletions databases/backends/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,36 @@
from sqlalchemy.engine.result import ResultMetaData, RowProxy
from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql.ddl import DDLElement
from sqlalchemy.types import TypeEngine

from databases.backends.common import ConstructDefaultParamsMixin
from databases.core import LOG_EXTRA, DatabaseURL
from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend

logger = logging.getLogger("databases")


class SQLiteCompiler(ConstructDefaultParamsMixin, pysqlite.dialect.statement_compiler):
pass


class SQLiteBackend(DatabaseBackend):
def __init__(
self, database_url: typing.Union[DatabaseURL, str], **options: typing.Any
) -> None:
self._database_url = DatabaseURL(database_url)
self._options = options
self._dialect = pysqlite.dialect(paramstyle="qmark")
# aiosqlite does not support decimals
self._dialect.supports_native_decimal = False
self._dialect = self._get_dialect()
self._pool = SQLitePool(self._database_url, **self._options)

def _get_dialect(self) -> Dialect:
dialect = pysqlite.dialect(paramstyle="qmark")

# aiosqlite does not support decimals
dialect.supports_native_decimal = False
dialect.statement_compiler = SQLiteCompiler

return dialect

async def connect(self) -> None:
pass
# assert self._pool is None, "DatabaseBackend is already running"
Expand Down
5 changes: 5 additions & 0 deletions databases/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,11 @@ def _build_query(
elif values:
return query.values(**values)

# for case when `table.insert()` called without `.values()` it has to be
# called to produce `insert_prefetch` for compiled query
if query.__visit_name__ == "insert":
return query.values()

return query


Expand Down
93 changes: 93 additions & 0 deletions tests/test_databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,20 @@ def process_result_value(self, value, dialect):
sqlalchemy.Column("price", sqlalchemy.Numeric(precision=30, scale=20)),
)

# Used to test column default values
default_values = sqlalchemy.Table(
"default_values",
metadata,
sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True),
sqlalchemy.Column("with_default", sqlalchemy.Integer, default=42),
sqlalchemy.Column(
"with_callable_default",
sqlalchemy.String(length=100),
default=lambda: "default_value",
),
sqlalchemy.Column("without_default", sqlalchemy.Integer),
)


@pytest.fixture(autouse=True, scope="module")
def create_test_database():
Expand Down Expand Up @@ -651,6 +665,84 @@ async def test_json_field(database_url):
assert results[0]["data"] == {"text": "hello", "boolean": True, "int": 1}


@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_insert_with_scalar_default(database_url):
"""
Test insert with scalar column default value
"""

async with Database(database_url) as database:
async with database.transaction(force_rollback=True):
query = default_values.insert()
values = {"without_default": 1}
await database.execute(query, values)

query = default_values.select().order_by(default_values.c.id.desc())
result = await database.fetch_one(query=query)

assert result["with_default"] == 42
assert result["without_default"] == values["without_default"]


@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_insert_default_values_with_no_values_called(database_url):
"""
Test insert default values without calling ``values()`` on insert and
without passing ``values`` to ``execute()``.
"""

async with Database(database_url) as database:
async with database.transaction(force_rollback=True):
query = default_values.insert()
await database.execute(query)

query = default_values.select().order_by(default_values.c.id.desc())
result = await database.fetch_one(query=query)

assert result["with_default"] == 42
assert result["without_default"] is None


@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_insert_default_values_with_overriden_default(database_url):
"""
Test if we provide value for a column having default value, the first one
should be set, not default one.
"""

async with Database(database_url) as database:
async with database.transaction(force_rollback=True):
query = default_values.insert()
values = {"with_default": 5}
await database.execute(query, values)

query = default_values.select().order_by(default_values.c.id.desc())
result = await database.fetch_one(query=query)

assert result["with_default"] == values["with_default"]


@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_insert_callable_default(database_url):
"""
Test insert with column having callable default.
"""

async with Database(database_url) as database:
async with database.transaction(force_rollback=True):
query = default_values.insert()
await database.execute(query)

query = default_values.select().order_by(default_values.c.id.desc())
result = await database.fetch_one(query=query)

assert result["with_callable_default"] == "default_value"


@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_custom_field(database_url):
Expand Down Expand Up @@ -917,6 +1009,7 @@ async def run_database_queries():
async with database:

async def db_lookup():

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left \n

await database.fetch_one("SELECT pg_sleep(1)")

await asyncio.gather(db_lookup(), db_lookup())
Expand Down