diff --git a/official/vision/ops/augment.py b/official/vision/ops/augment.py index 90c7266d4de..dc45324a034 100644 --- a/official/vision/ops/augment.py +++ b/official/vision/ops/augment.py @@ -2709,8 +2709,8 @@ def distort(self, images: tf.Tensor, @staticmethod def _sample_from_beta(alpha, beta, shape): - sample_alpha = tf.random.gamma(shape, 1., beta=alpha) - sample_beta = tf.random.gamma(shape, 1., beta=beta) + sample_alpha = tf.random.gamma(shape, alpha, beta=1.0) + sample_beta = tf.random.gamma(shape, beta, beta=1.0) return sample_alpha / (sample_alpha + sample_beta) def _cutmix(self, images: tf.Tensor,