diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 47d3447970..3ce0a2aa6a 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -58,8 +58,8 @@ /tensorflow_addons/image/tests/distance_transform_test.py @mels630 /tensorflow_addons/image/distort_image_ops.py @windqaq /tensorflow_addons/image/tests/distort_image_ops_test.py @windqaq -/tensorflow_addons/image/filters.py @mainak431 -/tensorflow_addons/image/tests/filters_test.py @mainak431 +/tensorflow_addons/image/filters.py @mainak431 @ghosalsattam +/tensorflow_addons/image/tests/filters_test.py @mainak431 @ghosalsattam /tensorflow_addons/image/interpolate_spline.py /tensorflow_addons/image/tests/interpolate_spline_test.py /tensorflow_addons/image/resampler_ops.py @autoih diff --git a/tensorflow_addons/image/__init__.py b/tensorflow_addons/image/__init__.py index f285289279..15b92a5380 100644 --- a/tensorflow_addons/image/__init__.py +++ b/tensorflow_addons/image/__init__.py @@ -23,6 +23,7 @@ from tensorflow_addons.image.distance_transform import euclidean_dist_transform from tensorflow_addons.image.dense_image_warp import interpolate_bilinear from tensorflow_addons.image.interpolate_spline import interpolate_spline +from tensorflow_addons.image.filters import gaussian_filter2d from tensorflow_addons.image.filters import mean_filter2d from tensorflow_addons.image.filters import median_filter2d from tensorflow_addons.image.cutout_ops import random_cutout diff --git a/tensorflow_addons/image/filters.py b/tensorflow_addons/image/filters.py index cecf85383a..ead32d4840 100644 --- a/tensorflow_addons/image/filters.py +++ b/tensorflow_addons/image/filters.py @@ -204,3 +204,93 @@ def median_filter2d( output = tf.cast(median, image.dtype) output = img_utils.from_4D_image(output, original_ndims) return output + + +def _get_gaussian_kernel(sigma, filter_shape_1d): + "This function creates a kernel of size [filter_shape]" + x = tf.range(-filter_shape_1d // 2 + 1, filter_shape_1d // 2 + 1) + x = tf.math.square(x) + a = tf.exp(-(x) / (2 * (sigma ** 2))) + a = a / tf.math.reduce_sum(a) + return a + + +def _get_gaussian_kernel_2d(gaussian_filter_x, gaussian_filter_y): + "Compute 2D Gaussian Kernel" + gaussian_kernel = tf.matmul(gaussian_filter_x, gaussian_filter_y) + return gaussian_kernel + + +@tf.function +def gaussian_filter2d( + image: FloatTensorLike, + filter_shape: Union[List[int], Tuple[int]] = [3, 3], + sigma: FloatTensorLike = 1, + padding: str = "REFLECT", + constant_values: TensorLike = 0, + name: Optional[str] = None, +) -> FloatTensorLike: + """Perform Gaussian Blur. + + Args: + image: Either a 2-D `Tensor` of shape `[height, width]`, + a 3-D `Tensor` of shape `[height, width, channels]`, + or a 4-D `Tensor` of shape `[batch_size, height, width, channels]`. + filter_shape: An `integer` or `tuple`/`list` of 2 integers, specifying + the height and width of the 2-D median filter. Can be a single integer + to specify the same value for all spatial dimensions. + sigma: Standard deviation of Gaussian. + padding: A `string`, one of "REFLECT", "CONSTANT", or "SYMMETRIC". + The type of padding algorithm to use, which is compatible with + `mode` argument in `tf.pad`. For more details, please refer to + https://www.tensorflow.org/api_docs/python/tf/pad. + constant_values: A `scalar`, the pad value to use in "CONSTANT" + padding mode. + name: A name for this operation (optional). + Returns: + 3-D or 4-D `Tensor` of the same dtype as input. + Raises: + ValueError: If `image` is not 2, 3 or 4-dimensional, + if `padding` is other than "REFLECT", "CONSTANT" or "SYMMETRIC", + or if `filter_shape` is invalid or sigma<=0. + """ + with tf.name_scope(name or "gaussian_filter2d"): + if sigma <= 0: + raise ValueError("Sigma should not be zero") + if padding not in ["REFLECT", "CONSTANT", "SYMMETRIC"]: + raise ValueError("Padding should be REFLECT, CONSTANT, OR SYMMETRIC") + + image = tf.cast(image, tf.float32) + original_ndims = img_utils.get_ndims(image) + image = img_utils.to_4D_image(image) + channels = tf.shape(image)[3] + filter_shape = keras_utils.normalize_tuple(filter_shape, 2, "filter_shape") + + gaussian_filter_x = _get_gaussian_kernel(sigma, filter_shape[1]) + gaussian_filter_x = tf.cast(gaussian_filter_x, tf.float32) + gaussian_filter_x = tf.reshape(gaussian_filter_x, [1, filter_shape[1]]) + + gaussian_filter_y = _get_gaussian_kernel(sigma, filter_shape[0]) + gaussian_filter_y = tf.reshape(gaussian_filter_y, [filter_shape[0], 1]) + gaussian_filter_y = tf.cast(gaussian_filter_y, tf.float32) + + gaussian_filter_2d = _get_gaussian_kernel_2d( + gaussian_filter_y, gaussian_filter_x + ) + gaussian_filter_2d = tf.repeat(gaussian_filter_2d, channels) + gaussian_filter_2d = tf.reshape( + gaussian_filter_2d, [filter_shape[0], filter_shape[1], channels, 1] + ) + + image = _pad( + image, filter_shape, mode=padding, constant_values=constant_values, + ) + + output = tf.nn.depthwise_conv2d( + input=image, + filter=gaussian_filter_2d, + strides=(1, 1, 1, 1), + padding="VALID", + ) + output = img_utils.from_4D_image(output, original_ndims) + return output diff --git a/tensorflow_addons/image/tests/filters_test.py b/tensorflow_addons/image/tests/filters_test.py index d3d75ecb61..64723f59b7 100644 --- a/tensorflow_addons/image/tests/filters_test.py +++ b/tensorflow_addons/image/tests/filters_test.py @@ -18,6 +18,8 @@ import tensorflow as tf from tensorflow_addons.image import mean_filter2d from tensorflow_addons.image import median_filter2d +from tensorflow_addons.image import gaussian_filter2d +from skimage.filters import gaussian _dtypes_to_test = { tf.dtypes.uint8, @@ -359,3 +361,113 @@ def test_symmetric_padding_with_3x3_filter_median(image_shape): constant_values=0, expected_plane=expected_plane, ) + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +def test_gaussian_filter2d_constant(): + test_image_tf = tf.random.uniform( + [1, 40, 40, 1], minval=0, maxval=255, dtype=tf.float64 + ) + gb = gaussian_filter2d(test_image_tf, 5, 1, padding="CONSTANT") + gb = gb.numpy() + gb1 = np.resize(gb, (40, 40)) + test_image_np = test_image_tf.numpy() + test_image_np = np.resize(test_image_np, [40, 40]) + gb2 = gaussian(test_image_np, 1, mode="constant") + np.testing.assert_allclose(gb2, gb1, 0.06) + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +def test_gaussian_filter2d_reflect(): + test_image_tf = tf.random.uniform( + [1, 40, 40, 1], minval=0, maxval=255, dtype=tf.float32 + ) + gb = gaussian_filter2d(test_image_tf, 5, 1, padding="REFLECT") + gb = gb.numpy() + gb1 = np.resize(gb, (40, 40)) + test_image_np = test_image_tf.numpy() + test_image_np = np.resize(test_image_np, [40, 40]) + gb2 = gaussian(test_image_np, 1, mode="mirror") + np.testing.assert_allclose(gb2, gb1, 0.06) + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +def test_gaussian_filter2d_symmetric(): + test_image_tf = tf.random.uniform( + [1, 40, 40, 1], minval=0, maxval=255, dtype=tf.float64 + ) + gb = gaussian_filter2d(test_image_tf, (5, 5), 1, padding="SYMMETRIC") + gb = gb.numpy() + gb1 = np.resize(gb, (40, 40)) + test_image_np = test_image_tf.numpy() + test_image_np = np.resize(test_image_np, [40, 40]) + gb2 = gaussian(test_image_np, 1, mode="reflect") + np.testing.assert_allclose(gb2, gb1, 0.06) + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +@pytest.mark.parametrize("image_shape", [[2, 5, 5, 3]]) +def test_gaussian_filter2d_batch(image_shape): + test_image_tf = tf.random.uniform( + [1, 40, 40, 1], minval=0, maxval=255, dtype=tf.float32 + ) + gb = gaussian_filter2d(test_image_tf, 5, 1, padding="SYMMETRIC") + gb = gb.numpy() + gb1 = np.resize(gb, (40, 40)) + test_image_np = test_image_tf.numpy() + test_image_np = np.resize(test_image_np, [40, 40]) + gb2 = gaussian(test_image_np, 1, mode="reflect") + np.testing.assert_allclose(gb2, gb1, 0.06) + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +def test_gaussian_filter2d_channels(): + test_image_tf = tf.constant( + [ + [ + [ + [0.0, 0.0, 0.0], + [2.0, 2.0, 0.0], + [4.0, 4.0, 0.0], + [6.0, 6.0, 0.0], + [8.0, 8.0, 0.0], + ], + [ + [10.0, 10.0, 0.0], + [12.0, 12.0, 0.0], + [14.0, 14.0, 0.0], + [16.0, 16.0, 0.0], + [18.0, 18.0, 0.0], + ], + [ + [20.0, 20.0, 0.0], + [22.0, 22.0, 0.0], + [24.0, 24.0, 0.0], + [26.0, 26.0, 0.0], + [28.0, 28.0, 0.0], + ], + [ + [30.0, 30.0, 0.0], + [32.0, 32.0, 0.0], + [34.0, 34.0, 0.0], + [36.0, 36.0, 0.0], + [38.0, 38.0, 0.0], + ], + [ + [40.0, 40.0, 0.0], + [42.0, 42.0, 0.0], + [44.0, 44.0, 0.0], + [46.0, 46.0, 0.0], + [48.0, 48.0, 0.0], + ], + ] + ], + dtype=tf.float32, + ) + gb = gaussian_filter2d(test_image_tf, 5, 1, padding="SYMMETRIC", name="gaussian") + gb = gb.numpy() + gb1 = np.resize(gb, (5, 5, 3)) + test_image_np = test_image_tf.numpy() + test_image_np = np.resize(test_image_np, [5, 5, 3]) + gb2 = gaussian(test_image_np, sigma=1, mode="reflect", multichannel=True) + np.testing.assert_allclose(gb2, gb1, 0.06)