From d4406c0cefddf67ad0d811750813e64fe22c5f3a Mon Sep 17 00:00:00 2001 From: Stefanie Grunwald Date: Sun, 24 Sep 2017 17:19:14 +0200 Subject: [PATCH] [AIRFLOW-71] Add support for private Docker images Pulling images from private Docker registries requires authentication, so additional parameters are added in order to perform the login step. (cherry picked from commit f101ff0063ad5b22b6efc8d89b1b31b4c6abfa0d) Signed-off-by: Bolke de Bruin --- airflow/contrib/hooks/__init__.py | 1 + airflow/hooks/docker_hook.py | 79 ++++++++++++ airflow/models.py | 4 + airflow/operators/docker_operator.py | 58 +++++++-- airflow/www/static/connection_form.js | 9 +- docs/code.rst | 1 + tests/hooks/test_docker_hook.py | 176 ++++++++++++++++++++++++++ tests/operators/docker_operator.py | 82 +++++++++++- 8 files changed, 389 insertions(+), 21 deletions(-) create mode 100644 airflow/hooks/docker_hook.py create mode 100644 tests/hooks/test_docker_hook.py diff --git a/airflow/contrib/hooks/__init__.py b/airflow/contrib/hooks/__init__.py index 6d45acea7e936..99a1746ee6551 100644 --- a/airflow/contrib/hooks/__init__.py +++ b/airflow/contrib/hooks/__init__.py @@ -32,6 +32,7 @@ # # ------------------------------------------------------------------------ _hooks = { + 'docker_hook': ['DockerHook'], 'ftp_hook': ['FTPHook'], 'ftps_hook': ['FTPSHook'], 'vertica_hook': ['VerticaHook'], diff --git a/airflow/hooks/docker_hook.py b/airflow/hooks/docker_hook.py new file mode 100644 index 0000000000000..a570292ead2b9 --- /dev/null +++ b/airflow/hooks/docker_hook.py @@ -0,0 +1,79 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from docker import Client +from docker.errors import APIError + +from airflow.exceptions import AirflowException +from airflow.hooks.base_hook import BaseHook +from airflow.utils.log.logging_mixin import LoggingMixin + +class DockerHook(BaseHook, LoggingMixin): + """ + Interact with a private Docker registry. + + :param docker_conn_id: ID of the Airflow connection where + credentials and extra configuration are stored + :type docker_conn_id: str + """ + def __init__(self, + docker_conn_id='docker_default', + base_url=None, + version=None, + tls=None + ): + if not base_url: + raise AirflowException('No Docker base URL provided') + if not version: + raise AirflowException('No Docker API version provided') + + conn = self.get_connection(docker_conn_id) + if not conn.host: + raise AirflowException('No Docker registry URL provided') + if not conn.login: + raise AirflowException('No username provided') + extra_options = conn.extra_dejson + + self.__base_url = base_url + self.__version = version + self.__tls = tls + self.__registry = conn.host + self.__username = conn.login + self.__password = conn.password + self.__email = extra_options.get('email') + self.__reauth = False if extra_options.get('reauth') == 'no' else True + + def get_conn(self): + client = Client( + base_url=self.__base_url, + version=self.__version, + tls=self.__tls + ) + self.__login(client) + return client + + def __login(self, client): + self.log.debug('Logging into Docker registry') + try: + client.login( + username=self.__username, + password=self.__password, + registry=self.__registry, + email=self.__email, + reauth=self.__reauth + ) + self.log.debug('Login successful') + except APIError as docker_error: + self.log.error('Docker registry login failed: %s', str(docker_error)) + raise AirflowException('Docker registry login failed: %s', str(docker_error)) diff --git a/airflow/models.py b/airflow/models.py index a82fec64fc8ba..b4b2325834c1a 100755 --- a/airflow/models.py +++ b/airflow/models.py @@ -530,6 +530,7 @@ class Connection(Base, LoggingMixin): _extra = Column('extra', String(5000)) _types = [ + ('docker', 'Docker Registry',), ('fs', 'File (path)'), ('ftp', 'FTP',), ('google_cloud_platform', 'Google Cloud Platform'), @@ -696,6 +697,9 @@ def get_hook(self): elif self.conn_type == 'wasb': from airflow.contrib.hooks.wasb_hook import WasbHook return WasbHook(wasb_conn_id=self.conn_id) + elif self.conn_type == 'docker': + from airflow.hooks.docker_hook import DockerHook + return DockerHook(docker_conn_id=self.conn_id) except: pass diff --git a/airflow/operators/docker_operator.py b/airflow/operators/docker_operator.py index f319cf44a796a..38edc8b4d79ab 100644 --- a/airflow/operators/docker_operator.py +++ b/airflow/operators/docker_operator.py @@ -14,11 +14,12 @@ import json +from airflow.hooks.docker_hook import DockerHook from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults from airflow.utils.file import TemporaryDirectory -from docker import APIClient as Client, tls +from docker import Client, tls import ast @@ -30,9 +31,14 @@ class DockerOperator(BaseOperator): that together exceed the default disk size of 10GB in a container. The path to the mounted directory can be accessed via the environment variable ``AIRFLOW_TMP_DIR``. + If a login to a private registry is required prior to pulling the image, a + Docker connection needs to be configured in Airflow and the connection ID + be provided with the parameter ``docker_conn_id``. + :param image: Docker image from which to create the container. :type image: str - :param api_version: Remote API version. + :param api_version: Remote API version. Set to ``auto`` to automatically + detect the server's version. :type api_version: str :param command: Command to be run in the container. :type command: str or list @@ -41,10 +47,11 @@ class DockerOperator(BaseOperator): https://docs.docker.com/engine/reference/run/#cpu-share-constraint :type cpus: float :param docker_url: URL of the host running the docker daemon. + Default is unix://var/run/docker.sock :type docker_url: str :param environment: Environment variables to set in the container. :type environment: dict - :param force_pull: Pull the docker image on every run. + :param force_pull: Pull the docker image on every run. Default is false. :type force_pull: bool :param mem_limit: Maximum amount of memory the container can use. Either a float value, which represents the limit in bytes, or a string like ``128m`` or ``1g``. @@ -78,6 +85,8 @@ class DockerOperator(BaseOperator): :type xcom_push: bool :param xcom_all: Push all the stdout or just the last line. The default is False (last line). :type xcom_all: bool + :param docker_conn_id: ID of the Airflow connection to use + :type docker_conn_id: str """ template_fields = ('command',) template_ext = ('.sh', '.bash',) @@ -105,6 +114,7 @@ def __init__( working_dir=None, xcom_push=False, xcom_all=False, + docker_conn_id=None, *args, **kwargs): @@ -129,25 +139,32 @@ def __init__( self.working_dir = working_dir self.xcom_push_flag = xcom_push self.xcom_all = xcom_all + self.docker_conn_id = docker_conn_id self.cli = None self.container = None + def get_hook(self): + return DockerHook( + docker_conn_id=self.docker_conn_id, + base_url=self.base_url, + version=self.api_version, + tls=self.__get_tls_config() + ) + def execute(self, context): self.log.info('Starting docker container from image %s', self.image) - tls_config = None - if self.tls_ca_cert and self.tls_client_cert and self.tls_client_key: - tls_config = tls.TLSConfig( - ca_cert=self.tls_ca_cert, - client_cert=(self.tls_client_cert, self.tls_client_key), - verify=True, - ssl_version=self.tls_ssl_version, - assert_hostname=self.tls_hostname - ) - self.docker_url = self.docker_url.replace('tcp://', 'https://') + tls_config = self.__get_tls_config() - self.cli = Client(base_url=self.docker_url, version=self.api_version, tls=tls_config) + if self.docker_conn_id: + self.cli = self.get_hook().get_conn() + else: + self.cli = Client( + base_url=self.docker_url, + version=self.api_version, + tls=tls_config + ) if ':' not in self.image: image = self.image + ':latest' @@ -204,3 +221,16 @@ def on_kill(self): if self.cli is not None: self.log.info('Stopping docker container') self.cli.stop(self.container['Id']) + + def __get_tls_config(self): + tls_config = None + if self.tls_ca_cert and self.tls_client_cert and self.tls_client_key: + tls_config = tls.TLSConfig( + ca_cert=self.tls_ca_cert, + client_cert=(self.tls_client_cert, self.tls_client_key), + verify=True, + ssl_version=self.tls_ssl_version, + assert_hostname=self.tls_hostname + ) + self.docker_url = self.docker_url.replace('tcp://', 'https://') + return tls_config diff --git a/airflow/www/static/connection_form.js b/airflow/www/static/connection_form.js index 0324bcfffb2d2..c40bba7620b6f 100644 --- a/airflow/www/static/connection_form.js +++ b/airflow/www/static/connection_form.js @@ -38,7 +38,14 @@ 'login': 'Username (or API Key)', 'schema': 'Database' } - } + }, + docker: { + hidden_fields: ['port', 'schema'], + relabeling: { + 'host': 'Registry URL', + 'login': 'Username', + }, + }, } function connTypeChange(connectionType) { $("div.form-group").removeClass("hide"); diff --git a/docs/code.rst b/docs/code.rst index 4a6718f556f07..1369b323f8409 100644 --- a/docs/code.rst +++ b/docs/code.rst @@ -216,6 +216,7 @@ Hooks :show-inheritance: :members: DbApiHook, + DockerHook, HiveCliHook, HiveMetastoreHook, HiveServer2Hook, diff --git a/tests/hooks/test_docker_hook.py b/tests/hooks/test_docker_hook.py new file mode 100644 index 0000000000000..4a77523db6714 --- /dev/null +++ b/tests/hooks/test_docker_hook.py @@ -0,0 +1,176 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from airflow import configuration +from airflow import models +from airflow.exceptions import AirflowException +from airflow.utils import db + +try: + from airflow.hooks.docker_hook import DockerHook + from docker import Client +except ImportError: + pass + +try: + from unittest import mock +except ImportError: + try: + import mock + except ImportError: + mock = None + + +@mock.patch('airflow.hooks.docker_hook.Client', autospec=True) +class DockerHookTest(unittest.TestCase): + def setUp(self): + configuration.load_test_config() + db.merge_conn( + models.Connection( + conn_id='docker_default', + conn_type='docker', + host='some.docker.registry.com', + login='some_user', + password='some_p4$$w0rd' + ) + ) + db.merge_conn( + models.Connection( + conn_id='docker_with_extras', + conn_type='docker', + host='some.docker.registry.com', + login='some_user', + password='some_p4$$w0rd', + extra='{"email": "some@example.com", "reauth": "no"}' + ) + ) + + def test_init_fails_when_no_base_url_given(self, _): + with self.assertRaises(AirflowException): + DockerHook( + docker_conn_id='docker_default', + version='auto', + tls=None + ) + + def test_init_fails_when_no_api_version_given(self, _): + with self.assertRaises(AirflowException): + DockerHook( + docker_conn_id='docker_default', + base_url='unix://var/run/docker.sock', + tls=None + ) + + def test_get_conn_override_defaults(self, docker_client_mock): + hook = DockerHook( + docker_conn_id='docker_default', + base_url='https://index.docker.io/v1/', + version='1.23', + tls='someconfig' + ) + hook.get_conn() + docker_client_mock.assert_called_with( + base_url='https://index.docker.io/v1/', + version='1.23', + tls='someconfig' + ) + + def test_get_conn_with_standard_config(self, _): + try: + hook = DockerHook( + docker_conn_id='docker_default', + base_url='unix://var/run/docker.sock', + version='auto' + ) + client = hook.get_conn() + self.assertIsNotNone(client) + except: + self.fail('Could not get connection from Airflow') + + def test_get_conn_with_extra_config(self, _): + try: + hook = DockerHook( + docker_conn_id='docker_with_extras', + base_url='unix://var/run/docker.sock', + version='auto' + ) + client = hook.get_conn() + self.assertIsNotNone(client) + except: + self.fail('Could not get connection from Airflow') + + def test_conn_with_standard_config_passes_parameters(self, _): + hook = DockerHook( + docker_conn_id='docker_default', + base_url='unix://var/run/docker.sock', + version='auto' + ) + client = hook.get_conn() + client.login.assert_called_with( + username='some_user', + password='some_p4$$w0rd', + registry='some.docker.registry.com', + reauth=True, + email=None + ) + + def test_conn_with_extra_config_passes_parameters(self, _): + hook = DockerHook( + docker_conn_id='docker_with_extras', + base_url='unix://var/run/docker.sock', + version='auto' + ) + client = hook.get_conn() + client.login.assert_called_with( + username='some_user', + password='some_p4$$w0rd', + registry='some.docker.registry.com', + reauth=False, + email='some@example.com' + ) + + def test_conn_with_broken_config_missing_username_fails(self, _): + db.merge_conn( + models.Connection( + conn_id='docker_without_username', + conn_type='docker', + host='some.docker.registry.com', + password='some_p4$$w0rd', + extra='{"email": "some@example.com"}' + ) + ) + with self.assertRaises(AirflowException): + hook = DockerHook( + docker_conn_id='docker_without_username', + base_url='unix://var/run/docker.sock', + version='auto' + ) + + def test_conn_with_broken_config_missing_host_fails(self, _): + db.merge_conn( + models.Connection( + conn_id='docker_without_host', + conn_type='docker', + login='some_user', + password='some_p4$$w0rd' + ) + ) + with self.assertRaises(AirflowException): + hook = DockerHook( + docker_conn_id='docker_without_host', + base_url='unix://var/run/docker.sock', + version='auto' + ) diff --git a/tests/operators/docker_operator.py b/tests/operators/docker_operator.py index 84585806ca02a..a12b6f829f8a3 100644 --- a/tests/operators/docker_operator.py +++ b/tests/operators/docker_operator.py @@ -17,7 +17,8 @@ try: from airflow.operators.docker_operator import DockerOperator - from docker import APIClient as Client + from airflow.hooks.docker_hook import DockerHook + from docker import Client except ImportError: pass @@ -33,7 +34,6 @@ class DockerOperatorTestCase(unittest.TestCase): - @unittest.skipIf(mock is None, 'mock package not present') @mock.patch('airflow.utils.file.mkdtemp') @mock.patch('airflow.operators.docker_operator.Client') def test_execute(self, client_class_mock, mkdtemp_mock): @@ -77,7 +77,6 @@ def test_execute(self, client_class_mock, mkdtemp_mock): client_mock.pull.assert_called_with('ubuntu:latest', stream=True) client_mock.wait.assert_called_with('some_id') - @unittest.skipIf(mock is None, 'mock package not present') @mock.patch('airflow.operators.docker_operator.tls.TLSConfig') @mock.patch('airflow.operators.docker_operator.Client') def test_execute_tls(self, client_class_mock, tls_class_mock): @@ -105,7 +104,6 @@ def test_execute_tls(self, client_class_mock, tls_class_mock): client_class_mock.assert_called_with(base_url='https://127.0.0.1:2376', tls=tls_mock, version=None) - @unittest.skipIf(mock is None, 'mock package not present') @mock.patch('airflow.operators.docker_operator.Client') def test_execute_unicode_logs(self, client_class_mock): client_mock = mock.Mock(spec=Client) @@ -128,7 +126,6 @@ def test_execute_unicode_logs(self, client_class_mock): logging.raiseExceptions = originalRaiseExceptions print_exception_mock.assert_not_called() - @unittest.skipIf(mock is None, 'mock package not present') @mock.patch('airflow.operators.docker_operator.Client') def test_execute_container_fails(self, client_class_mock): client_mock = mock.Mock(spec=Client) @@ -146,7 +143,6 @@ def test_execute_container_fails(self, client_class_mock): with self.assertRaises(AirflowException): operator.execute(None) - @unittest.skipIf(mock is None, 'mock package not present') def test_on_kill(self): client_mock = mock.Mock(spec=Client) @@ -158,6 +154,80 @@ def test_on_kill(self): client_mock.stop.assert_called_with('some_id') + @mock.patch('airflow.operators.docker_operator.Client') + def test_execute_no_docker_conn_id_no_hook(self, operator_client_mock): + # Mock out a Docker client, so operations don't raise errors + client_mock = mock.Mock(name='DockerOperator.Client mock', spec=Client) + client_mock.images.return_value = [] + client_mock.create_container.return_value = {'Id': 'some_id'} + client_mock.logs.return_value = [] + client_mock.pull.return_value = [] + client_mock.wait.return_value = 0 + operator_client_mock.return_value = client_mock + + # Create the DockerOperator + operator = DockerOperator( + image='publicregistry/someimage', + owner='unittest', + task_id='unittest' + ) + + # Mock out the DockerHook + hook_mock = mock.Mock(name='DockerHook mock', spec=DockerHook) + hook_mock.get_conn.return_value = client_mock + operator.get_hook = mock.Mock( + name='DockerOperator.get_hook mock', + spec=DockerOperator.get_hook, + return_value=hook_mock + ) + + operator.execute(None) + self.assertEqual( + operator.get_hook.call_count, 0, + 'Hook called though no docker_conn_id configured' + ) + + @mock.patch('airflow.operators.docker_operator.Client') + def test_execute_with_docker_conn_id_use_hook(self, operator_client_mock): + # Mock out a Docker client, so operations don't raise errors + client_mock = mock.Mock(name='DockerOperator.Client mock', spec=Client) + client_mock.images.return_value = [] + client_mock.create_container.return_value = {'Id': 'some_id'} + client_mock.logs.return_value = [] + client_mock.pull.return_value = [] + client_mock.wait.return_value = 0 + operator_client_mock.return_value = client_mock + + # Create the DockerOperator + operator = DockerOperator( + image='publicregistry/someimage', + owner='unittest', + task_id='unittest', + docker_conn_id='some_conn_id' + ) + + # Mock out the DockerHook + hook_mock = mock.Mock(name='DockerHook mock', spec=DockerHook) + hook_mock.get_conn.return_value = client_mock + operator.get_hook = mock.Mock( + name='DockerOperator.get_hook mock', + spec=DockerOperator.get_hook, + return_value=hook_mock + ) + + operator.execute(None) + self.assertEqual( + operator_client_mock.call_count, 0, + 'Client was called on the operator instead of the hook' + ) + self.assertEqual( + operator.get_hook.call_count, 1, + 'Hook was not called although docker_conn_id configured' + ) + self.assertEqual( + client_mock.pull.call_count, 1, + 'Image was not pulled using operator client' + ) if __name__ == "__main__": unittest.main()