Skip to content

Implement spectral and instance norm in NNX #4623

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@
from .nn.normalization import LayerNorm as LayerNorm
from .nn.normalization import RMSNorm as RMSNorm
from .nn.normalization import GroupNorm as GroupNorm
from .nn.normalization import InstanceNorm as InstanceNorm
from .nn.normalization import SpectralNorm as SpectralNorm
from .nn.stochastic import Dropout as Dropout
from .rnglib import Rngs as Rngs
from .rnglib import RngStream as RngStream
Expand Down
358 changes: 357 additions & 1 deletion flax/nnx/nn/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,24 @@ def _normalize(
return jnp.asarray(y, dtype)


def _l2_normalize(x, axis=None, eps=1e-12):
"""Normalizes along dimension `axis` using an L2 norm.

This specialized function exists for numerical stability reasons.

Args:
x: An input ndarray.
axis: Dimension along which to normalize, e.g. `1` to separately normalize
vectors in a batch. Passing `None` views `t` as a flattened vector when
calculating the norm (equivalent to Frobenius norm).
eps: Epsilon to avoid dividing by zero.

Returns:
An array of the same shape as 'x' L2-normalized along 'axis'.
"""
return x * jax.lax.rsqrt((x * x).sum(axis=axis, keepdims=True) + eps)


class BatchNorm(Module):
"""BatchNorm Module.

Expand Down Expand Up @@ -835,4 +853,342 @@ def __call__(self, x, *, mask: tp.Optional[jax.Array] = None):
(self.feature_axis,),
self.dtype,
self.epsilon,
)
)


class InstanceNorm(Module):
"""Instance normalization (https://arxiv.org/abs/1607.08022v3).

InstanceNorm normalizes the activations of the layer for each channel (rather
than across all channels like Layer Normalization), and for each given example
in a batch independently (rather than across an entire batch like Batch
Normalization). i.e. applies a transformation that maintains the mean activation
within each channel within each example close to 0 and the activation standard
deviation close to 1.
.. note::
This normalization operation is identical to LayerNorm and GroupNorm; the
difference is simply which axes are reduced and the shape of the feature axes
(i.e. the shape of the learnable scale and bias parameters).

Example usage::

>>> from flax import nnx
>>> import jax
>>> import numpy as np
>>> # dimensions: (batch, height, width, channel)
>>> x = jax.random.normal(jax.random.key(0), (2, 3, 4, 5))
>>> layer = nnx.InstanceNorm(5, rngs=nnx.Rngs(0))
>>> nnx.state(layer, nnx.OfType(nnx.Param))
State({
'bias': VariableState( # 5 (20 B)
type=Param,
value=Array([0., 0., 0., 0., 0.], dtype=float32)
),
'scale': VariableState( # 5 (20 B)
type=Param,
value=Array([1., 1., 1., 1., 1.], dtype=float32)
)
})
>>> y = layer(x)
>>> # having a channel_axis of -1 in InstanceNorm is identical to reducing all non-batch,
>>> # non-channel axes and using the feature_axes as the feature_axes in LayerNorm
>>> y2 = nnx.LayerNorm(5, reduction_axes=[1, 2], feature_axes=-1, rngs=nnx.Rngs(0))(x)
>>> np.testing.assert_allclose(y, y2, atol=1e-7)
>>> y3 = nnx.GroupNorm(5, num_groups=x.shape[-1], rngs=nnx.Rngs(0))(x)
>>> np.testing.assert_allclose(y, y3, atol=1e-7)

Args:
num_features: the number of input features/channels.
epsilon: A small float added to variance to avoid dividing by zero.
dtype: the dtype of the result (default: infer from input and params).
param_dtype: the dtype passed to parameter initializers (default: float32).
use_bias: If True, bias (beta) is added.
use_scale: If True, multiply by scale (gamma). When the next layer is linear
(also e.g. nn.relu), this can be disabled since the scaling will be done
by the next layer.
bias_init: Initializer for bias, by default, zero.
scale_init: Initializer for scale, by default, one.
feature_axes: Axes for features. The learned bias and scaling parameters will
be in the shape defined by the feature axes. All other axes except the batch
axes (which is assumed to be the leading axis) will be reduced.
axis_name: the axis name used to combine batch statistics from multiple
devices. See ``jax.pmap`` for a description of axis names (default: None).
This is only needed if the model is subdivided across devices, i.e. the
array being normalized is sharded across devices within a pmap or shard
map. For SPMD jit, you do not need to manually synchronize. Just make sure
that the axes are correctly annotated and XLA:SPMD will insert the
necessary collectives.
axis_index_groups: groups of axis indices within that named axis
representing subsets of devices to reduce over (default: None). For
example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over the
examples on the first two and last two devices. See ``jax.lax.psum`` for
more details.
use_fast_variance: If true, use a faster, but less numerically stable,
calculation for the variance.
rngs: The rng key.
"""

def __init__(
self,
num_features: int,
*,
epsilon: float = 1e-6,
dtype: tp.Optional[Dtype] = None,
param_dtype: Dtype = jnp.float32,
use_bias: bool = True,
use_scale: bool = True,
bias_init: Initializer = initializers.zeros,
scale_init: Initializer = initializers.ones,
feature_axes: Axes = -1,
axis_name: tp.Optional[str] = None,
axis_index_groups: tp.Any = None,
use_fast_variance: bool = True,
rngs: rnglib.Rngs,
):
feature_shape = (num_features,)
self.scale: nnx.Param[jax.Array] | None
if use_scale:
key = rngs.params()
self.scale = nnx.Param(scale_init(key, feature_shape, param_dtype))
else:
self.scale = None

self.bias: nnx.Param[jax.Array] | None
if use_bias:
key = rngs.params()
self.bias = nnx.Param(bias_init(key, feature_shape, param_dtype))
else:
self.bias = None

self.num_features = num_features
self.epsilon = epsilon
self.dtype = dtype
self.param_dtype = param_dtype
self.use_bias = use_bias
self.use_scale = use_scale
self.bias_init = bias_init
self.scale_init = scale_init
self.feature_axes = feature_axes
self.axis_name = axis_name
self.axis_index_groups = axis_index_groups
self.use_fast_variance = use_fast_variance
self.rngs = rngs

def __call__(self, x, *, mask: tp.Optional[jax.Array] = None):
"""Applies instance normalization on the input.

Args:
x: the inputs
mask: Binary array of shape broadcastable to ``inputs`` tensor, indicating
the positions for which the mean and variance should be computed.

Returns:
Normalized inputs (the same shape as inputs).
"""
feature_axes = _canonicalize_axes(x.ndim, self.feature_axes)
if 0 in feature_axes:
raise ValueError('The channel axes cannot include the leading dimension '
'as this is assumed to be the batch axis.')
reduction_axes = [i for i in range(1, x.ndim) if i not in feature_axes]

mean, var = _compute_stats(
x,
reduction_axes,
self.dtype,
self.axis_name,
self.axis_index_groups,
use_fast_variance=self.use_fast_variance,
mask=mask,
)

return _normalize(
x,
mean,
var,
self.scale.value if self.scale else None,
self.bias.value if self.bias else None,
reduction_axes,
feature_axes,
self.dtype,
self.epsilon,
)


class SpectralNorm(Module):
"""Spectral normalization.
See:
- https://arxiv.org/abs/1802.05957
- https://arxiv.org/abs/1805.08318
- https://arxiv.org/abs/1809.11096

Spectral normalization normalizes the weight params so that the spectral
norm of the matrix is equal to 1. This is implemented as a layer wrapper
where each wrapped layer will have its params spectral normalized before
computing its ``__call__`` output.
.. note::
The initialized variables dict will contain, in addition to a 'params'
collection, a separate 'batch_stats' collection that will contain a
``u`` vector and ``sigma`` value, which are intermediate values used
when performing spectral normalization. During training, we pass in
``update_stats=True`` so that ``u`` and ``sigma`` are updated with
the most recently computed values using power iteration. This will
help the power iteration method approximate the true singular value
more accurately over time. During eval, we pass in ``update_stats=False``
to ensure we get deterministic behavior from the model.

Example usage::

>>> from flax import nnx
>>> import jax
>>> rngs = nnx.Rngs(0)
>>> x = jax.random.normal(jax.random.key(0), (3, 4))
>>> layer = nnx.SpectralNorm(nnx.Linear(4, 5, rngs=rngs),
... rngs=rngs)
>>> nnx.state(layer, nnx.OfType(nnx.Param))
State({
'layer_instance': {
'bias': VariableState( # 5 (20 B)
type=Param,
value=Array([0., 0., 0., 0., 0.], dtype=float32)
),
'kernel': VariableState( # 20 (80 B)
type=Param,
value=Array([[ 0.5350889 , -0.48486355, -0.4022262 , -0.61925626, -0.46665004],
[ 0.31773907, 0.38944173, -0.54608804, 0.84378934, -0.93099 ],
[-0.67658 , 0.0724705 , -0.6101737 , 0.12972134, 0.877074 ],
[ 0.27292168, 0.32105306, -0.2556603 , 0.4896752 , 0.19558711]], dtype=float32)
)
}
})
>>> y = layer(x, update_stats=True)

Args:
layer_instance: Module instance that is wrapped with SpectralNorm
n_steps: How many steps of power iteration to perform to approximate the
singular value of the weight params.
epsilon: A small float added to l2-normalization to avoid dividing by zero.
dtype: the dtype of the result (default: infer from input and params).
param_dtype: the dtype passed to parameter initializers (default: float32).
error_on_non_matrix: Spectral normalization is only defined on matrices. By
default, this module will return scalars unchanged and flatten
higher-order tensors in their leading dimensions. Setting this flag to
True will instead throw an error if a weight tensor with dimension greater
than 2 is used by the layer.
collection_name: Name of the collection to store intermediate values used
when performing spectral normalization.
rngs: The rng key.
"""

def __init__(
self,
layer_instance: Module,
*,
n_steps: int = 1,
epsilon: float = 1e-12,
dtype: tp.Optional[Dtype] = None,
param_dtype: Dtype = jnp.float32,
error_on_non_matrix: bool = False,
collection_name: str = 'batch_stats',
rngs: rnglib.Rngs,
):
self.layer_instance = layer_instance
self.n_steps = n_steps
self.epsilon = epsilon
self.dtype = dtype
self.param_dtype = param_dtype
self.error_on_non_matrix = error_on_non_matrix
self.collection_name = collection_name
self.rngs = rngs

def __call__(self, x, *args, update_stats: bool, **kwargs):
"""Compute the largest singular value of the weights in ``self.layer_instance``
using power iteration and normalize the weights using this value before
computing the ``__call__`` output.

Args:
x: the input array of the nested layer
*args: positional arguments to be passed into the call method of the
underlying layer instance in ``self.layer_instance``.
update_stats: if True, update the internal ``u`` vector and ``sigma``
value after computing their updated values using power iteration. This
will help the power iteration method approximate the true singular value
more accurately over time.
**kwargs: keyword arguments to be passed into the call method of the
underlying layer instance in ``self.layer_instance``.

Returns:
Output of the layer using spectral normalized weights.
"""

state = nnx.state(self.layer_instance)

def spectral_normalize(path, vs):
value = jnp.asarray(vs.value)
value_shape = value.shape

# Skip and return value if input is scalar, vector or if number of power
# iterations is less than 1
if value.ndim <= 1 or self.n_steps < 1:
return value
# Handle higher-order tensors.
elif value.ndim > 2:
if self.error_on_non_matrix:
raise ValueError(
f'Input is {value.ndim}D but error_on_non_matrix is set to True'
)
else:
value = jnp.reshape(value, (-1, value.shape[-1]))

u_var_name = (
self.collection_name
+ '/'
+ '/'.join(str(k) for k in path)
+ '/u'
)

try:
u = state[u_var_name].value
except KeyError:
u = jax.random.normal(
self.rngs.params(),
(1, value.shape[-1]),
self.param_dtype,
)

sigma_var_name = (
self.collection_name
+ '/'
+ '/'.join(str(k) for k in path)
+ '/sigma'
)

try:
sigma = state[sigma_var_name].value
except KeyError:
sigma = jnp.ones((), self.param_dtype)

for _ in range(self.n_steps):
v = _l2_normalize(
jnp.matmul(u, value.transpose([1, 0])), eps=self.epsilon
)
u = _l2_normalize(jnp.matmul(v, value), eps=self.epsilon)

u = lax.stop_gradient(u)
v = lax.stop_gradient(v)

sigma = jnp.matmul(jnp.matmul(v, value), jnp.transpose(u))[0, 0]

value /= jnp.where(sigma != 0, sigma, 1)
value_bar = value.reshape(value_shape)

if update_stats:
state[u_var_name] = nnx.Param(u)
state[sigma_var_name] = nnx.Param(sigma)

dtype = dtypes.canonicalize_dtype(vs.value, u, v, sigma, dtype=self.dtype)
return vs.replace(jnp.asarray(value_bar, dtype))

state = nnx.map_state(spectral_normalize, state)
nnx.update(self.layer_instance, state)

return self.layer_instance(x, *args, **kwargs) # type: ignore
Loading
Loading