diff --git a/databases/backends/aiopg.py b/databases/backends/aiopg.py index 1b73fa62..9ad12f63 100644 --- a/databases/backends/aiopg.py +++ b/databases/backends/aiopg.py @@ -14,7 +14,12 @@ from sqlalchemy.sql.ddl import DDLElement from databases.core import DatabaseURL -from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend +from databases.interfaces import ( + ConnectionBackend, + DatabaseBackend, + Record, + TransactionBackend, +) logger = logging.getLogger("databases") @@ -112,7 +117,7 @@ async def release(self) -> None: await self._database._pool.release(self._connection) self._connection = None - async def fetch_all(self, query: ClauseElement) -> typing.List[typing.Sequence]: + async def fetch_all(self, query: ClauseElement) -> typing.List[Record]: assert self._connection is not None, "Connection is not acquired" query_str, args, context = self._compile(query) cursor = await self._connection.cursor() @@ -133,7 +138,7 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[typing.Sequence]: finally: cursor.close() - async def fetch_one(self, query: ClauseElement) -> typing.Optional[typing.Sequence]: + async def fetch_one(self, query: ClauseElement) -> typing.Optional[Record]: assert self._connection is not None, "Connection is not acquired" query_str, args, context = self._compile(query) cursor = await self._connection.cursor() diff --git a/databases/backends/asyncmy.py b/databases/backends/asyncmy.py index c9b6611f..e15dfa45 100644 --- a/databases/backends/asyncmy.py +++ b/databases/backends/asyncmy.py @@ -12,7 +12,12 @@ from sqlalchemy.sql.ddl import DDLElement from databases.core import LOG_EXTRA, DatabaseURL -from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend +from databases.interfaces import ( + ConnectionBackend, + DatabaseBackend, + Record, + TransactionBackend, +) logger = logging.getLogger("databases") @@ -100,7 +105,7 @@ async def release(self) -> None: await self._database._pool.release(self._connection) self._connection = None - async def fetch_all(self, query: ClauseElement) -> typing.List[typing.Sequence]: + async def fetch_all(self, query: ClauseElement) -> typing.List[Record]: assert self._connection is not None, "Connection is not acquired" query_str, args, context = self._compile(query) async with self._connection.cursor() as cursor: @@ -121,7 +126,7 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[typing.Sequence]: finally: await cursor.close() - async def fetch_one(self, query: ClauseElement) -> typing.Optional[typing.Sequence]: + async def fetch_one(self, query: ClauseElement) -> typing.Optional[Record]: assert self._connection is not None, "Connection is not acquired" query_str, args, context = self._compile(query) async with self._connection.cursor() as cursor: diff --git a/databases/backends/mysql.py b/databases/backends/mysql.py index 4c490d71..2a0a8425 100644 --- a/databases/backends/mysql.py +++ b/databases/backends/mysql.py @@ -12,7 +12,12 @@ from sqlalchemy.sql.ddl import DDLElement from databases.core import LOG_EXTRA, DatabaseURL -from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend +from databases.interfaces import ( + ConnectionBackend, + DatabaseBackend, + Record, + TransactionBackend, +) logger = logging.getLogger("databases") @@ -100,7 +105,7 @@ async def release(self) -> None: await self._database._pool.release(self._connection) self._connection = None - async def fetch_all(self, query: ClauseElement) -> typing.List[typing.Sequence]: + async def fetch_all(self, query: ClauseElement) -> typing.List[Record]: assert self._connection is not None, "Connection is not acquired" query_str, args, context = self._compile(query) cursor = await self._connection.cursor() @@ -121,7 +126,7 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[typing.Sequence]: finally: await cursor.close() - async def fetch_one(self, query: ClauseElement) -> typing.Optional[typing.Sequence]: + async def fetch_one(self, query: ClauseElement) -> typing.Optional[Record]: assert self._connection is not None, "Connection is not acquired" query_str, args, context = self._compile(query) cursor = await self._connection.cursor() diff --git a/databases/backends/postgres.py b/databases/backends/postgres.py index 472da07a..3d0a36f2 100644 --- a/databases/backends/postgres.py +++ b/databases/backends/postgres.py @@ -11,7 +11,12 @@ from sqlalchemy.types import TypeEngine from databases.core import LOG_EXTRA, DatabaseURL -from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend +from databases.interfaces import ( + ConnectionBackend, + DatabaseBackend, + Record as RecordInterface, + TransactionBackend, +) logger = logging.getLogger("databases") @@ -78,7 +83,7 @@ def connection(self) -> "PostgresConnection": return PostgresConnection(self, self._dialect) -class Record(Sequence): +class Record(RecordInterface): __slots__ = ( "_row", "_result_columns", @@ -105,7 +110,7 @@ def __init__( self._column_map, self._column_map_int, self._column_map_full = column_maps @property - def _mapping(self) -> asyncpg.Record: + def _mapping(self) -> typing.Mapping: return self._row def keys(self) -> typing.KeysView: @@ -171,7 +176,7 @@ async def release(self) -> None: self._connection = await self._database._pool.release(self._connection) self._connection = None - async def fetch_all(self, query: ClauseElement) -> typing.List[typing.Sequence]: + async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]: assert self._connection is not None, "Connection is not acquired" query_str, args, result_columns = self._compile(query) rows = await self._connection.fetch(query_str, *args) @@ -179,7 +184,7 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[typing.Sequence]: column_maps = self._create_column_maps(result_columns) return [Record(row, result_columns, dialect, column_maps) for row in rows] - async def fetch_one(self, query: ClauseElement) -> typing.Optional[typing.Sequence]: + async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]: assert self._connection is not None, "Connection is not acquired" query_str, args, result_columns = self._compile(query) row = await self._connection.fetchrow(query_str, *args) diff --git a/databases/backends/sqlite.py b/databases/backends/sqlite.py index 46a39519..9626dcf8 100644 --- a/databases/backends/sqlite.py +++ b/databases/backends/sqlite.py @@ -11,7 +11,12 @@ from sqlalchemy.sql.ddl import DDLElement from databases.core import LOG_EXTRA, DatabaseURL -from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend +from databases.interfaces import ( + ConnectionBackend, + DatabaseBackend, + Record, + TransactionBackend, +) logger = logging.getLogger("databases") @@ -86,7 +91,7 @@ async def release(self) -> None: await self._pool.release(self._connection) self._connection = None - async def fetch_all(self, query: ClauseElement) -> typing.List[typing.Sequence]: + async def fetch_all(self, query: ClauseElement) -> typing.List[Record]: assert self._connection is not None, "Connection is not acquired" query_str, args, context = self._compile(query) @@ -104,7 +109,7 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[typing.Sequence]: for row in rows ] - async def fetch_one(self, query: ClauseElement) -> typing.Optional[typing.Sequence]: + async def fetch_one(self, query: ClauseElement) -> typing.Optional[Record]: assert self._connection is not None, "Connection is not acquired" query_str, args, context = self._compile(query) diff --git a/databases/core.py b/databases/core.py index b3c2b440..893eb37e 100644 --- a/databases/core.py +++ b/databases/core.py @@ -11,7 +11,12 @@ from sqlalchemy.sql import ClauseElement from databases.importer import import_from_string -from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend +from databases.interfaces import ( + ConnectionBackend, + DatabaseBackend, + Record, + TransactionBackend, +) if sys.version_info >= (3, 7): # pragma: no cover import contextvars as contextvars @@ -144,13 +149,13 @@ async def __aexit__( async def fetch_all( self, query: typing.Union[ClauseElement, str], values: dict = None - ) -> typing.List[typing.Sequence]: + ) -> typing.List[Record]: async with self.connection() as connection: return await connection.fetch_all(query, values) async def fetch_one( self, query: typing.Union[ClauseElement, str], values: dict = None - ) -> typing.Optional[typing.Sequence]: + ) -> typing.Optional[Record]: async with self.connection() as connection: return await connection.fetch_one(query, values) @@ -265,14 +270,14 @@ async def __aexit__( async def fetch_all( self, query: typing.Union[ClauseElement, str], values: dict = None - ) -> typing.List[typing.Sequence]: + ) -> typing.List[Record]: built_query = self._build_query(query, values) async with self._query_lock: return await self._connection.fetch_all(built_query) async def fetch_one( self, query: typing.Union[ClauseElement, str], values: dict = None - ) -> typing.Optional[typing.Sequence]: + ) -> typing.Optional[Record]: built_query = self._build_query(query, values) async with self._query_lock: return await self._connection.fetch_one(built_query) diff --git a/databases/interfaces.py b/databases/interfaces.py index 9bf24435..c2109a23 100644 --- a/databases/interfaces.py +++ b/databases/interfaces.py @@ -1,4 +1,5 @@ import typing +from collections.abc import Sequence from sqlalchemy.sql import ClauseElement @@ -21,10 +22,10 @@ async def acquire(self) -> None: async def release(self) -> None: raise NotImplementedError() # pragma: no cover - async def fetch_all(self, query: ClauseElement) -> typing.List[typing.Sequence]: + async def fetch_all(self, query: ClauseElement) -> typing.List["Record"]: raise NotImplementedError() # pragma: no cover - async def fetch_one(self, query: ClauseElement) -> typing.Optional[typing.Sequence]: + async def fetch_one(self, query: ClauseElement) -> typing.Optional["Record"]: raise NotImplementedError() # pragma: no cover async def fetch_val( @@ -66,3 +67,9 @@ async def commit(self) -> None: async def rollback(self) -> None: raise NotImplementedError() # pragma: no cover + + +class Record(Sequence): + @property + def _mapping(self) -> typing.Mapping: + raise NotImplementedError() # pragma: no cover diff --git a/docs/database_queries.md b/docs/database_queries.md index 11721237..898e7343 100644 --- a/docs/database_queries.md +++ b/docs/database_queries.md @@ -108,3 +108,20 @@ Note that query arguments should follow the `:query_arg` style. [sqlalchemy-core]: https://docs.sqlalchemy.org/en/latest/core/ [sqlalchemy-core-tutorial]: https://docs.sqlalchemy.org/en/latest/core/tutorial.html + +## Query result + +To keep in line with [SQLAlchemy 1.4 changes][sqlalchemy-mapping-changes] +query result object no longer implements a mapping interface. +To access query result as a mapping you should use the `_mapping` property. +That way you can process both SQLAlchemy Rows and databases Records from raw queries +with the same function without any instance checks. + +```python +query = "SELECT * FROM notes WHERE id = :id" +result = await database.fetch_one(query=query, values={"id": 1}) +result.id # access field via attribute +result._mapping['id'] # access field via mapping +``` + +[sqlalchemy-mapping-changes]: https://docs.sqlalchemy.org/en/14/changelog/migration_14.html#rowproxy-is-no-longer-a-proxy-is-now-called-row-and-behaves-like-an-enhanced-named-tuple diff --git a/tests/test_databases.py b/tests/test_databases.py index 5f91aa37..ad1fb542 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -1222,3 +1222,25 @@ async def test_result_named_access(database_url): assert result.text == "example1" assert result.completed is True + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@mysql_versions +@async_adapter +async def test_mapping_property_interface(database_url): + """ + Test that all connections implement interface with `_mapping` property + """ + async with Database(database_url) as database: + query = notes.insert() + values = {"text": "example1", "completed": True} + await database.execute(query, values) + + query = notes.select() + single_result = await database.fetch_one(query=query) + assert single_result._mapping["text"] == "example1" + assert single_result._mapping["completed"] is True + + list_result = await database.fetch_all(query=query) + assert list_result[0]._mapping["text"] == "example1" + assert list_result[0]._mapping["completed"] is True