diff --git a/databases/backends/common.py b/databases/backends/common.py new file mode 100644 index 00000000..ca4c3fc4 --- /dev/null +++ b/databases/backends/common.py @@ -0,0 +1,33 @@ +import typing + +from sqlalchemy import ColumnDefault +from sqlalchemy.engine.default import DefaultDialect + + +class ConstructDefaultParamsMixin: + """ + A mixin to support column default values for insert queries for asyncpg, + aiomysql and aiosqlite + """ + + prefetch: typing.List + dialect: DefaultDialect + + def construct_params( + self, + params: typing.Optional[typing.Mapping] = None, + _group_number: typing.Any = None, + _check: bool = True, + ) -> typing.Dict: + pd = super().construct_params(params, _group_number, _check) # type: ignore + + for column in self.prefetch: + pd[column.key] = self._exec_default(column.default) + + return pd + + def _exec_default(self, default: ColumnDefault) -> typing.Any: + if default.is_callable: + return default.arg(self.dialect) + else: + return default.arg diff --git a/databases/backends/mysql.py b/databases/backends/mysql.py index b6476add..fc10bfcb 100644 --- a/databases/backends/mysql.py +++ b/databases/backends/mysql.py @@ -9,24 +9,35 @@ from sqlalchemy.engine.result import ResultMetaData, RowProxy from sqlalchemy.sql import ClauseElement from sqlalchemy.sql.ddl import DDLElement -from sqlalchemy.types import TypeEngine +from databases.backends.common import ConstructDefaultParamsMixin from databases.core import LOG_EXTRA, DatabaseURL from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend logger = logging.getLogger("databases") +class MySQLCompiler(ConstructDefaultParamsMixin, pymysql.dialect.statement_compiler): + pass + + class MySQLBackend(DatabaseBackend): def __init__( self, database_url: typing.Union[DatabaseURL, str], **options: typing.Any ) -> None: self._database_url = DatabaseURL(database_url) self._options = options - self._dialect = pymysql.dialect(paramstyle="pyformat") - self._dialect.supports_native_decimal = True + self._dialect = self._get_dialect() self._pool = None + def _get_dialect(self) -> Dialect: + dialect = pymysql.dialect(paramstyle="pyformat") + + dialect.statement_compiler = MySQLCompiler + dialect.supports_native_decimal = True + + return dialect + def _get_connection_kwargs(self) -> dict: url_options = self._database_url.options diff --git a/databases/backends/postgres.py b/databases/backends/postgres.py index 8c1d75b1..61732015 100644 --- a/databases/backends/postgres.py +++ b/databases/backends/postgres.py @@ -10,11 +10,20 @@ from sqlalchemy.sql.schema import Column from sqlalchemy.types import TypeEngine +from databases.backends.common import ConstructDefaultParamsMixin from databases.core import LOG_EXTRA, DatabaseURL from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend logger = logging.getLogger("databases") +_result_processors = {} # type: dict + + +class APGCompiler_psycopg2( + ConstructDefaultParamsMixin, pypostgresql.dialect.statement_compiler +): + pass + class PostgresBackend(DatabaseBackend): def __init__( @@ -28,6 +37,7 @@ def __init__( def _get_dialect(self) -> Dialect: dialect = pypostgresql.dialect(paramstyle="pyformat") + dialect.statement_compiler = APGCompiler_psycopg2 dialect.implicit_returning = True dialect.supports_native_enum = True dialect.supports_smallserial = True # 9.2+ diff --git a/databases/backends/sqlite.py b/databases/backends/sqlite.py index 28ceb6fb..0339efce 100644 --- a/databases/backends/sqlite.py +++ b/databases/backends/sqlite.py @@ -8,25 +8,36 @@ from sqlalchemy.engine.result import ResultMetaData, RowProxy from sqlalchemy.sql import ClauseElement from sqlalchemy.sql.ddl import DDLElement -from sqlalchemy.types import TypeEngine +from databases.backends.common import ConstructDefaultParamsMixin from databases.core import LOG_EXTRA, DatabaseURL from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend logger = logging.getLogger("databases") +class SQLiteCompiler(ConstructDefaultParamsMixin, pysqlite.dialect.statement_compiler): + pass + + class SQLiteBackend(DatabaseBackend): def __init__( self, database_url: typing.Union[DatabaseURL, str], **options: typing.Any ) -> None: self._database_url = DatabaseURL(database_url) self._options = options - self._dialect = pysqlite.dialect(paramstyle="qmark") - # aiosqlite does not support decimals - self._dialect.supports_native_decimal = False + self._dialect = self._get_dialect() self._pool = SQLitePool(self._database_url, **self._options) + def _get_dialect(self) -> Dialect: + dialect = pysqlite.dialect(paramstyle="qmark") + + # aiosqlite does not support decimals + dialect.supports_native_decimal = False + dialect.statement_compiler = SQLiteCompiler + + return dialect + async def connect(self) -> None: pass # assert self._pool is None, "DatabaseBackend is already running" diff --git a/databases/core.py b/databases/core.py index 2bab6735..cec261e7 100644 --- a/databases/core.py +++ b/databases/core.py @@ -301,6 +301,11 @@ def _build_query( elif values: return query.values(**values) + # for case when `table.insert()` called without `.values()` it has to be + # called to produce `insert_prefetch` for compiled query + if query.__visit_name__ == "insert": + return query.values() + return query diff --git a/tests/test_databases.py b/tests/test_databases.py index c7317688..5f078a2f 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -70,6 +70,20 @@ def process_result_value(self, value, dialect): sqlalchemy.Column("price", sqlalchemy.Numeric(precision=30, scale=20)), ) +# Used to test column default values +default_values = sqlalchemy.Table( + "default_values", + metadata, + sqlalchemy.Column("id", sqlalchemy.Integer, primary_key=True), + sqlalchemy.Column("with_default", sqlalchemy.Integer, default=42), + sqlalchemy.Column( + "with_callable_default", + sqlalchemy.String(length=100), + default=lambda: "default_value", + ), + sqlalchemy.Column("without_default", sqlalchemy.Integer), +) + @pytest.fixture(autouse=True, scope="module") def create_test_database(): @@ -651,6 +665,84 @@ async def test_json_field(database_url): assert results[0]["data"] == {"text": "hello", "boolean": True, "int": 1} +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_insert_with_scalar_default(database_url): + """ + Test insert with scalar column default value + """ + + async with Database(database_url) as database: + async with database.transaction(force_rollback=True): + query = default_values.insert() + values = {"without_default": 1} + await database.execute(query, values) + + query = default_values.select().order_by(default_values.c.id.desc()) + result = await database.fetch_one(query=query) + + assert result["with_default"] == 42 + assert result["without_default"] == values["without_default"] + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_insert_default_values_with_no_values_called(database_url): + """ + Test insert default values without calling ``values()`` on insert and + without passing ``values`` to ``execute()``. + """ + + async with Database(database_url) as database: + async with database.transaction(force_rollback=True): + query = default_values.insert() + await database.execute(query) + + query = default_values.select().order_by(default_values.c.id.desc()) + result = await database.fetch_one(query=query) + + assert result["with_default"] == 42 + assert result["without_default"] is None + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_insert_default_values_with_overriden_default(database_url): + """ + Test if we provide value for a column having default value, the first one + should be set, not default one. + """ + + async with Database(database_url) as database: + async with database.transaction(force_rollback=True): + query = default_values.insert() + values = {"with_default": 5} + await database.execute(query, values) + + query = default_values.select().order_by(default_values.c.id.desc()) + result = await database.fetch_one(query=query) + + assert result["with_default"] == values["with_default"] + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_insert_callable_default(database_url): + """ + Test insert with column having callable default. + """ + + async with Database(database_url) as database: + async with database.transaction(force_rollback=True): + query = default_values.insert() + await database.execute(query) + + query = default_values.select().order_by(default_values.c.id.desc()) + result = await database.fetch_one(query=query) + + assert result["with_callable_default"] == "default_value" + + @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter async def test_custom_field(database_url): @@ -917,6 +1009,7 @@ async def run_database_queries(): async with database: async def db_lookup(): + await database.fetch_one("SELECT pg_sleep(1)") await asyncio.gather(db_lookup(), db_lookup())