Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions providers/sftp/src/airflow/providers/sftp/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Constants for SFTP provider."""

from __future__ import annotations


class SFTPOperation:
"""Operation that can be used with SFTP."""

PUT = "put"
GET = "get"
DELETE = "delete"
44 changes: 36 additions & 8 deletions providers/sftp/src/airflow/providers/sftp/operators/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,11 @@

import paramiko

from airflow.configuration import conf
from airflow.providers.common.compat.sdk import AirflowException, BaseOperator
from airflow.providers.sftp.constants import SFTPOperation
from airflow.providers.sftp.hooks.sftp import SFTPHook


class SFTPOperation:
"""Operation that can be used with SFTP."""

PUT = "put"
GET = "get"
DELETE = "delete"
from airflow.providers.sftp.triggers.sftp import SFTPOperatorTrigger


class SFTPOperator(BaseOperator):
Expand Down Expand Up @@ -95,6 +90,7 @@ def __init__(
create_intermediate_dirs: bool = False,
concurrency: int = 1,
prefetch: bool = True,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -108,8 +104,25 @@ def __init__(
self.remote_filepath = remote_filepath
self.concurrency = concurrency
self.prefetch = prefetch
self.deferrable = deferrable

def execute(self, context: Any) -> str | list[str] | None:
if self.deferrable:
self.defer(
trigger=SFTPOperatorTrigger(
ssh_conn_id=self.ssh_conn_id,
local_filepath=self.local_filepath,
remote_filepath=self.remote_filepath,
operation=self.operation,
confirm=self.confirm,
create_intermediate_dirs=self.create_intermediate_dirs,
remote_host=self.remote_host,
concurrency=self.concurrency,
prefetch=self.prefetch,
),
method_name="execute_complete",
)

if self.local_filepath is None:
local_filepath_array = []
elif isinstance(self.local_filepath, str):
Expand Down Expand Up @@ -227,6 +240,21 @@ def execute(self, context: Any) -> str | list[str] | None:

return self.local_filepath

def execute_complete(self, context: Any, event: dict) -> str | list[str] | None:
"""
Execute when the trigger fires in deferrable mode.

:param context: The task context.
:param event: The event yielded by SFTPOperatorTrigger.
:return: The local filepath(s).
"""
if event.get("status") == "error":
raise AirflowException(
f"Error during deferrable SFTP {self.operation.upper()} operation: {event.get('message')}"
)
self.log.info("File transfer completed successfully via deferrable mode.")
return event.get("local_filepath")

@staticmethod
def _is_missing_path_error(exc: Exception) -> bool:
if isinstance(exc, FileNotFoundError):
Expand Down
132 changes: 132 additions & 0 deletions providers/sftp/src/airflow/providers/sftp/triggers/sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,135 @@ async def run(self) -> AsyncIterator[TriggerEvent]:

def _get_async_hook(self) -> SFTPHookAsync:
return SFTPHookAsync(sftp_conn_id=self.sftp_conn_id)


class SFTPOperatorTrigger(BaseTrigger):
"""
Trigger for SFTPOperator deferrable mode.

Fires when a file transfer (PUT, GET, or DELETE) completes
on the SFTP server, freeing the worker slot during the transfer.

:param ssh_conn_id: The SSH connection ID to use.
:param local_filepath: Local file path(s) to transfer.
:param remote_filepath: Remote file path(s) on the SFTP server.
:param operation: The SFTP operation - put, get, or delete.
:param confirm: Whether to confirm the file transfer.
:param create_intermediate_dirs: Whether to create intermediate dirs.
:param remote_host: Remote host to connect to (overrides connection).
:param concurrency: Number of threads for directory transfers.
:param prefetch: Whether to prefetch during file retrieval.
"""

def __init__(
self,
ssh_conn_id: str | None = None,
local_filepath: str | list[str] | None = None,
remote_filepath: str | list[str] = "",
operation: str = "put",
confirm: bool = True,
create_intermediate_dirs: bool = False,
remote_host: str | None = None,
concurrency: int = 1,
prefetch: bool = True,
) -> None:
super().__init__()
self.ssh_conn_id = ssh_conn_id
self.local_filepath = local_filepath
self.remote_filepath = remote_filepath
self.operation = operation
self.confirm = confirm
self.create_intermediate_dirs = create_intermediate_dirs
self.remote_host = remote_host
self.concurrency = concurrency
self.prefetch = prefetch

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize the trigger for storage in the database."""
return (
"airflow.providers.sftp.triggers.sftp.SFTPOperatorTrigger",
{
"ssh_conn_id": self.ssh_conn_id,
"local_filepath": self.local_filepath,
"remote_filepath": self.remote_filepath,
"operation": self.operation,
"confirm": self.confirm,
"create_intermediate_dirs": self.create_intermediate_dirs,
"remote_host": self.remote_host,
"concurrency": self.concurrency,
"prefetch": self.prefetch,
},
)

