diff --git a/official/vision/ops/augment.py b/official/vision/ops/augment.py index 90c7266d4de..f632c2232fb 100644 --- a/official/vision/ops/augment.py +++ b/official/vision/ops/augment.py @@ -844,19 +844,7 @@ def color(image: tf.Tensor, factor: float) -> tf.Tensor: def contrast(image: tf.Tensor, factor: float) -> tf.Tensor: """Equivalent of PIL Contrast.""" - degenerate = tf.image.rgb_to_grayscale(image) - # Cast before calling tf.histogram. - degenerate = tf.cast(degenerate, tf.int32) - - # Compute the grayscale histogram, then compute the mean pixel value, - # and create a constant image size of that value. Use that as the - # blending degenerate target of the original image. - hist = tf.histogram_fixed_width(degenerate, [0, 255], nbins=256) - mean = tf.reduce_sum(tf.cast(hist, tf.float32)) / 256.0 - degenerate = tf.ones_like(degenerate, dtype=tf.float32) * mean - degenerate = tf.clip_by_value(degenerate, 0.0, 255.0) - degenerate = tf.image.grayscale_to_rgb(tf.cast(degenerate, tf.uint8)) - return blend(degenerate, image, factor) + return tf.image.adjust_contrast(image, factor) def brightness(image: tf.Tensor, factor: float) -> tf.Tensor: