diff --git a/airflow/contrib/hooks/__init__.py b/airflow/contrib/hooks/__init__.py index 6d45acea7e936..99a1746ee6551 100644 --- a/airflow/contrib/hooks/__init__.py +++ b/airflow/contrib/hooks/__init__.py @@ -32,6 +32,7 @@ # # ------------------------------------------------------------------------ _hooks = { + 'docker_hook': ['DockerHook'], 'ftp_hook': ['FTPHook'], 'ftps_hook': ['FTPSHook'], 'vertica_hook': ['VerticaHook'], diff --git a/airflow/hooks/docker_hook.py b/airflow/hooks/docker_hook.py new file mode 100644 index 0000000000000..a570292ead2b9 --- /dev/null +++ b/airflow/hooks/docker_hook.py @@ -0,0 +1,79 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from docker import Client +from docker.errors import APIError + +from airflow.exceptions import AirflowException +from airflow.hooks.base_hook import BaseHook +from airflow.utils.log.logging_mixin import LoggingMixin + +class DockerHook(BaseHook, LoggingMixin): + """ + Interact with a private Docker registry. + + :param docker_conn_id: ID of the Airflow connection where + credentials and extra configuration are stored + :type docker_conn_id: str + """ + def __init__(self, + docker_conn_id='docker_default', + base_url=None, + version=None, + tls=None + ): + if not base_url: + raise AirflowException('No Docker base URL provided') + if not version: + raise AirflowException('No Docker API version provided') + + conn = self.get_connection(docker_conn_id) + if not conn.host: + raise AirflowException('No Docker registry URL provided') + if not conn.login: + raise AirflowException('No username provided') + extra_options = conn.extra_dejson + + self.__base_url = base_url + self.__version = version + self.__tls = tls + self.__registry = conn.host + self.__username = conn.login + self.__password = conn.password + self.__email = extra_options.get('email') + self.__reauth = False if extra_options.get('reauth') == 'no' else True + + def get_conn(self): + client = Client( + base_url=self.__base_url, + version=self.__version, + tls=self.__tls + ) + self.__login(client) + return client + + def __login(self, client): + self.log.debug('Logging into Docker registry') + try: + client.login( + username=self.__username, + password=self.__password, + registry=self.__registry, + email=self.__email, + reauth=self.__reauth + ) + self.log.debug('Login successful') + except APIError as docker_error: + self.log.error('Docker registry login failed: %s', str(docker_error)) + raise AirflowException('Docker registry login failed: %s', str(docker_error)) diff --git a/airflow/models.py b/airflow/models.py index a82fec64fc8ba..b4b2325834c1a 100755 --- a/airflow/models.py +++ b/airflow/models.py @@ -530,6 +530,7 @@ class Connection(Base, LoggingMixin): _extra = Column('extra', String(5000)) _types = [ + ('docker', 'Docker Registry',), ('fs', 'File (path)'), ('ftp', 'FTP',), ('google_cloud_platform', 'Google Cloud Platform'), @@ -696,6 +697,9 @@ def get_hook(self): elif self.conn_type == 'wasb': from airflow.contrib.hooks.wasb_hook import WasbHook return WasbHook(wasb_conn_id=self.conn_id) + elif self.conn_type == 'docker': + from airflow.hooks.docker_hook import DockerHook + return DockerHook(docker_conn_id=self.conn_id) except: pass diff --git a/airflow/operators/docker_operator.py b/airflow/operators/docker_operator.py index f319cf44a796a..38edc8b4d79ab 100644 --- a/airflow/operators/docker_operator.py +++ b/airflow/operators/docker_operator.py @@ -14,11 +14,12 @@ import json +from airflow.hooks.docker_hook import DockerHook from airflow.exceptions import AirflowException from airflow.models import BaseOperator from airflow.utils.decorators import apply_defaults from airflow.utils.file import TemporaryDirectory -from docker import APIClient as Client, tls +from docker import Client, tls import ast @@ -30,9 +31,14 @@ class DockerOperator(BaseOperator): that together exceed the default disk size of 10GB in a container. The path to the mounted directory can be accessed via the environment variable ``AIRFLOW_TMP_DIR``. + If a login to a private registry is required prior to pulling the image, a + Docker connection needs to be configured in Airflow and the connection ID + be provided with the parameter ``docker_conn_id``. + :param image: Docker image from which to create the container. :type image: str - :param api_version: Remote API version. + :param api_version: Remote API version. Set to ``auto`` to automatically + detect the server's version. :type api_version: str :param command: Command to be run in the container. :type command: str or list @@ -41,10 +47,11 @@ class DockerOperator(BaseOperator): https://docs.docker.com/engine/reference/run/#cpu-share-constraint :type cpus: float :param docker_url: URL of the host running the docker daemon. + Default is unix://var/run/docker.sock :type docker_url: str :param environment: Environment variables to set in the container. :type environment: dict - :param force_pull: Pull the docker image on every run. + :param force_pull: Pull the docker image on every run. Default is false. :type force_pull: bool :param mem_limit: Maximum amount of memory the container can use. Either a float value, which represents the limit in bytes, or a string like ``128m`` or ``1g``. @@ -78,6 +85,8 @@ class DockerOperator(BaseOperator): :type xcom_push: bool :param xcom_all: Push all the stdout or just the last line. The default is False (last line). :type xcom_all: bool + :param docker_conn_id: ID of the Airflow connection to use + :type docker_conn_id: str """ template_fields = ('command',) template_ext = ('.sh', '.bash',) @@ -105,6 +114,7 @@ def __init__( working_dir=None, xcom_push=False, xcom_all=False, + docker_conn_id=None, *args, **kwargs): @@ -129,25 +139,32 @@ def __init__( self.working_dir = working_dir self.xcom_push_flag = xcom_push self.xcom_all = xcom_all + self.docker_conn_id = docker_conn_id self.cli = None self.container = None + def get_hook(self): + return DockerHook( + docker_conn_id=self.docker_conn_id, + base_url=self.base_url, + version=self.api_version, + tls=self.__get_tls_config() + ) + def execute(self, context): self.log.info('Starting docker container from image %s', self.image) - tls_config = None - if self.tls_ca_cert and self.tls_client_cert and self.tls_client_key: - tls_config = tls.TLSConfig( - ca_cert=self.tls_ca_cert, - client_cert=(self.tls_client_cert, self.tls_client_key), - verify=True, - ssl_version=self.tls_ssl_version, - assert_hostname=self.tls_hostname - ) - self.docker_url = self.docker_url.replace('tcp://', 'https://') + tls_config = self.__get_tls_config() - self.cli = Client(base_url=self.docker_url, version=self.api_version, tls=tls_config) + if self.docker_conn_id: + self.cli = self.get_hook().get_conn() + else: + self.cli = Client( + base_url=self.docker_url, + version=self.api_version, + tls=tls_config + ) if ':' not in self.image: image = self.image + ':latest' @@ -204,3 +221,16 @@ def on_kill(self): if self.cli is not None: self.log.info('Stopping docker container') self.cli.stop(self.container['Id']) + + def __get_tls_config(self): + tls_config = None + if self.tls_ca_cert and self.tls_client_cert and self.tls_client_key: + tls_config = tls.TLSConfig( + ca_cert=self.tls_ca_cert, + client_cert=(self.tls_client_cert, self.tls_client_key), + verify=True, + ssl_version=self.tls_ssl_version, + assert_hostname=self.tls_hostname + ) + self.docker_url = self.docker_url.replace('tcp://', 'https://') + return tls_config diff --git a/airflow/www/static/connection_form.js b/airflow/www/static/connection_form.js index 0324bcfffb2d2..c40bba7620b6f 100644 --- a/airflow/www/static/connection_form.js +++ b/airflow/www/static/connection_form.js @@ -38,7 +38,14 @@ 'login': 'Username (or API Key)', 'schema': 'Database' } - } + }, + docker: { + hidden_fields: ['port', 'schema'], + relabeling: { + 'host': 'Registry URL', + 'login': 'Username', + }, + }, } function connTypeChange(connectionType) { $("div.form-group").removeClass("hide"); diff --git a/docs/code.rst b/docs/code.rst index 4a6718f556f07..1369b323f8409 100644 --- a/docs/code.rst +++ b/docs/code.rst @@ -216,6 +216,7 @@ Hooks :show-inheritance: :members: DbApiHook, + DockerHook, HiveCliHook, HiveMetastoreHook, HiveServer2Hook, diff --git a/tests/hooks/test_docker_hook.py b/tests/hooks/test_docker_hook.py new file mode 100644 index 0000000000000..4a77523db6714 --- /dev/null +++ b/tests/hooks/test_docker_hook.py @@ -0,0 +1,176 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from airflow import configuration +from airflow import models +from airflow.exceptions import AirflowException +from airflow.utils import db + +try: + from airflow.hooks.docker_hook import DockerHook + from docker import Client +except ImportError: + pass + +try: + from unittest import mock +except ImportError: + try: + import mock + except ImportError: + mock = None + + +@mock.patch('airflow.hooks.docker_hook.Client', autospec=True) +class DockerHookTest(unittest.TestCase): + def setUp(self): + configuration.load_test_config() + db.merge_conn( + models.Connection( + conn_id='docker_default', + conn_type='docker', + host='some.docker.registry.com', + login='some_user', + password='some_p4$$w0rd' + ) + ) + db.merge_conn( + models.Connection( + conn_id='docker_with_extras', + conn_type='docker', + host='some.docker.registry.com', + login='some_user', + password='some_p4$$w0rd', + extra='{"email": "some@example.com", "reauth": "no"}' + ) + ) + + def test_init_fails_when_no_base_url_given(self, _): + with self.assertRaises(AirflowException): + DockerHook( + docker_conn_id='docker_default', + version='auto', + tls=None + ) + + def test_init_fails_when_no_api_version_given(self, _): + with self.assertRaises(AirflowException): + DockerHook( + docker_conn_id='docker_default', + base_url='unix://var/run/docker.sock', + tls=None + ) + + def test_get_conn_override_defaults(self, docker_client_mock): + hook = DockerHook( + docker_conn_id='docker_default', + base_url='https://index.docker.io/v1/', + version='1.23', + tls='someconfig' + ) + hook.get_conn() + docker_client_mock.assert_called_with( + base_url='https://index.docker.io/v1/', + version='1.23', + tls='someconfig' + ) + + def test_get_conn_with_standard_config(self, _): + try: + hook = DockerHook( + docker_conn_id='docker_default', + base_url='unix://var/run/docker.sock', + version='auto' + ) + client = hook.get_conn() + self.assertIsNotNone(client) + except: + self.fail('Could not get connection from Airflow') + + def test_get_conn_with_extra_config(self, _): + try: + hook = DockerHook( + docker_conn_id='docker_with_extras', + base_url='unix://var/run/docker.sock', + version='auto' + ) + client = hook.get_conn() + self.assertIsNotNone(client) + except: + self.fail('Could not get connection from Airflow') + + def test_conn_with_standard_config_passes_parameters(self, _): + hook = DockerHook( + docker_conn_id='docker_default', + base_url='unix://var/run/docker.sock', + version='auto' + ) + client = hook.get_conn() + client.login.assert_called_with( + username='some_user', + password='some_p4$$w0rd', + registry='some.docker.registry.com', + reauth=True, + email=None + ) + + def test_conn_with_extra_config_passes_parameters(self, _): + hook = DockerHook( + docker_conn_id='docker_with_extras', + base_url='unix://var/run/docker.sock', + version='auto' + ) + client = hook.get_conn() + client.login.assert_called_with( + username='some_user', + password='some_p4$$w0rd', + registry='some.docker.registry.com', + reauth=False, + email='some@example.com' + ) + + def test_conn_with_broken_config_missing_username_fails(self, _): + db.merge_conn( + models.Connection( + conn_id='docker_without_username', + conn_type='docker', + host='some.docker.registry.com', + password='some_p4$$w0rd', + extra='{"email": "some@example.com"}' + ) + ) + with self.assertRaises(AirflowException): + hook = DockerHook( + docker_conn_id='docker_without_username', + base_url='unix://var/run/docker.sock', + version='auto' + ) + + def test_conn_with_broken_config_missing_host_fails(self, _): + db.merge_conn( + models.Connection( + conn_id='docker_without_host', + conn_type='docker', + login='some_user', + password='some_p4$$w0rd' + ) + ) + with self.assertRaises(AirflowException): + hook = DockerHook( + docker_conn_id='docker_without_host', + base_url='unix://var/run/docker.sock', + version='auto' + ) diff --git a/tests/operators/docker_operator.py b/tests/operators/docker_operator.py index 84585806ca02a..a12b6f829f8a3 100644 --- a/tests/operators/docker_operator.py +++ b/tests/operators/docker_operator.py @@ -17,7 +17,8 @@ try: from airflow.operators.docker_operator import DockerOperator - from docker import APIClient as Client + from airflow.hooks.docker_hook import DockerHook + from docker import Client except ImportError: pass @@ -33,7 +34,6 @@ class DockerOperatorTestCase(unittest.TestCase): - @unittest.skipIf(mock is None, 'mock package not present') @mock.patch('airflow.utils.file.mkdtemp') @mock.patch('airflow.operators.docker_operator.Client') def test_execute(self, client_class_mock, mkdtemp_mock): @@ -77,7 +77,6 @@ def test_execute(self, client_class_mock, mkdtemp_mock): client_mock.pull.assert_called_with('ubuntu:latest', stream=True) client_mock.wait.assert_called_with('some_id') - @unittest.skipIf(mock is None, 'mock package not present') @mock.patch('airflow.operators.docker_operator.tls.TLSConfig') @mock.patch('airflow.operators.docker_operator.Client') def test_execute_tls(self, client_class_mock, tls_class_mock): @@ -105,7 +104,6 @@ def test_execute_tls(self, client_class_mock, tls_class_mock): client_class_mock.assert_called_with(base_url='https://127.0.0.1:2376', tls=tls_mock, version=None) - @unittest.skipIf(mock is None, 'mock package not present') @mock.patch('airflow.operators.docker_operator.Client') def test_execute_unicode_logs(self, client_class_mock): client_mock = mock.Mock(spec=Client) @@ -128,7 +126,6 @@ def test_execute_unicode_logs(self, client_class_mock): logging.raiseExceptions = originalRaiseExceptions print_exception_mock.assert_not_called() - @unittest.skipIf(mock is None, 'mock package not present') @mock.patch('airflow.operators.docker_operator.Client') def test_execute_container_fails(self, client_class_mock): client_mock = mock.Mock(spec=Client) @@ -146,7 +143,6 @@ def test_execute_container_fails(self, client_class_mock): with self.assertRaises(AirflowException): operator.execute(None) - @unittest.skipIf(mock is None, 'mock package not present') def test_on_kill(self): client_mock = mock.Mock(spec=Client) @@ -158,6 +154,80 @@ def test_on_kill(self): client_mock.stop.assert_called_with('some_id') + @mock.patch('airflow.operators.docker_operator.Client') + def test_execute_no_docker_conn_id_no_hook(self, operator_client_mock): + # Mock out a Docker client, so operations don't raise errors + client_mock = mock.Mock(name='DockerOperator.Client mock', spec=Client) + client_mock.images.return_value = [] + client_mock.create_container.return_value = {'Id': 'some_id'} + client_mock.logs.return_value = [] + client_mock.pull.return_value = [] + client_mock.wait.return_value = 0 + operator_client_mock.return_value = client_mock + + # Create the DockerOperator + operator = DockerOperator( + image='publicregistry/someimage', + owner='unittest', + task_id='unittest' + ) + + # Mock out the DockerHook + hook_mock = mock.Mock(name='DockerHook mock', spec=DockerHook) + hook_mock.get_conn.return_value = client_mock + operator.get_hook = mock.Mock( + name='DockerOperator.get_hook mock', + spec=DockerOperator.get_hook, + return_value=hook_mock + ) + + operator.execute(None) + self.assertEqual( + operator.get_hook.call_count, 0, + 'Hook called though no docker_conn_id configured' + ) + + @mock.patch('airflow.operators.docker_operator.Client') + def test_execute_with_docker_conn_id_use_hook(self, operator_client_mock): + # Mock out a Docker client, so operations don't raise errors + client_mock = mock.Mock(name='DockerOperator.Client mock', spec=Client) + client_mock.images.return_value = [] + client_mock.create_container.return_value = {'Id': 'some_id'} + client_mock.logs.return_value = [] + client_mock.pull.return_value = [] + client_mock.wait.return_value = 0 + operator_client_mock.return_value = client_mock + + # Create the DockerOperator + operator = DockerOperator( + image='publicregistry/someimage', + owner='unittest', + task_id='unittest', + docker_conn_id='some_conn_id' + ) + + # Mock out the DockerHook + hook_mock = mock.Mock(name='DockerHook mock', spec=DockerHook) + hook_mock.get_conn.return_value = client_mock + operator.get_hook = mock.Mock( + name='DockerOperator.get_hook mock', + spec=DockerOperator.get_hook, + return_value=hook_mock + ) + + operator.execute(None) + self.assertEqual( + operator_client_mock.call_count, 0, + 'Client was called on the operator instead of the hook' + ) + self.assertEqual( + operator.get_hook.call_count, 1, + 'Hook was not called although docker_conn_id configured' + ) + self.assertEqual( + client_mock.pull.call_count, 1, + 'Image was not pulled using operator client' + ) if __name__ == "__main__": unittest.main()