async def run(self) -> AsyncIterator[TriggerEvent]:
"""Run the file transfer asynchronously and yield a TriggerEvent when done."""
try:
loop = asyncio.get_running_loop()
await loop.run_in_executor(
None,
self._do_transfer,
)
yield TriggerEvent(
{
"status": "success",
"local_filepath": self.local_filepath,
}
)
except Exception as e:
yield TriggerEvent({"status": "error", "message": str(e)})

def _do_transfer(self) -> None:
"""Run the actual synchronous SFTP transfer in a thread executor."""
import os
from pathlib import Path

from airflow.providers.sftp.constants import SFTPOperation
from airflow.providers.sftp.hooks.sftp import SFTPHook

sftp_hook = SFTPHook(
ssh_conn_id=self.ssh_conn_id,
remote_host=self.remote_host or "",
)

if isinstance(self.local_filepath, str):
local_filepath_array = [self.local_filepath] if self.local_filepath else []
else:
local_filepath_array = self.local_filepath or []

if isinstance(self.remote_filepath, str):
remote_filepath_array = [self.remote_filepath]
else:
remote_filepath_array = list(self.remote_filepath)

if self.operation.lower() == SFTPOperation.GET:
for local, remote in zip(local_filepath_array, remote_filepath_array):
if self.create_intermediate_dirs:
Path(os.path.dirname(local)).mkdir(parents=True, exist_ok=True)
if sftp_hook.isdir(remote):
if self.concurrency > 1:
sftp_hook.retrieve_directory_concurrently(
remote, local, workers=self.concurrency, prefetch=self.prefetch
)
else:
sftp_hook.retrieve_directory(remote, local)
else:
sftp_hook.retrieve_file(remote, local, prefetch=self.prefetch)
elif self.operation.lower() == SFTPOperation.PUT:
for local, remote in zip(local_filepath_array, remote_filepath_array):
if self.create_intermediate_dirs:
sftp_hook.create_directory(os.path.dirname(remote))
if os.path.isdir(local):
if self.concurrency > 1:
sftp_hook.store_directory_concurrently(
remote, local, confirm=self.confirm, workers=self.concurrency
)
else:
sftp_hook.store_directory(remote, local, confirm=self.confirm)
else:
sftp_hook.store_file(remote, local, confirm=self.confirm)
elif self.operation.lower() == SFTPOperation.DELETE:
for remote in remote_filepath_array:
if sftp_hook.isdir(remote):
sftp_hook.delete_directory(remote, include_files=True)
else:
sftp_hook.delete_file(remote)
116 changes: 116 additions & 0 deletions providers/sftp/tests/unit/sftp/operators/test_sftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@
import paramiko
import pytest

from airflow.exceptions import TaskDeferred
from airflow.models import DAG, Connection
from airflow.providers.common.compat.openlineage.facet import Dataset
from airflow.providers.common.compat.sdk import AirflowException
from airflow.providers.sftp.hooks.sftp import SFTPHook
from airflow.providers.sftp.operators.sftp import SFTPOperation, SFTPOperator
from airflow.providers.sftp.triggers.sftp import SFTPOperatorTrigger
from airflow.providers.ssh.hooks.ssh import SSHHook
from airflow.providers.ssh.operators.ssh import SSHOperator
from airflow.utils import timezone
Expand Down Expand Up @@ -675,3 +677,117 @@ def test_extract_sftp_hook(self, get_connection, get_conn, operation, expected):

