Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from functools import cached_property
from typing import TYPE_CHECKING, Any

from airflow.providers.common.compat.sdk import AirflowException, BaseHook, BaseOperator
from airflow.providers.common.compat.sdk import AirflowException, BaseHook, BaseOperator, conf
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.providers.common.sql.triggers.sql import SQLExecuteQueryTrigger

Expand Down Expand Up @@ -55,6 +55,7 @@ class GenericTransfer(BaseOperator):
:param insert_args: extra params for `insert_rows` method.
:param page_size: number of records to be read in paginated mode (optional).
:param paginated_sql_statement_clause: SQL statement clause to be used for pagination (optional).
:param deferrable: Run operator in the deferrable mode
Comment thread
dabla marked this conversation as resolved.
Outdated
"""

template_fields: Sequence[str] = (
Expand Down Expand Up @@ -90,6 +91,7 @@ def __init__(
insert_args: dict | None = None,
page_size: int | None = None,
paginated_sql_statement_clause: str | None = None,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
) -> None:
Comment thread
dabla marked this conversation as resolved.
super().__init__(**kwargs)
Expand All @@ -104,6 +106,7 @@ def __init__(
self.insert_args = insert_args or {}
self.page_size = page_size
self.paginated_sql_statement_clause = paginated_sql_statement_clause or "{} LIMIT {} OFFSET {}"
self.deferrable = deferrable

@classmethod
def get_hook(cls, conn_id: str, hook_params: dict | None = None) -> DbApiHook:
Expand Down Expand Up @@ -159,14 +162,29 @@ def execute(self, context: Context):
self.destination_hook.run(self.preoperator)

if self.page_size and isinstance(self.sql, str):
self.defer(
trigger=SQLExecuteQueryTrigger(
conn_id=self.source_conn_id,
hook_params=self.source_hook_params,
sql=self.get_paginated_sql(0),
),
method_name=self.execute_complete.__name__,
)
if self.deferrable:
self.defer(
trigger=SQLExecuteQueryTrigger(
conn_id=self.source_conn_id,
hook_params=self.source_hook_params,
sql=self.get_paginated_sql(0),
),
method_name=self.execute_complete.__name__,
)
else:
offset = 0
while True:
paginated_sql = self.get_paginated_sql(offset)
self.log.info("Executing: \n %s", paginated_sql)
if rows := self.source_hook.get_records(paginated_sql):
self._insert_rows(rows=rows, context=context)
Comment thread
dabla marked this conversation as resolved.
offset += self.page_size
Comment thread
dabla marked this conversation as resolved.
else:
self.log.info(
"No more rows to fetch into %s; ending transfer.",
self.destination_table,
)
break
else:
if isinstance(self.sql, str):
self.sql = [self.sql]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ class GenericTransfer(BaseOperator):
preoperator: Incomplete
insert_args: Incomplete
page_size: Incomplete
paginated_sql_statement_clause: Incomplete
deferrable: bool
def __init__(
self,
Comment thread
dabla marked this conversation as resolved.
*,
Expand All @@ -68,6 +70,8 @@ class GenericTransfer(BaseOperator):
preoperator: str | list[str] | None = None,
insert_args: dict | None = None,
page_size: int | None = None,
paginated_sql_statement_clause: str | None = None,
deferrable: bool = False,
Comment thread
dabla marked this conversation as resolved.
**kwargs,
) -> None: ...
@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,50 @@ def test_non_paginated_read_for_multiple_sql_statements_with_rows_processor(self
"table": "NEW_HR.EMPLOYEES",
}

def test_paginated_read(self):
def test_non_deferred_paginated_read(self):
"""
Test that GenericTransfer paginates eagerly (non-deferred) when page_size is set and deferrable is False.
"""
with mock.patch(f"{BASEHOOK_PATCH_PATH}.get_connection", side_effect=self.get_connection):
with mock.patch(f"{BASEHOOK_PATCH_PATH}.get_hook", side_effect=self.get_hook):
operator = GenericTransfer(
task_id="transfer_table",
source_conn_id="my_source_conn_id",
destination_conn_id="my_destination_conn_id",
sql="SELECT * FROM HR.EMPLOYEES",
destination_table="NEW_HR.EMPLOYEES",
page_size=1000, # Enable pagination
insert_args=INSERT_ARGS,
execution_timeout=timedelta(hours=1),
deferrable=False, # Explicitly non-deferred
)

operator.execute(context=mock_context(task=operator))

assert self.mocked_source_hook.get_records.call_count == 3
assert (
self.mocked_source_hook.get_records.call_args_list[0].args[0]
== "SELECT * FROM HR.EMPLOYEES LIMIT 1000 OFFSET 0"
)
assert (
self.mocked_source_hook.get_records.call_args_list[1].args[0]
== "SELECT * FROM HR.EMPLOYEES LIMIT 1000 OFFSET 1000"
)
assert (
self.mocked_source_hook.get_records.call_args_list[2].args[0]
== "SELECT * FROM HR.EMPLOYEES LIMIT 1000 OFFSET 2000"
)
assert self.mocked_destination_hook.insert_rows.call_count == 2
assert self.mocked_destination_hook.insert_rows.call_args_list[0].kwargs == {
**INSERT_ARGS,
**{"rows": [[1, 2], [11, 12], [3, 4], [13, 14]], "table": "NEW_HR.EMPLOYEES"},
}
assert self.mocked_destination_hook.insert_rows.call_args_list[1].kwargs == {
**INSERT_ARGS,
**{"rows": [[3, 4], [13, 14]], "table": "NEW_HR.EMPLOYEES"},
}

def test_deferred_paginated_read(self):
"""
This unit test is based on the example described in the medium article:
https://medium.com/apache-airflow/transfering-data-from-sap-hana-to-mssql-using-the-airflow-generictransfer-d29f147a9f1f
Expand All @@ -399,6 +442,7 @@ def test_paginated_read(self):
page_size=1000, # Fetch data in chunks of 1000 rows for pagination
insert_args=INSERT_ARGS,
execution_timeout=timedelta(hours=1),
deferrable=True,
)

results, events = execute_operator(operator)
Expand All @@ -414,6 +458,14 @@ def test_paginated_read(self):
self.mocked_source_hook.get_records.call_args_list[0].args[0]
== "SELECT * FROM HR.EMPLOYEES LIMIT 1000 OFFSET 0"
)
assert (
self.mocked_source_hook.get_records.call_args_list[1].args[0]
== "SELECT * FROM HR.EMPLOYEES LIMIT 1000 OFFSET 1000"
)
assert (
self.mocked_source_hook.get_records.call_args_list[2].args[0]
== "SELECT * FROM HR.EMPLOYEES LIMIT 1000 OFFSET 2000"
)
assert self.mocked_destination_hook.insert_rows.call_count == 2
assert self.mocked_destination_hook.insert_rows.call_args_list[0].kwargs == {
**INSERT_ARGS,
Expand Down
Loading