diff --git a/providers/sftp/src/airflow/providers/sftp/constants.py b/providers/sftp/src/airflow/providers/sftp/constants.py new file mode 100644 index 0000000000000..ec94d5c22ef59 --- /dev/null +++ b/providers/sftp/src/airflow/providers/sftp/constants.py @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Constants for SFTP provider.""" + +from __future__ import annotations + + +class SFTPOperation: + """Operation that can be used with SFTP.""" + + PUT = "put" + GET = "get" + DELETE = "delete" diff --git a/providers/sftp/src/airflow/providers/sftp/operators/sftp.py b/providers/sftp/src/airflow/providers/sftp/operators/sftp.py index 0b47b5b7d5ddb..a227355af897e 100644 --- a/providers/sftp/src/airflow/providers/sftp/operators/sftp.py +++ b/providers/sftp/src/airflow/providers/sftp/operators/sftp.py @@ -28,16 +28,11 @@ import paramiko +from airflow.configuration import conf from airflow.providers.common.compat.sdk import AirflowException, BaseOperator +from airflow.providers.sftp.constants import SFTPOperation from airflow.providers.sftp.hooks.sftp import SFTPHook - - -class SFTPOperation: - """Operation that can be used with SFTP.""" - - PUT = "put" - GET = "get" - DELETE = "delete" +from airflow.providers.sftp.triggers.sftp import SFTPOperatorTrigger class SFTPOperator(BaseOperator): @@ -95,6 +90,7 @@ def __init__( create_intermediate_dirs: bool = False, concurrency: int = 1, prefetch: bool = True, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ) -> None: super().__init__(**kwargs) @@ -108,8 +104,25 @@ def __init__( self.remote_filepath = remote_filepath self.concurrency = concurrency self.prefetch = prefetch + self.deferrable = deferrable def execute(self, context: Any) -> str | list[str] | None: + if self.deferrable: + self.defer( + trigger=SFTPOperatorTrigger( + ssh_conn_id=self.ssh_conn_id, + local_filepath=self.local_filepath, + remote_filepath=self.remote_filepath, + operation=self.operation, + confirm=self.confirm, + create_intermediate_dirs=self.create_intermediate_dirs, + remote_host=self.remote_host, + concurrency=self.concurrency, + prefetch=self.prefetch, + ), + method_name="execute_complete", + ) + if self.local_filepath is None: local_filepath_array = [] elif isinstance(self.local_filepath, str): @@ -227,6 +240,21 @@ def execute(self, context: Any) -> str | list[str] | None: return self.local_filepath + def execute_complete(self, context: Any, event: dict) -> str | list[str] | None: + """ + Execute when the trigger fires in deferrable mode. + + :param context: The task context. + :param event: The event yielded by SFTPOperatorTrigger. + :return: The local filepath(s). + """ + if event.get("status") == "error": + raise AirflowException( + f"Error during deferrable SFTP {self.operation.upper()} operation: {event.get('message')}" + ) + self.log.info("File transfer completed successfully via deferrable mode.") + return event.get("local_filepath") + @staticmethod def _is_missing_path_error(exc: Exception) -> bool: if isinstance(exc, FileNotFoundError): diff --git a/providers/sftp/src/airflow/providers/sftp/triggers/sftp.py b/providers/sftp/src/airflow/providers/sftp/triggers/sftp.py index 6ecdbaa158940..683716f606c08 100644 --- a/providers/sftp/src/airflow/providers/sftp/triggers/sftp.py +++ b/providers/sftp/src/airflow/providers/sftp/triggers/sftp.py @@ -139,3 +139,135 @@ async def run(self) -> AsyncIterator[TriggerEvent]: def _get_async_hook(self) -> SFTPHookAsync: return SFTPHookAsync(sftp_conn_id=self.sftp_conn_id) + + +class SFTPOperatorTrigger(BaseTrigger): + """ + Trigger for SFTPOperator deferrable mode. + + Fires when a file transfer (PUT, GET, or DELETE) completes + on the SFTP server, freeing the worker slot during the transfer. + + :param ssh_conn_id: The SSH connection ID to use. + :param local_filepath: Local file path(s) to transfer. + :param remote_filepath: Remote file path(s) on the SFTP server. + :param operation: The SFTP operation - put, get, or delete. + :param confirm: Whether to confirm the file transfer. + :param create_intermediate_dirs: Whether to create intermediate dirs. + :param remote_host: Remote host to connect to (overrides connection). + :param concurrency: Number of threads for directory transfers. + :param prefetch: Whether to prefetch during file retrieval. + """ + + def __init__( + self, + ssh_conn_id: str | None = None, + local_filepath: str | list[str] | None = None, + remote_filepath: str | list[str] = "", + operation: str = "put", + confirm: bool = True, + create_intermediate_dirs: bool = False, + remote_host: str | None = None, + concurrency: int = 1, + prefetch: bool = True, + ) -> None: + super().__init__() + self.ssh_conn_id = ssh_conn_id + self.local_filepath = local_filepath + self.remote_filepath = remote_filepath + self.operation = operation + self.confirm = confirm + self.create_intermediate_dirs = create_intermediate_dirs + self.remote_host = remote_host + self.concurrency = concurrency + self.prefetch = prefetch + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serialize the trigger for storage in the database.""" + return ( + "airflow.providers.sftp.triggers.sftp.SFTPOperatorTrigger", + { + "ssh_conn_id": self.ssh_conn_id, + "local_filepath": self.local_filepath, + "remote_filepath": self.remote_filepath, + "operation": self.operation, + "confirm": self.confirm, + "create_intermediate_dirs": self.create_intermediate_dirs, + "remote_host": self.remote_host, + "concurrency": self.concurrency, + "prefetch": self.prefetch, + }, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + """Run the file transfer asynchronously and yield a TriggerEvent when done.""" + try: + loop = asyncio.get_running_loop() + await loop.run_in_executor( + None, + self._do_transfer, + ) + yield TriggerEvent( + { + "status": "success", + "local_filepath": self.local_filepath, + } + ) + except Exception as e: + yield TriggerEvent({"status": "error", "message": str(e)}) + + def _do_transfer(self) -> None: + """Run the actual synchronous SFTP transfer in a thread executor.""" + import os + from pathlib import Path + + from airflow.providers.sftp.constants import SFTPOperation + from airflow.providers.sftp.hooks.sftp import SFTPHook + + sftp_hook = SFTPHook( + ssh_conn_id=self.ssh_conn_id, + remote_host=self.remote_host or "", + ) + + if isinstance(self.local_filepath, str): + local_filepath_array = [self.local_filepath] if self.local_filepath else [] + else: + local_filepath_array = self.local_filepath or [] + + if isinstance(self.remote_filepath, str): + remote_filepath_array = [self.remote_filepath] + else: + remote_filepath_array = list(self.remote_filepath) + + if self.operation.lower() == SFTPOperation.GET: + for local, remote in zip(local_filepath_array, remote_filepath_array): + if self.create_intermediate_dirs: + Path(os.path.dirname(local)).mkdir(parents=True, exist_ok=True) + if sftp_hook.isdir(remote): + if self.concurrency > 1: + sftp_hook.retrieve_directory_concurrently( + remote, local, workers=self.concurrency, prefetch=self.prefetch + ) + else: + sftp_hook.retrieve_directory(remote, local) + else: + sftp_hook.retrieve_file(remote, local, prefetch=self.prefetch) + elif self.operation.lower() == SFTPOperation.PUT: + for local, remote in zip(local_filepath_array, remote_filepath_array): + if self.create_intermediate_dirs: + sftp_hook.create_directory(os.path.dirname(remote)) + if os.path.isdir(local): + if self.concurrency > 1: + sftp_hook.store_directory_concurrently( + remote, local, confirm=self.confirm, workers=self.concurrency + ) + else: + sftp_hook.store_directory(remote, local, confirm=self.confirm) + else: + sftp_hook.store_file(remote, local, confirm=self.confirm) + elif self.operation.lower() == SFTPOperation.DELETE: + for remote in remote_filepath_array: + if sftp_hook.isdir(remote): + sftp_hook.delete_directory(remote, include_files=True) + else: + sftp_hook.delete_file(remote) diff --git a/providers/sftp/tests/unit/sftp/operators/test_sftp.py b/providers/sftp/tests/unit/sftp/operators/test_sftp.py index 815d981320107..4810cfbbb7cda 100644 --- a/providers/sftp/tests/unit/sftp/operators/test_sftp.py +++ b/providers/sftp/tests/unit/sftp/operators/test_sftp.py @@ -28,11 +28,13 @@ import paramiko import pytest +from airflow.exceptions import TaskDeferred from airflow.models import DAG, Connection from airflow.providers.common.compat.openlineage.facet import Dataset from airflow.providers.common.compat.sdk import AirflowException from airflow.providers.sftp.hooks.sftp import SFTPHook from airflow.providers.sftp.operators.sftp import SFTPOperation, SFTPOperator +from airflow.providers.sftp.triggers.sftp import SFTPOperatorTrigger from airflow.providers.ssh.hooks.ssh import SSHHook from airflow.providers.ssh.operators.ssh import SSHOperator from airflow.utils import timezone @@ -675,3 +677,117 @@ def test_extract_sftp_hook(self, get_connection, get_conn, operation, expected): assert lineage.inputs == expected[0] assert lineage.outputs == expected[1] + + +class TestSFTPOperatorDeferrable: + """Tests for SFTPOperator deferrable mode.""" + + def test_sftp_operator_defers_when_deferrable_true(self): + """Test that SFTPOperator defers when deferrable=True.""" + operator = SFTPOperator( + task_id="test_sftp_defer", + ssh_conn_id="ssh_default", + local_filepath="/tmp/test.txt", + remote_filepath="/remote/test.txt", + operation=SFTPOperation.PUT, + deferrable=True, + ) + with pytest.raises(TaskDeferred) as exc: + operator.execute(context={}) + assert isinstance(exc.value.trigger, SFTPOperatorTrigger) + assert exc.value.method_name == "execute_complete" + + def test_sftp_operator_execute_complete_success(self): + """Test execute_complete returns local_filepath on success.""" + operator = SFTPOperator( + task_id="test_sftp_complete", + ssh_conn_id="ssh_default", + local_filepath="/tmp/test.txt", + remote_filepath="/remote/test.txt", + operation=SFTPOperation.PUT, + deferrable=True, + ) + event = {"status": "success", "local_filepath": "/tmp/test.txt"} + result = operator.execute_complete(context={}, event=event) + assert result == "/tmp/test.txt" + + def test_sftp_operator_execute_complete_raises_on_error(self): + """Test execute_complete raises AirflowException on error.""" + operator = SFTPOperator( + task_id="test_sftp_error", + ssh_conn_id="ssh_default", + local_filepath="/tmp/test.txt", + remote_filepath="/remote/test.txt", + operation=SFTPOperation.PUT, + deferrable=True, + ) + event = {"status": "error", "message": "Connection refused"} + with pytest.raises(AirflowException, match="Connection refused"): + operator.execute_complete(context={}, event=event) + + +class TestSFTPOperatorTrigger: + """Tests for SFTPOperatorTrigger.""" + + def test_serialize_roundtrip(self): + """Test that serialize() produces correct output for reconstruction.""" + trigger = SFTPOperatorTrigger( + ssh_conn_id="ssh_default", + local_filepath="/tmp/test.txt", + remote_filepath="/remote/test.txt", + operation="put", + confirm=True, + create_intermediate_dirs=False, + remote_host=None, + concurrency=1, + prefetch=True, + ) + classpath, kwargs = trigger.serialize() + assert classpath == "airflow.providers.sftp.triggers.sftp.SFTPOperatorTrigger" + assert kwargs["ssh_conn_id"] == "ssh_default" + assert kwargs["local_filepath"] == "/tmp/test.txt" + assert kwargs["remote_filepath"] == "/remote/test.txt" + assert kwargs["operation"] == "put" + assert kwargs["confirm"] is True + assert kwargs["remote_host"] is None + assert kwargs["concurrency"] == 1 + assert kwargs["prefetch"] is True + + def test_run_success(self): + """Test run() yields TriggerEvent with status success.""" + import asyncio + from unittest.mock import patch + trigger = SFTPOperatorTrigger( + ssh_conn_id="ssh_default", + local_filepath="/tmp/test.txt", + remote_filepath="/remote/test.txt", + operation="put", + ) + with patch.object(trigger, "_do_transfer", return_value=None): + events = [] + async def collect(): + async for event in trigger.run(): + events.append(event) + asyncio.run(collect()) + assert len(events) == 1 + assert events[0].payload["status"] == "success" + + def test_run_error(self): + """Test run() yields TriggerEvent with status error on exception.""" + import asyncio + from unittest.mock import patch + trigger = SFTPOperatorTrigger( + ssh_conn_id="ssh_default", + local_filepath="/tmp/test.txt", + remote_filepath="/remote/test.txt", + operation="put", + ) + with patch.object(trigger, "_do_transfer", side_effect=Exception("Connection failed")): + events = [] + async def collect(): + async for event in trigger.run(): + events.append(event) + asyncio.run(collect()) + assert len(events) == 1 + assert events[0].payload["status"] == "error" + assert "Connection failed" in events[0].payload["message"]