From 3dd0c4d1a9aad613465d637b59989661f0bc7b95 Mon Sep 17 00:00:00 2001 From: Abhineet Choudhary Date: Fri, 27 Mar 2020 13:50:40 +0530 Subject: [PATCH 1/4] init sharpness --- tensorflow_addons/image/sharpness_op.py | 52 ++++++++++++++++++++ tensorflow_addons/image/sharpness_op_test.py | 42 ++++++++++++++++ 2 files changed, 94 insertions(+) create mode 100644 tensorflow_addons/image/sharpness_op.py create mode 100644 tensorflow_addons/image/sharpness_op_test.py diff --git a/tensorflow_addons/image/sharpness_op.py b/tensorflow_addons/image/sharpness_op.py new file mode 100644 index 0000000000..cc1ec532c8 --- /dev/null +++ b/tensorflow_addons/image/sharpness_op.py @@ -0,0 +1,52 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Sharpness Op for color ops""" + +import tensorflow as tf + +from tensorflow_addons.image.compose_ops import blend + + +def sharpness(image, factor): + """Implements Sharpness function from PIL using TF ops.""" + orig_image = image + image = tf.cast(image, tf.float32) + # Make image 4D for conv operation. + image = tf.expand_dims(image, 0) + # SMOOTH PIL Kernel. + kernel = ( + tf.constant( + [[1, 1, 1], [1, 5, 1], [1, 1, 1]], dtype=tf.float32, shape=[3, 3, 1, 1] + ) + / 13.0 + ) + # Tile across channel dimension. + kernel = tf.tile(kernel, [1, 1, 3, 1]) + strides = [1, 1, 1, 1] + degenerate = tf.nn.depthwise_conv2d( + image, kernel, strides, padding="VALID", dilations=[1, 1] + ) + degenerate = tf.clip_by_value(degenerate, 0.0, 255.0) + degenerate = tf.squeeze(tf.cast(degenerate, tf.uint8), [0]) + + # For the borders of the resulting image, fill in the values of the + # original image. + mask = tf.ones_like(degenerate) + padded_mask = tf.pad(mask, [[1, 1], [1, 1], [0, 0]]) + padded_degenerate = tf.pad(degenerate, [[1, 1], [1, 1], [0, 0]]) + result = tf.where(tf.equal(padded_mask, 1), padded_degenerate, orig_image) + result = tf.cast(result, tf.uint8) + # Blend the final result. + return blend(result, orig_image, factor) diff --git a/tensorflow_addons/image/sharpness_op_test.py b/tensorflow_addons/image/sharpness_op_test.py new file mode 100644 index 0000000000..f2e7fe543c --- /dev/null +++ b/tensorflow_addons/image/sharpness_op_test.py @@ -0,0 +1,42 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for shapness""" + +import pytest +import tensorflow as tf +import numpy as np + +from tensorflow_addons.image import sharpness_op +from PIL import Image, ImageEnhance + +_DTYPES = { + np.uint8, + np.int32, + np.int64, + np.float16, + np.float32, + np.float64, +} + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +def test_equalize_with_PIL(): + # np.random.seed(0) + image = np.random.randint(low=0, high=255, size=(5, 5, 3), dtype=np.uint8) + enhancer = ImageEnhance.Sharpness(Image.fromarray(image)) + sharpened = enhancer.enhance(0.5) + np.testing.assert_allclose( + sharpness_op.sharpness(tf.constant(image), 0.5).numpy(), sharpened, atol=1 + ) From 1058a0a62971ae789fe4f22e82d94a573fed9665 Mon Sep 17 00:00:00 2001 From: Abhineet Choudhary Date: Wed, 1 Apr 2020 19:52:37 +0530 Subject: [PATCH 2/4] add batching --- tensorflow_addons/image/color_ops.py | 58 +++++++++++++++++++- tensorflow_addons/image/color_ops_test.py | 23 +++++++- tensorflow_addons/image/sharpness_op.py | 52 ------------------ tensorflow_addons/image/sharpness_op_test.py | 42 -------------- 4 files changed, 78 insertions(+), 97 deletions(-) delete mode 100644 tensorflow_addons/image/sharpness_op.py delete mode 100644 tensorflow_addons/image/sharpness_op_test.py diff --git a/tensorflow_addons/image/color_ops.py b/tensorflow_addons/image/color_ops.py index dc40f66611..1ac93c78d3 100644 --- a/tensorflow_addons/image/color_ops.py +++ b/tensorflow_addons/image/color_ops.py @@ -14,12 +14,14 @@ # ============================================================================== """Color operations. equalize: Equalizes image histogram + sharpness: Sharpen image """ import tensorflow as tf -from tensorflow_addons.utils.types import TensorLike +from tensorflow_addons.utils.types import TensorLike, Number from tensorflow_addons.image.utils import to_4D_image, from_4D_image +from tensorflow_addons.image.compose_ops import blend from typing import Optional from functools import partial @@ -84,7 +86,7 @@ def equalize( (num_images, num_rows, num_columns, num_channels) (NHWC), or (num_images, num_channels, num_rows, num_columns) (NCHW), or (num_rows, num_columns, num_channels) (HWC), or - (num_channels, num_rows, num_columns) (HWC), or + (num_channels, num_rows, num_columns) (CHW), or (num_rows, num_columns) (HW). The rank must be statically known (the shape is not `TensorShape(None)`). data_format: Either 'channels_first' or 'channels_last' @@ -98,3 +100,55 @@ def equalize( fn = partial(equalize_image, data_format=data_format) image = tf.map_fn(fn, image) return from_4D_image(image, image_dims) + + +def sharpness_image(image: TensorLike, factor: Number) -> tf.Tensor: + """Implements Sharpness function from PIL using TF ops.""" + orig_image = image + image_dtype = image.dtype + # Make image 4D for conv operation. + image = tf.expand_dims(image, 0) + # SMOOTH PIL Kernel. + image = tf.cast(image, tf.float32) + kernel = ( + tf.constant( + [[1, 1, 1], [1, 5, 1], [1, 1, 1]], dtype=tf.float32, shape=[3, 3, 1, 1] + ) + / 13.0 + ) + # Tile across channel dimension. + kernel = tf.tile(kernel, [1, 1, 3, 1]) + strides = [1, 1, 1, 1] + degenerate = tf.nn.depthwise_conv2d( + image, kernel, strides, padding="VALID", dilations=[1, 1] + ) + degenerate = tf.clip_by_value(degenerate, 0.0, 255.0) + degenerate = tf.squeeze(tf.cast(degenerate, image_dtype), [0]) + + # For the borders of the resulting image, fill in the values of the + # original image. + mask = tf.ones_like(degenerate) + padded_mask = tf.pad(mask, [[1, 1], [1, 1], [0, 0]]) + padded_degenerate = tf.pad(degenerate, [[1, 1], [1, 1], [0, 0]]) + result = tf.where(tf.equal(padded_mask, 1), padded_degenerate, orig_image) + # Blend the final result. + blended = blend(result, orig_image, factor) + return tf.cast(blended, image_dtype) + + +def sharpness(image: TensorLike, factor: Number) -> tf.Tensor: + """Change sharpness of image(s) + + Args: + images: A tensor of shape + (num_images, num_rows, num_columns, num_channels) (NHWC), or + (num_rows, num_columns, num_channels) (HWC) + factor: A floating point value or Tensor above 0.0. + Returns: + Image(s) with the same type and shape as `images`. + """ + image_dims = tf.rank(image) + image = to_4D_image(image) + fn = partial(sharpness_image, factor=factor) + image = tf.map_fn(fn, image) + return from_4D_image(image, image_dims) diff --git a/tensorflow_addons/image/color_ops_test.py b/tensorflow_addons/image/color_ops_test.py index 780066c364..e4e22cbd20 100644 --- a/tensorflow_addons/image/color_ops_test.py +++ b/tensorflow_addons/image/color_ops_test.py @@ -19,7 +19,7 @@ import numpy as np from tensorflow_addons.image import color_ops -from PIL import Image, ImageOps +from PIL import Image, ImageOps, ImageEnhance _DTYPES = { np.uint8, @@ -53,3 +53,24 @@ def test_equalize_channel_first(shape): image = tf.ones(shape=shape, dtype=tf.uint8) equalized = color_ops.equalize(image, "channels_first") np.testing.assert_equal(equalized.numpy(), image.numpy()) + + +@pytest.mark.parametrize("dtype", _DTYPES) +@pytest.mark.parametrize("shape", [(5, 5, 3), (10, 5, 5, 3)]) +def test_sharpness_dtype_shape(dtype, shape): + image = np.ones(shape=shape, dtype=dtype) + sharp = color_ops.sharpness(tf.constant(image), 0).numpy() + np.testing.assert_equal(sharp, image) + assert sharp.dtype == image.dtype + + +@pytest.mark.usefixtures("maybe_run_functions_eagerly") +def test_sharpness_with_PIL(): + np.random.seed(0) + image = np.random.randint(low=0, high=255, size=(10, 5, 5, 3), dtype=np.uint8) + sharpened = np.stack( + [ImageEnhance.Sharpness(Image.fromarray(i)).enhance(0.5) for i in image] + ) + np.testing.assert_allclose( + color_ops.sharpness(tf.constant(image), 0.5).numpy(), sharpened, atol=1 + ) diff --git a/tensorflow_addons/image/sharpness_op.py b/tensorflow_addons/image/sharpness_op.py deleted file mode 100644 index cc1ec532c8..0000000000 --- a/tensorflow_addons/image/sharpness_op.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Sharpness Op for color ops""" - -import tensorflow as tf - -from tensorflow_addons.image.compose_ops import blend - - -def sharpness(image, factor): - """Implements Sharpness function from PIL using TF ops.""" - orig_image = image - image = tf.cast(image, tf.float32) - # Make image 4D for conv operation. - image = tf.expand_dims(image, 0) - # SMOOTH PIL Kernel. - kernel = ( - tf.constant( - [[1, 1, 1], [1, 5, 1], [1, 1, 1]], dtype=tf.float32, shape=[3, 3, 1, 1] - ) - / 13.0 - ) - # Tile across channel dimension. - kernel = tf.tile(kernel, [1, 1, 3, 1]) - strides = [1, 1, 1, 1] - degenerate = tf.nn.depthwise_conv2d( - image, kernel, strides, padding="VALID", dilations=[1, 1] - ) - degenerate = tf.clip_by_value(degenerate, 0.0, 255.0) - degenerate = tf.squeeze(tf.cast(degenerate, tf.uint8), [0]) - - # For the borders of the resulting image, fill in the values of the - # original image. - mask = tf.ones_like(degenerate) - padded_mask = tf.pad(mask, [[1, 1], [1, 1], [0, 0]]) - padded_degenerate = tf.pad(degenerate, [[1, 1], [1, 1], [0, 0]]) - result = tf.where(tf.equal(padded_mask, 1), padded_degenerate, orig_image) - result = tf.cast(result, tf.uint8) - # Blend the final result. - return blend(result, orig_image, factor) diff --git a/tensorflow_addons/image/sharpness_op_test.py b/tensorflow_addons/image/sharpness_op_test.py deleted file mode 100644 index f2e7fe543c..0000000000 --- a/tensorflow_addons/image/sharpness_op_test.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright 2020 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for shapness""" - -import pytest -import tensorflow as tf -import numpy as np - -from tensorflow_addons.image import sharpness_op -from PIL import Image, ImageEnhance - -_DTYPES = { - np.uint8, - np.int32, - np.int64, - np.float16, - np.float32, - np.float64, -} - - -@pytest.mark.usefixtures("maybe_run_functions_eagerly") -def test_equalize_with_PIL(): - # np.random.seed(0) - image = np.random.randint(low=0, high=255, size=(5, 5, 3), dtype=np.uint8) - enhancer = ImageEnhance.Sharpness(Image.fromarray(image)) - sharpened = enhancer.enhance(0.5) - np.testing.assert_allclose( - sharpness_op.sharpness(tf.constant(image), 0.5).numpy(), sharpened, atol=1 - ) From 4dee2d56c61a8f13e8af67ef8e6cb80e17c0e443 Mon Sep 17 00:00:00 2001 From: Abhineet Choudhary Date: Sat, 4 Apr 2020 23:26:53 +0530 Subject: [PATCH 3/4] test commit --- tensorflow_addons/image/color_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_addons/image/color_ops.py b/tensorflow_addons/image/color_ops.py index 1ac93c78d3..dd9d4b47a6 100644 --- a/tensorflow_addons/image/color_ops.py +++ b/tensorflow_addons/image/color_ops.py @@ -145,7 +145,7 @@ def sharpness(image: TensorLike, factor: Number) -> tf.Tensor: (num_rows, num_columns, num_channels) (HWC) factor: A floating point value or Tensor above 0.0. Returns: - Image(s) with the same type and shape as `images`. + Image(s) with the same type and shape as `images`, sharper. """ image_dims = tf.rank(image) image = to_4D_image(image) From 428232208f6db707c03d5bb4b0e6c5ad2ac1c7f8 Mon Sep 17 00:00:00 2001 From: Abhineet Choudhary Date: Tue, 7 Apr 2020 20:25:14 +0530 Subject: [PATCH 4/4] test factors --- tensorflow_addons/image/tests/color_ops_test.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow_addons/image/tests/color_ops_test.py b/tensorflow_addons/image/tests/color_ops_test.py index e4e22cbd20..8484f97548 100644 --- a/tensorflow_addons/image/tests/color_ops_test.py +++ b/tensorflow_addons/image/tests/color_ops_test.py @@ -64,13 +64,13 @@ def test_sharpness_dtype_shape(dtype, shape): assert sharp.dtype == image.dtype -@pytest.mark.usefixtures("maybe_run_functions_eagerly") -def test_sharpness_with_PIL(): +@pytest.mark.parametrize("factor", [0, 0.25, 0.5, 0.75, 1]) +def test_sharpness_with_PIL(factor): np.random.seed(0) image = np.random.randint(low=0, high=255, size=(10, 5, 5, 3), dtype=np.uint8) sharpened = np.stack( - [ImageEnhance.Sharpness(Image.fromarray(i)).enhance(0.5) for i in image] + [ImageEnhance.Sharpness(Image.fromarray(i)).enhance(factor) for i in image] ) np.testing.assert_allclose( - color_ops.sharpness(tf.constant(image), 0.5).numpy(), sharpened, atol=1 + color_ops.sharpness(tf.constant(image), factor).numpy(), sharpened, atol=1 )