From 8724ad732bb2917ce07611c76774243f077ee08f Mon Sep 17 00:00:00 2001 From: Micheal Gendy Date: Sun, 6 Mar 2022 23:52:52 +0200 Subject: [PATCH] allow iterate over custom num of records --- databases/backends/aiopg.py | 6 +++++- databases/backends/asyncmy.py | 6 +++++- databases/backends/mysql.py | 6 +++++- databases/backends/postgres.py | 7 +++++-- databases/backends/sqlite.py | 6 +++++- databases/core.py | 14 ++++++++++---- databases/interfaces.py | 2 +- tests/test_databases.py | 11 +++++++++++ 8 files changed, 47 insertions(+), 11 deletions(-) diff --git a/databases/backends/aiopg.py b/databases/backends/aiopg.py index 9ad12f63..fbc4aef9 100644 --- a/databases/backends/aiopg.py +++ b/databases/backends/aiopg.py @@ -179,7 +179,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None: cursor.close() async def iterate( - self, query: ClauseElement + self, query: ClauseElement, n: int = None ) -> typing.AsyncGenerator[typing.Any, None]: assert self._connection is not None, "Connection is not acquired" query_str, args, context = self._compile(query) @@ -195,6 +195,10 @@ async def iterate( Row._default_key_style, row, ) + if n is not None: + n -= 1 + if n == 0: + break finally: cursor.close() diff --git a/databases/backends/asyncmy.py b/databases/backends/asyncmy.py index e15dfa45..9e5b4570 100644 --- a/databases/backends/asyncmy.py +++ b/databases/backends/asyncmy.py @@ -169,7 +169,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None: await cursor.close() async def iterate( - self, query: ClauseElement + self, query: ClauseElement, n: int = None ) -> typing.AsyncGenerator[typing.Any, None]: assert self._connection is not None, "Connection is not acquired" query_str, args, context = self._compile(query) @@ -185,6 +185,10 @@ async def iterate( Row._default_key_style, row, ) + if n is not None: + n -= 1 + if n == 0: + break finally: await cursor.close() diff --git a/databases/backends/mysql.py b/databases/backends/mysql.py index 2a0a8425..03fd18ba 100644 --- a/databases/backends/mysql.py +++ b/databases/backends/mysql.py @@ -169,7 +169,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None: await cursor.close() async def iterate( - self, query: ClauseElement + self, query: ClauseElement, n: int = None ) -> typing.AsyncGenerator[typing.Any, None]: assert self._connection is not None, "Connection is not acquired" query_str, args, context = self._compile(query) @@ -185,6 +185,10 @@ async def iterate( Row._default_key_style, row, ) + if n is not None: + n -= 1 + if n == 0: + break finally: await cursor.close() diff --git a/databases/backends/postgres.py b/databases/backends/postgres.py index 3d0a36f2..ebd3a61c 100644 --- a/databases/backends/postgres.py +++ b/databases/backends/postgres.py @@ -1,6 +1,5 @@ import logging import typing -from collections.abc import Sequence import asyncpg from sqlalchemy.dialects.postgresql import pypostgresql @@ -227,13 +226,17 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None: await self._connection.execute(single_query, *args) async def iterate( - self, query: ClauseElement + self, query: ClauseElement, n: int = None ) -> typing.AsyncGenerator[typing.Any, None]: assert self._connection is not None, "Connection is not acquired" query_str, args, result_columns = self._compile(query) column_maps = self._create_column_maps(result_columns) async for row in self._connection.cursor(query_str, *args): yield Record(row, result_columns, self._dialect, column_maps) + if n is not None: + n -= 1 + if n == 0: + break def transaction(self) -> TransactionBackend: return PostgresTransaction(connection=self) diff --git a/databases/backends/sqlite.py b/databases/backends/sqlite.py index 9626dcf8..5f55a21d 100644 --- a/databases/backends/sqlite.py +++ b/databases/backends/sqlite.py @@ -141,7 +141,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None: await self.execute(single_query) async def iterate( - self, query: ClauseElement + self, query: ClauseElement, n: int = None ) -> typing.AsyncGenerator[typing.Any, None]: assert self._connection is not None, "Connection is not acquired" query_str, args, context = self._compile(query) @@ -155,6 +155,10 @@ async def iterate( Row._default_key_style, row, ) + if n is not None: + n -= 1 + if n == 0: + break def transaction(self) -> TransactionBackend: return SQLiteTransaction(self) diff --git a/databases/core.py b/databases/core.py index 7005281c..e584dea3 100644 --- a/databases/core.py +++ b/databases/core.py @@ -181,10 +181,13 @@ async def execute_many( return await connection.execute_many(query, values) async def iterate( - self, query: typing.Union[ClauseElement, str], values: dict = None + self, + query: typing.Union[ClauseElement, str], + values: dict = None, + n: int = None, ) -> typing.AsyncGenerator[typing.Mapping, None]: async with self.connection() as connection: - async for record in connection.iterate(query, values): + async for record in connection.iterate(query, values, n): yield record def _new_connection(self) -> "Connection": @@ -307,12 +310,15 @@ async def execute_many( await self._connection.execute_many(queries) async def iterate( - self, query: typing.Union[ClauseElement, str], values: dict = None + self, + query: typing.Union[ClauseElement, str], + values: dict = None, + n: int = None, ) -> typing.AsyncGenerator[typing.Any, None]: built_query = self._build_query(query, values) async with self.transaction(): async with self._query_lock: - async for record in self._connection.iterate(built_query): + async for record in self._connection.iterate(built_query, n): yield record def transaction( diff --git a/databases/interfaces.py b/databases/interfaces.py index c2109a23..7c44a5d8 100644 --- a/databases/interfaces.py +++ b/databases/interfaces.py @@ -41,7 +41,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None: raise NotImplementedError() # pragma: no cover async def iterate( - self, query: ClauseElement + self, query: ClauseElement, n: int = None ) -> typing.AsyncGenerator[typing.Mapping, None]: raise NotImplementedError() # pragma: no cover # mypy needs async iterators to contain a `yield` diff --git a/tests/test_databases.py b/tests/test_databases.py index 7a0b84fd..215432b4 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -221,6 +221,17 @@ async def test_queries(database_url): assert iterate_results[2]["text"] == "example3" assert iterate_results[2]["completed"] == True + # iterate() with custom number of records + query = notes.select() + iterate_results = [] + async for result in database.iterate(query=query, n=2): + iterate_results.append(result) + assert len(iterate_results) == 2 + assert iterate_results[0]["text"] == "example1" + assert iterate_results[0]["completed"] == True + assert iterate_results[1]["text"] == "example2" + assert iterate_results[1]["completed"] == False + @pytest.mark.parametrize("database_url", DATABASE_URLS) @mysql_versions