From d53c9fe287a87b26c7c62d359b1f63f0ede5751a Mon Sep 17 00:00:00 2001 From: Matt Bahr Date: Tue, 11 Mar 2025 22:02:54 +0000 Subject: [PATCH] implement spectral and instance norm --- flax/nnx/__init__.py | 2 + flax/nnx/nn/normalization.py | 358 ++++++++++++++++++++++++++++- tests/nnx/nn/normalization_test.py | 139 +++++++++++ 3 files changed, 498 insertions(+), 1 deletion(-) diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index 1d75d06b9..67cd3fb2b 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -104,6 +104,8 @@ from .nn.normalization import LayerNorm as LayerNorm from .nn.normalization import RMSNorm as RMSNorm from .nn.normalization import GroupNorm as GroupNorm +from .nn.normalization import InstanceNorm as InstanceNorm +from .nn.normalization import SpectralNorm as SpectralNorm from .nn.stochastic import Dropout as Dropout from .rnglib import Rngs as Rngs from .rnglib import RngStream as RngStream diff --git a/flax/nnx/nn/normalization.py b/flax/nnx/nn/normalization.py index 72c6450cf..9c5623a3d 100644 --- a/flax/nnx/nn/normalization.py +++ b/flax/nnx/nn/normalization.py @@ -181,6 +181,24 @@ def _normalize( return jnp.asarray(y, dtype) +def _l2_normalize(x, axis=None, eps=1e-12): + """Normalizes along dimension `axis` using an L2 norm. + + This specialized function exists for numerical stability reasons. + + Args: + x: An input ndarray. + axis: Dimension along which to normalize, e.g. `1` to separately normalize + vectors in a batch. Passing `None` views `t` as a flattened vector when + calculating the norm (equivalent to Frobenius norm). + eps: Epsilon to avoid dividing by zero. + + Returns: + An array of the same shape as 'x' L2-normalized along 'axis'. + """ + return x * jax.lax.rsqrt((x * x).sum(axis=axis, keepdims=True) + eps) + + class BatchNorm(Module): """BatchNorm Module. @@ -835,4 +853,342 @@ def __call__(self, x, *, mask: tp.Optional[jax.Array] = None): (self.feature_axis,), self.dtype, self.epsilon, - ) \ No newline at end of file + ) + + +class InstanceNorm(Module): + """Instance normalization (https://arxiv.org/abs/1607.08022v3). + + InstanceNorm normalizes the activations of the layer for each channel (rather + than across all channels like Layer Normalization), and for each given example + in a batch independently (rather than across an entire batch like Batch + Normalization). i.e. applies a transformation that maintains the mean activation + within each channel within each example close to 0 and the activation standard + deviation close to 1. + .. note:: + This normalization operation is identical to LayerNorm and GroupNorm; the + difference is simply which axes are reduced and the shape of the feature axes + (i.e. the shape of the learnable scale and bias parameters). + + Example usage:: + + >>> from flax import nnx + >>> import jax + >>> import numpy as np + >>> # dimensions: (batch, height, width, channel) + >>> x = jax.random.normal(jax.random.key(0), (2, 3, 4, 5)) + >>> layer = nnx.InstanceNorm(5, rngs=nnx.Rngs(0)) + >>> nnx.state(layer, nnx.OfType(nnx.Param)) + State({ + 'bias': VariableState( # 5 (20 B) + type=Param, + value=Array([0., 0., 0., 0., 0.], dtype=float32) + ), + 'scale': VariableState( # 5 (20 B) + type=Param, + value=Array([1., 1., 1., 1., 1.], dtype=float32) + ) + }) + >>> y = layer(x) + >>> # having a channel_axis of -1 in InstanceNorm is identical to reducing all non-batch, + >>> # non-channel axes and using the feature_axes as the feature_axes in LayerNorm + >>> y2 = nnx.LayerNorm(5, reduction_axes=[1, 2], feature_axes=-1, rngs=nnx.Rngs(0))(x) + >>> np.testing.assert_allclose(y, y2, atol=1e-7) + >>> y3 = nnx.GroupNorm(5, num_groups=x.shape[-1], rngs=nnx.Rngs(0))(x) + >>> np.testing.assert_allclose(y, y3, atol=1e-7) + + Args: + num_features: the number of input features/channels. + epsilon: A small float added to variance to avoid dividing by zero. + dtype: the dtype of the result (default: infer from input and params). + param_dtype: the dtype passed to parameter initializers (default: float32). + use_bias: If True, bias (beta) is added. + use_scale: If True, multiply by scale (gamma). When the next layer is linear + (also e.g. nn.relu), this can be disabled since the scaling will be done + by the next layer. + bias_init: Initializer for bias, by default, zero. + scale_init: Initializer for scale, by default, one. + feature_axes: Axes for features. The learned bias and scaling parameters will + be in the shape defined by the feature axes. All other axes except the batch + axes (which is assumed to be the leading axis) will be reduced. + axis_name: the axis name used to combine batch statistics from multiple + devices. See ``jax.pmap`` for a description of axis names (default: None). + This is only needed if the model is subdivided across devices, i.e. the + array being normalized is sharded across devices within a pmap or shard + map. For SPMD jit, you do not need to manually synchronize. Just make sure + that the axes are correctly annotated and XLA:SPMD will insert the + necessary collectives. + axis_index_groups: groups of axis indices within that named axis + representing subsets of devices to reduce over (default: None). For + example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over the + examples on the first two and last two devices. See ``jax.lax.psum`` for + more details. + use_fast_variance: If true, use a faster, but less numerically stable, + calculation for the variance. + rngs: The rng key. + """ + + def __init__( + self, + num_features: int, + *, + epsilon: float = 1e-6, + dtype: tp.Optional[Dtype] = None, + param_dtype: Dtype = jnp.float32, + use_bias: bool = True, + use_scale: bool = True, + bias_init: Initializer = initializers.zeros, + scale_init: Initializer = initializers.ones, + feature_axes: Axes = -1, + axis_name: tp.Optional[str] = None, + axis_index_groups: tp.Any = None, + use_fast_variance: bool = True, + rngs: rnglib.Rngs, + ): + feature_shape = (num_features,) + self.scale: nnx.Param[jax.Array] | None + if use_scale: + key = rngs.params() + self.scale = nnx.Param(scale_init(key, feature_shape, param_dtype)) + else: + self.scale = None + + self.bias: nnx.Param[jax.Array] | None + if use_bias: + key = rngs.params() + self.bias = nnx.Param(bias_init(key, feature_shape, param_dtype)) + else: + self.bias = None + + self.num_features = num_features + self.epsilon = epsilon + self.dtype = dtype + self.param_dtype = param_dtype + self.use_bias = use_bias + self.use_scale = use_scale + self.bias_init = bias_init + self.scale_init = scale_init + self.feature_axes = feature_axes + self.axis_name = axis_name + self.axis_index_groups = axis_index_groups + self.use_fast_variance = use_fast_variance + self.rngs = rngs + + def __call__(self, x, *, mask: tp.Optional[jax.Array] = None): + """Applies instance normalization on the input. + + Args: + x: the inputs + mask: Binary array of shape broadcastable to ``inputs`` tensor, indicating + the positions for which the mean and variance should be computed. + + Returns: + Normalized inputs (the same shape as inputs). + """ + feature_axes = _canonicalize_axes(x.ndim, self.feature_axes) + if 0 in feature_axes: + raise ValueError('The channel axes cannot include the leading dimension ' + 'as this is assumed to be the batch axis.') + reduction_axes = [i for i in range(1, x.ndim) if i not in feature_axes] + + mean, var = _compute_stats( + x, + reduction_axes, + self.dtype, + self.axis_name, + self.axis_index_groups, + use_fast_variance=self.use_fast_variance, + mask=mask, + ) + + return _normalize( + x, + mean, + var, + self.scale.value if self.scale else None, + self.bias.value if self.bias else None, + reduction_axes, + feature_axes, + self.dtype, + self.epsilon, + ) + + +class SpectralNorm(Module): + """Spectral normalization. + See: + - https://arxiv.org/abs/1802.05957 + - https://arxiv.org/abs/1805.08318 + - https://arxiv.org/abs/1809.11096 + + Spectral normalization normalizes the weight params so that the spectral + norm of the matrix is equal to 1. This is implemented as a layer wrapper + where each wrapped layer will have its params spectral normalized before + computing its ``__call__`` output. + .. note:: + The initialized variables dict will contain, in addition to a 'params' + collection, a separate 'batch_stats' collection that will contain a + ``u`` vector and ``sigma`` value, which are intermediate values used + when performing spectral normalization. During training, we pass in + ``update_stats=True`` so that ``u`` and ``sigma`` are updated with + the most recently computed values using power iteration. This will + help the power iteration method approximate the true singular value + more accurately over time. During eval, we pass in ``update_stats=False`` + to ensure we get deterministic behavior from the model. + + Example usage:: + + >>> from flax import nnx + >>> import jax + >>> rngs = nnx.Rngs(0) + >>> x = jax.random.normal(jax.random.key(0), (3, 4)) + >>> layer = nnx.SpectralNorm(nnx.Linear(4, 5, rngs=rngs), + ... rngs=rngs) + >>> nnx.state(layer, nnx.OfType(nnx.Param)) + State({ + 'layer_instance': { + 'bias': VariableState( # 5 (20 B) + type=Param, + value=Array([0., 0., 0., 0., 0.], dtype=float32) + ), + 'kernel': VariableState( # 20 (80 B) + type=Param, + value=Array([[ 0.5350889 , -0.48486355, -0.4022262 , -0.61925626, -0.46665004], + [ 0.31773907, 0.38944173, -0.54608804, 0.84378934, -0.93099 ], + [-0.67658 , 0.0724705 , -0.6101737 , 0.12972134, 0.877074 ], + [ 0.27292168, 0.32105306, -0.2556603 , 0.4896752 , 0.19558711]], dtype=float32) + ) + } + }) + >>> y = layer(x, update_stats=True) + + Args: + layer_instance: Module instance that is wrapped with SpectralNorm + n_steps: How many steps of power iteration to perform to approximate the + singular value of the weight params. + epsilon: A small float added to l2-normalization to avoid dividing by zero. + dtype: the dtype of the result (default: infer from input and params). + param_dtype: the dtype passed to parameter initializers (default: float32). + error_on_non_matrix: Spectral normalization is only defined on matrices. By + default, this module will return scalars unchanged and flatten + higher-order tensors in their leading dimensions. Setting this flag to + True will instead throw an error if a weight tensor with dimension greater + than 2 is used by the layer. + collection_name: Name of the collection to store intermediate values used + when performing spectral normalization. + rngs: The rng key. + """ + + def __init__( + self, + layer_instance: Module, + *, + n_steps: int = 1, + epsilon: float = 1e-12, + dtype: tp.Optional[Dtype] = None, + param_dtype: Dtype = jnp.float32, + error_on_non_matrix: bool = False, + collection_name: str = 'batch_stats', + rngs: rnglib.Rngs, + ): + self.layer_instance = layer_instance + self.n_steps = n_steps + self.epsilon = epsilon + self.dtype = dtype + self.param_dtype = param_dtype + self.error_on_non_matrix = error_on_non_matrix + self.collection_name = collection_name + self.rngs = rngs + + def __call__(self, x, *args, update_stats: bool, **kwargs): + """Compute the largest singular value of the weights in ``self.layer_instance`` + using power iteration and normalize the weights using this value before + computing the ``__call__`` output. + + Args: + x: the input array of the nested layer + *args: positional arguments to be passed into the call method of the + underlying layer instance in ``self.layer_instance``. + update_stats: if True, update the internal ``u`` vector and ``sigma`` + value after computing their updated values using power iteration. This + will help the power iteration method approximate the true singular value + more accurately over time. + **kwargs: keyword arguments to be passed into the call method of the + underlying layer instance in ``self.layer_instance``. + + Returns: + Output of the layer using spectral normalized weights. + """ + + state = nnx.state(self.layer_instance) + + def spectral_normalize(path, vs): + value = jnp.asarray(vs.value) + value_shape = value.shape + + # Skip and return value if input is scalar, vector or if number of power + # iterations is less than 1 + if value.ndim <= 1 or self.n_steps < 1: + return value + # Handle higher-order tensors. + elif value.ndim > 2: + if self.error_on_non_matrix: + raise ValueError( + f'Input is {value.ndim}D but error_on_non_matrix is set to True' + ) + else: + value = jnp.reshape(value, (-1, value.shape[-1])) + + u_var_name = ( + self.collection_name + + '/' + + '/'.join(str(k) for k in path) + + '/u' + ) + + try: + u = state[u_var_name].value + except KeyError: + u = jax.random.normal( + self.rngs.params(), + (1, value.shape[-1]), + self.param_dtype, + ) + + sigma_var_name = ( + self.collection_name + + '/' + + '/'.join(str(k) for k in path) + + '/sigma' + ) + + try: + sigma = state[sigma_var_name].value + except KeyError: + sigma = jnp.ones((), self.param_dtype) + + for _ in range(self.n_steps): + v = _l2_normalize( + jnp.matmul(u, value.transpose([1, 0])), eps=self.epsilon + ) + u = _l2_normalize(jnp.matmul(v, value), eps=self.epsilon) + + u = lax.stop_gradient(u) + v = lax.stop_gradient(v) + + sigma = jnp.matmul(jnp.matmul(v, value), jnp.transpose(u))[0, 0] + + value /= jnp.where(sigma != 0, sigma, 1) + value_bar = value.reshape(value_shape) + + if update_stats: + state[u_var_name] = nnx.Param(u) + state[sigma_var_name] = nnx.Param(sigma) + + dtype = dtypes.canonicalize_dtype(vs.value, u, v, sigma, dtype=self.dtype) + return vs.replace(jnp.asarray(value_bar, dtype)) + + state = nnx.map_state(spectral_normalize, state) + nnx.update(self.layer_instance, state) + + return self.layer_instance(x, *args, **kwargs) # type: ignore \ No newline at end of file diff --git a/tests/nnx/nn/normalization_test.py b/tests/nnx/nn/normalization_test.py index d6a399196..7aa88953e 100644 --- a/tests/nnx/nn/normalization_test.py +++ b/tests/nnx/nn/normalization_test.py @@ -323,6 +323,145 @@ def __call__(self, x, *, mask=None): assert isinstance(linen_out, jax.Array) np.testing.assert_array_equal(linen_out, nnx_out) + @parameterized.product( + dtype=[jnp.float32, jnp.float16], + param_dtype=[jnp.float32, jnp.float16], + use_fast_variance=[True, False], + mask=[None, np.array([True, False, True, False, True, False])], + ) + def test_nnx_linen_instancenorm_equivalence( + self, + dtype: tp.Optional[Dtype], + param_dtype: Dtype, + use_fast_variance: bool, + mask: tp.Optional[np.ndarray], + ): + class NNXModel(nnx.Module): + def __init__(self, dtype, param_dtype, use_fast_variance, rngs): + self.norm_layer = nnx.InstanceNorm( + 6, + dtype=dtype, + param_dtype=param_dtype, + use_fast_variance=use_fast_variance, + rngs=rngs, + ) + self.linear = nnx.Linear( + 6, 4, dtype=dtype, param_dtype=param_dtype, rngs=rngs + ) + + def __call__(self, x, *, mask=None): + x = self.norm_layer(x, mask=mask) + x = self.linear(x) + return x + + class LinenModel(linen.Module): + dtype: tp.Optional[Dtype] = None + param_dtype: Dtype = jnp.float32 + use_fast_variance: bool = True + + def setup(self): + self.norm_layer = linen.InstanceNorm( + dtype=self.dtype, + param_dtype=self.param_dtype, + use_fast_variance=self.use_fast_variance, + ) + self.linear = linen.Dense( + 4, dtype=self.dtype, param_dtype=self.param_dtype + ) + + def __call__(self, x, *, mask=None): + x = self.norm_layer(x, mask=mask) + x = self.linear(x) + return x + + rngs = nnx.Rngs(42) + x = jax.random.normal(jax.random.key(0), (10, 6)) + + linen_model = LinenModel( + dtype=dtype, param_dtype=param_dtype, use_fast_variance=use_fast_variance + ) + variables = linen_model.init(jax.random.key(1), x) + linen_out = linen_model.apply(variables, x, mask=mask) + + nnx_model = NNXModel( + dtype=dtype, + param_dtype=param_dtype, + use_fast_variance=use_fast_variance, + rngs=rngs, + ) + nnx_model.linear.kernel.value = variables['params']['linear']['kernel'] + nnx_model.linear.bias.value = variables['params']['linear']['bias'] + + nnx_out = nnx_model(x, mask=mask) + assert isinstance(linen_out, jax.Array) + np.testing.assert_array_equal(linen_out, nnx_out) + + @parameterized.product( + dtype=[jnp.float32, jnp.float16], + param_dtype=[jnp.float32, jnp.float16], + n_steps=[1, 10], + update_stats=[True, False], + ) + def test_nnx_linen_spectralnorm_equivalence( + self, + dtype: tp.Optional[Dtype], + param_dtype: Dtype, + n_steps: int, + update_stats: bool, + ): + class NNXModel(nnx.Module): + def __init__(self, dtype, param_dtype, rngs): + self.linear = nnx.Linear(5, 4, dtype=dtype, + param_dtype=param_dtype, rngs=rngs) + self.norm_layer = nnx.SpectralNorm( + self.linear, + n_steps=n_steps, + dtype=dtype, + param_dtype=param_dtype, + rngs=rngs, + ) + + def __call__(self, x): + return self.norm_layer(x, update_stats=update_stats) + + class LinenModel(linen.Module): + dtype: tp.Optional[Dtype] = None + param_dtype: Dtype = jnp.float32 + + def setup(self): + self.dense = linen.Dense( + 4, dtype=self.dtype, param_dtype=self.param_dtype + ) + self.norm_layer = linen.SpectralNorm(self.dense, n_steps=n_steps) + + def __call__(self, x): + return self.norm_layer(x, update_stats=update_stats) + + rngs = nnx.Rngs(42) + x = jax.random.normal(jax.random.key(0), (10, 5)) + + linen_model = LinenModel(dtype=dtype, param_dtype=param_dtype) + variables = linen_model.init(jax.random.key(1), x) + + nnx_model = NNXModel( + dtype=dtype, param_dtype=param_dtype, rngs=rngs + ) + nnx_model.linear.kernel.value = variables['params']['dense']['kernel'] + nnx_model.linear.bias.value = variables['params']['dense']['bias'] + + linear_state = nnx.state(nnx_model.linear) + linear_state['batch_stats/kernel/u'] = nnx.Param( + variables['batch_stats']['norm_layer']['dense/kernel/u'] + ) + linear_state['batch_stats/kernel/sigma'] = nnx.Param( + variables['batch_stats']['norm_layer']['dense/kernel/sigma'] + ) + nnx.update(nnx_model.linear, linear_state) + + linen_out = linen_model.apply(variables, x, mutable=['batch_stats']) + nnx_out = nnx_model(x) + np.testing.assert_array_equal(linen_out[0], nnx_out) + if __name__ == '__main__': absltest.main()