assert lineage.inputs == expected[0]
assert lineage.outputs == expected[1]


class TestSFTPOperatorDeferrable:
"""Tests for SFTPOperator deferrable mode."""

def test_sftp_operator_defers_when_deferrable_true(self):
"""Test that SFTPOperator defers when deferrable=True."""
operator = SFTPOperator(
task_id="test_sftp_defer",
ssh_conn_id="ssh_default",
local_filepath="/tmp/test.txt",
remote_filepath="/remote/test.txt",
operation=SFTPOperation.PUT,
deferrable=True,
)
with pytest.raises(TaskDeferred) as exc:
operator.execute(context={})
assert isinstance(exc.value.trigger, SFTPOperatorTrigger)
assert exc.value.method_name == "execute_complete"

def test_sftp_operator_execute_complete_success(self):
"""Test execute_complete returns local_filepath on success."""
operator = SFTPOperator(
task_id="test_sftp_complete",
ssh_conn_id="ssh_default",
local_filepath="/tmp/test.txt",
remote_filepath="/remote/test.txt",
operation=SFTPOperation.PUT,
deferrable=True,
)
event = {"status": "success", "local_filepath": "/tmp/test.txt"}
result = operator.execute_complete(context={}, event=event)
assert result == "/tmp/test.txt"

def test_sftp_operator_execute_complete_raises_on_error(self):
"""Test execute_complete raises AirflowException on error."""
operator = SFTPOperator(
task_id="test_sftp_error",
ssh_conn_id="ssh_default",
local_filepath="/tmp/test.txt",
remote_filepath="/remote/test.txt",
operation=SFTPOperation.PUT,
deferrable=True,
)
event = {"status": "error", "message": "Connection refused"}
with pytest.raises(AirflowException, match="Connection refused"):
operator.execute_complete(context={}, event=event)


class TestSFTPOperatorTrigger:
"""Tests for SFTPOperatorTrigger."""

def test_serialize_roundtrip(self):
"""Test that serialize() produces correct output for reconstruction."""
trigger = SFTPOperatorTrigger(
ssh_conn_id="ssh_default",
local_filepath="/tmp/test.txt",
remote_filepath="/remote/test.txt",
operation="put",
confirm=True,
create_intermediate_dirs=False,
remote_host=None,
concurrency=1,
prefetch=True,
)
classpath, kwargs = trigger.serialize()
assert classpath == "airflow.providers.sftp.triggers.sftp.SFTPOperatorTrigger"
assert kwargs["ssh_conn_id"] == "ssh_default"
assert kwargs["local_filepath"] == "/tmp/test.txt"
assert kwargs["remote_filepath"] == "/remote/test.txt"
assert kwargs["operation"] == "put"
assert kwargs["confirm"] is True
assert kwargs["remote_host"] is None
assert kwargs["concurrency"] == 1
assert kwargs["prefetch"] is True

def test_run_success(self):
"""Test run() yields TriggerEvent with status success."""
import asyncio
from unittest.mock import patch
trigger = SFTPOperatorTrigger(
ssh_conn_id="ssh_default",
local_filepath="/tmp/test.txt",
remote_filepath="/remote/test.txt",
operation="put",
)
with patch.object(trigger, "_do_transfer", return_value=None):
events = []
async def collect():
async for event in trigger.run():
events.append(event)
asyncio.run(collect())
assert len(events) == 1
assert events[0].payload["status"] == "success"

def test_run_error(self):
"""Test run() yields TriggerEvent with status error on exception."""
import asyncio
from unittest.mock import patch
trigger = SFTPOperatorTrigger(
ssh_conn_id="ssh_default",
local_filepath="/tmp/test.txt",
remote_filepath="/remote/test.txt",
operation="put",
)
with patch.object(trigger, "_do_transfer", side_effect=Exception("Connection failed")):
events = []
async def collect():
async for event in trigger.run():
events.append(event)
asyncio.run(collect())
assert len(events) == 1
assert events[0].payload["status"] == "error"
assert "Connection failed" in events[0].payload["message"]
Loading