From fdbfc02a8065f7587efbc4e27a0af4ce01c8c1a2 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 30 Jan 2026 12:14:38 -0800 Subject: [PATCH] Modify custom BatchNorm layer such that the mean and variance variables have the same rank whether train=False or train=True. PiperOrigin-RevId: 863341434 --- init2winit/model_lib/deepspeech.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/init2winit/model_lib/deepspeech.py b/init2winit/model_lib/deepspeech.py index fbe7b2e0..d0e32567 100644 --- a/init2winit/model_lib/deepspeech.py +++ b/init2winit/model_lib/deepspeech.py @@ -401,9 +401,9 @@ def __call__(self, inputs, input_paddings=None, train=False): if train: mask = 1.0 - padding - sum_v = jnp.sum(inputs * mask, axis=reduce_over_dims, keepdims=True) + sum_v = jnp.sum(inputs * mask, axis=reduce_over_dims, keepdims=False) count_v = jnp.sum( - jnp.ones_like(inputs) * mask, axis=reduce_over_dims, keepdims=True) + jnp.ones_like(inputs) * mask, axis=reduce_over_dims, keepdims=False) count_v = jnp.maximum(count_v, 1.0) mean = sum_v / count_v @@ -411,7 +411,7 @@ def __call__(self, inputs, input_paddings=None, train=False): sum_vv = jnp.sum( (inputs - mean) * (inputs - mean) * mask, axis=reduce_over_dims, - keepdims=True) + keepdims=False) var = sum_vv / count_v