diff --git a/providers/common/sql/docs/changelog.rst b/providers/common/sql/docs/changelog.rst index 9b7fdec015362..8ced2c43ad869 100644 --- a/providers/common/sql/docs/changelog.rst +++ b/providers/common/sql/docs/changelog.rst @@ -25,6 +25,19 @@ Changelog --------- +.. warning:: + **Breaking Change:** The default execution mode for paginated (``page_size`` + string SQL) GenericTransfer tasks has changed. Previously, these tasks always ran in deferred mode (using deferrable execution). Starting with this release, they now run synchronously by default unless you explicitly opt in to deferrable mode. + + This is a silent behavior change for any existing DAG using paginated GenericTransfer. If you want to restore the old behavior (always defer execution): + + 1. Pass ``deferrable=True`` to each affected GenericTransfer task, **or** + 2. Set the global config option ``[operators] default_deferrable = true`` to make all operators deferrable by default. + + Review your DAGs and configuration if you rely on deferred execution for paginated GenericTransfer tasks. + +.. Below changes are excluded from the changelog. Move them to + appropriate section above if needed. Do not delete the lines(!): + 1.35.0 ...... diff --git a/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.py b/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.py index 043ec22a5c8c7..3fd7b330c44b7 100644 --- a/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.py +++ b/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.py @@ -21,7 +21,7 @@ from functools import cached_property from typing import TYPE_CHECKING, Any -from airflow.providers.common.compat.sdk import AirflowException, BaseHook, BaseOperator +from airflow.providers.common.compat.sdk import AirflowException, BaseHook, BaseOperator, conf from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.providers.common.sql.triggers.sql import SQLExecuteQueryTrigger @@ -55,6 +55,8 @@ class GenericTransfer(BaseOperator): :param insert_args: extra params for `insert_rows` method. :param page_size: number of records to be read in paginated mode (optional). :param paginated_sql_statement_clause: SQL statement clause to be used for pagination (optional). + :param deferrable: Run operator in deferrable mode (only effective in paginated mode, i.e. + when `page_size` is set and `sql` is a string). """ template_fields: Sequence[str] = ( @@ -90,6 +92,7 @@ def __init__( insert_args: dict | None = None, page_size: int | None = None, paginated_sql_statement_clause: str | None = None, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ) -> None: super().__init__(**kwargs) @@ -104,6 +107,7 @@ def __init__( self.insert_args = insert_args or {} self.page_size = page_size self.paginated_sql_statement_clause = paginated_sql_statement_clause or "{} LIMIT {} OFFSET {}" + self.deferrable = deferrable @classmethod def get_hook(cls, conn_id: str, hook_params: dict | None = None) -> DbApiHook: @@ -159,14 +163,32 @@ def execute(self, context: Context): self.destination_hook.run(self.preoperator) if self.page_size and isinstance(self.sql, str): - self.defer( - trigger=SQLExecuteQueryTrigger( - conn_id=self.source_conn_id, - hook_params=self.source_hook_params, - sql=self.get_paginated_sql(0), - ), - method_name=self.execute_complete.__name__, - ) + if self.deferrable: + self.defer( + trigger=SQLExecuteQueryTrigger( + conn_id=self.source_conn_id, + hook_params=self.source_hook_params, + sql=self.get_paginated_sql(0), + ), + method_name=self.execute_complete.__name__, + ) + else: + offset = 0 + while True: + paginated_sql = self.get_paginated_sql(offset) + self.log.info("Executing: \n %s", paginated_sql) + if rows := self.source_hook.get_records(paginated_sql): + self._insert_rows(rows=rows, context=context) + if len(rows) < self.page_size: + break + offset += self.page_size + self.log.info("Offset increased to %d", offset) + else: + self.log.info( + "No more rows to fetch into %s; ending transfer.", + self.destination_table, + ) + break else: if isinstance(self.sql, str): self.sql = [self.sql] diff --git a/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.pyi b/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.pyi index 3194e3877fdc2..a7a69bb452b65 100644 --- a/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.pyi +++ b/providers/common/sql/src/airflow/providers/common/sql/operators/generic_transfer.pyi @@ -55,6 +55,8 @@ class GenericTransfer(BaseOperator): preoperator: Incomplete insert_args: Incomplete page_size: Incomplete + paginated_sql_statement_clause: Incomplete + deferrable: bool def __init__( self, *, @@ -68,6 +70,8 @@ class GenericTransfer(BaseOperator): preoperator: str | list[str] | None = None, insert_args: dict | None = None, page_size: int | None = None, + paginated_sql_statement_clause: str | None = None, + deferrable: bool = False, **kwargs, ) -> None: ... @classmethod diff --git a/providers/common/sql/tests/unit/common/sql/operators/test_generic_transfer.py b/providers/common/sql/tests/unit/common/sql/operators/test_generic_transfer.py index 1406542df4ea5..a268d232701e4 100644 --- a/providers/common/sql/tests/unit/common/sql/operators/test_generic_transfer.py +++ b/providers/common/sql/tests/unit/common/sql/operators/test_generic_transfer.py @@ -382,7 +382,50 @@ def test_non_paginated_read_for_multiple_sql_statements_with_rows_processor(self "table": "NEW_HR.EMPLOYEES", } - def test_paginated_read(self): + def test_non_deferred_paginated_read(self): + """ + Test that GenericTransfer paginates eagerly (non-deferred) when page_size is set and deferrable is False. + It stops early when fewer rows than page_size are returned (no need for an extra empty-page fetch). + """ + with mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection", side_effect=self.get_connection): + with mock.patch(f"{BASEHOOK_PATCH_PATH}.get_hook", side_effect=self.get_hook): + operator = GenericTransfer( + task_id="transfer_table", + source_conn_id="my_source_conn_id", + destination_conn_id="my_destination_conn_id", + sql="SELECT * FROM HR.EMPLOYEES", + destination_table="NEW_HR.EMPLOYEES", + page_size=2, + insert_args=INSERT_ARGS, + execution_timeout=timedelta(hours=1), + ) + + operator.execute(context=mock_context(task=operator)) + + assert self.mocked_source_hook.get_records.call_count == 3 + assert ( + self.mocked_source_hook.get_records.call_args_list[0].args[0] + == "SELECT * FROM HR.EMPLOYEES LIMIT 2 OFFSET 0" + ) + assert ( + self.mocked_source_hook.get_records.call_args_list[1].args[0] + == "SELECT * FROM HR.EMPLOYEES LIMIT 2 OFFSET 2" + ) + assert ( + self.mocked_source_hook.get_records.call_args_list[2].args[0] + == "SELECT * FROM HR.EMPLOYEES LIMIT 2 OFFSET 4" + ) + assert self.mocked_destination_hook.insert_rows.call_count == 2 + assert self.mocked_destination_hook.insert_rows.call_args_list[0].kwargs == { + **INSERT_ARGS, + **{"rows": [[1, 2], [11, 12], [3, 4], [13, 14]], "table": "NEW_HR.EMPLOYEES"}, + } + assert self.mocked_destination_hook.insert_rows.call_args_list[1].kwargs == { + **INSERT_ARGS, + **{"rows": [[3, 4], [13, 14]], "table": "NEW_HR.EMPLOYEES"}, + } + + def test_deferred_paginated_read(self): """ This unit test is based on the example described in the medium article: https://medium.com/apache-airflow/transfering-data-from-sap-hana-to-mssql-using-the-airflow-generictransfer-d29f147a9f1f @@ -399,6 +442,7 @@ def test_paginated_read(self): page_size=1000, # Fetch data in chunks of 1000 rows for pagination insert_args=INSERT_ARGS, execution_timeout=timedelta(hours=1), + deferrable=True, ) results, events = execute_operator(operator) @@ -414,6 +458,14 @@ def test_paginated_read(self): self.mocked_source_hook.get_records.call_args_list[0].args[0] == "SELECT * FROM HR.EMPLOYEES LIMIT 1000 OFFSET 0" ) + assert ( + self.mocked_source_hook.get_records.call_args_list[1].args[0] + == "SELECT * FROM HR.EMPLOYEES LIMIT 1000 OFFSET 1000" + ) + assert ( + self.mocked_source_hook.get_records.call_args_list[2].args[0] + == "SELECT * FROM HR.EMPLOYEES LIMIT 1000 OFFSET 2000" + ) assert self.mocked_destination_hook.insert_rows.call_count == 2 assert self.mocked_destination_hook.insert_rows.call_args_list[0].kwargs == { **INSERT_ARGS,