Skip to content

Commit d40a1d9

Browse files
committed
implement spectral and instance norm
1 parent a9ef59b commit d40a1d9

File tree

3 files changed

+498
-1
lines changed

3 files changed

+498
-1
lines changed

flax/nnx/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@
104104
from .nn.normalization import LayerNorm as LayerNorm
105105
from .nn.normalization import RMSNorm as RMSNorm
106106
from .nn.normalization import GroupNorm as GroupNorm
107+
from .nn.normalization import InstanceNorm as InstanceNorm
108+
from .nn.normalization import SpectralNorm as SpectralNorm
107109
from .nn.stochastic import Dropout as Dropout
108110
from .rnglib import Rngs as Rngs
109111
from .rnglib import RngStream as RngStream

flax/nnx/nn/normalization.py

Lines changed: 357 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,24 @@ def _normalize(
181181
return jnp.asarray(y, dtype)
182182

183183

184+
def _l2_normalize(x, axis=None, eps=1e-12):
185+
"""Normalizes along dimension `axis` using an L2 norm.
186+
187+
This specialized function exists for numerical stability reasons.
188+
189+
Args:
190+
x: An input ndarray.
191+
axis: Dimension along which to normalize, e.g. `1` to separately normalize
192+
vectors in a batch. Passing `None` views `t` as a flattened vector when
193+
calculating the norm (equivalent to Frobenius norm).
194+
eps: Epsilon to avoid dividing by zero.
195+
196+
Returns:
197+
An array of the same shape as 'x' L2-normalized along 'axis'.
198+
"""
199+
return x * jax.lax.rsqrt((x * x).sum(axis=axis, keepdims=True) + eps)
200+
201+
184202
class BatchNorm(Module):
185203
"""BatchNorm Module.
186204
@@ -835,4 +853,342 @@ def __call__(self, x, *, mask: tp.Optional[jax.Array] = None):
835853
(self.feature_axis,),
836854
self.dtype,
837855
self.epsilon,
838-
)
856+
)
857+
858+
859+
class InstanceNorm(Module):
860+
"""Instance normalization (https://arxiv.org/abs/1607.08022v3).
861+
862+
InstanceNorm normalizes the activations of the layer for each channel (rather
863+
than across all channels like Layer Normalization), and for each given example
864+
in a batch independently (rather than across an entire batch like Batch
865+
Normalization). i.e. applies a transformation that maintains the mean activation
866+
within each channel within each example close to 0 and the activation standard
867+
deviation close to 1.
868+
.. note::
869+
This normalization operation is identical to LayerNorm and GroupNorm; the
870+
difference is simply which axes are reduced and the shape of the feature axes
871+
(i.e. the shape of the learnable scale and bias parameters).
872+
873+
Example usage::
874+
875+
>>> from flax import nnx
876+
>>> import jax
877+
>>> import numpy as np
878+
>>> # dimensions: (batch, height, width, channel)
879+
>>> x = jax.random.normal(jax.random.key(0), (2, 3, 4, 5))
880+
>>> layer = nnx.InstanceNorm(5, rngs=nnx.Rngs(0))
881+
>>> nnx.state(layer, nnx.OfType(nnx.Param))
882+
State({
883+
'bias': VariableState( # 5 (20 B)
884+
type=Param,
885+
value=Array([0., 0., 0., 0., 0.], dtype=float32)
886+
),
887+
'scale': VariableState( # 5 (20 B)
888+
type=Param,
889+
value=Array([1., 1., 1., 1., 1.], dtype=float32)
890+
)
891+
})
892+
>>> y = layer(x)
893+
>>> # having a channel_axis of -1 in InstanceNorm is identical to reducing all non-batch,
894+
>>> # non-channel axes and using the feature_axes as the feature_axes in LayerNorm
895+
>>> y2 = nnx.LayerNorm(5, reduction_axes=[1, 2], feature_axes=-1, rngs=nnx.Rngs(0))(x)
896+
>>> np.testing.assert_allclose(y, y2, atol=1e-7)
897+
>>> y3 = nnx.GroupNorm(5, num_groups=x.shape[-1], rngs=nnx.Rngs(0))(x)
898+
>>> np.testing.assert_allclose(y, y3, atol=1e-7)
899+
900+
Args:
901+
num_features: the number of input features/channels.
902+
epsilon: A small float added to variance to avoid dividing by zero.
903+
dtype: the dtype of the result (default: infer from input and params).
904+
param_dtype: the dtype passed to parameter initializers (default: float32).
905+
use_bias: If True, bias (beta) is added.
906+
use_scale: If True, multiply by scale (gamma). When the next layer is linear
907+
(also e.g. nn.relu), this can be disabled since the scaling will be done
908+
by the next layer.
909+
bias_init: Initializer for bias, by default, zero.
910+
scale_init: Initializer for scale, by default, one.
911+
feature_axes: Axes for features. The learned bias and scaling parameters will
912+
be in the shape defined by the feature axes. All other axes except the batch
913+
axes (which is assumed to be the leading axis) will be reduced.
914+
axis_name: the axis name used to combine batch statistics from multiple
915+
devices. See ``jax.pmap`` for a description of axis names (default: None).
916+
This is only needed if the model is subdivided across devices, i.e. the
917+
array being normalized is sharded across devices within a pmap or shard
918+
map. For SPMD jit, you do not need to manually synchronize. Just make sure
919+
that the axes are correctly annotated and XLA:SPMD will insert the
920+
necessary collectives.
921+
axis_index_groups: groups of axis indices within that named axis
922+
representing subsets of devices to reduce over (default: None). For
923+
example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over the
924+
examples on the first two and last two devices. See ``jax.lax.psum`` for
925+
more details.
926+
use_fast_variance: If true, use a faster, but less numerically stable,
927+
calculation for the variance.
928+
rngs: The rng key.
929+
"""
930+
931+
def __init__(
932+
self,
933+
num_features: int,
934+
*,
935+
epsilon: float = 1e-6,
936+
dtype: tp.Optional[Dtype] = None,
937+
param_dtype: Dtype = jnp.float32,
938+
use_bias: bool = True,
939+
use_scale: bool = True,
940+
bias_init: Initializer = initializers.zeros,
941+
scale_init: Initializer = initializers.ones,
942+
feature_axes: Axes = -1,
943+
axis_name: tp.Optional[str] = None,
944+
axis_index_groups: tp.Any = None,
945+
use_fast_variance: bool = True,
946+
rngs: rnglib.Rngs,
947+
):
948+
feature_shape = (num_features,)
949+
self.scale: nnx.Param[jax.Array] | None
950+
if use_scale:
951+
key = rngs.params()
952+
self.scale = nnx.Param(scale_init(key, feature_shape, param_dtype))
953+
else:
954+
self.scale = None
955+
956+
self.bias: nnx.Param[jax.Array] | None
957+
if use_bias:
958+
key = rngs.params()
959+
self.bias = nnx.Param(bias_init(key, feature_shape, param_dtype))
960+
else:
961+
self.bias = None
962+
963+
self.num_features = num_features
964+
self.epsilon = epsilon
965+
self.dtype = dtype
966+
self.param_dtype = param_dtype
967+
self.use_bias = use_bias
968+
self.use_scale = use_scale
969+
self.bias_init = bias_init
970+
self.scale_init = scale_init
971+
self.feature_axes = feature_axes
972+
self.axis_name = axis_name
973+
self.axis_index_groups = axis_index_groups
974+
self.use_fast_variance = use_fast_variance
975+
self.rngs = rngs
976+
977+
def __call__(self, x, *, mask: tp.Optional[jax.Array] = None):
978+
"""Applies instance normalization on the input.
979+
980+
Args:
981+
x: the inputs
982+
mask: Binary array of shape broadcastable to ``inputs`` tensor, indicating
983+
the positions for which the mean and variance should be computed.
984+
985+
Returns:
986+
Normalized inputs (the same shape as inputs).
987+
"""
988+
feature_axes = _canonicalize_axes(x.ndim, self.feature_axes)
989+
if 0 in feature_axes:
990+
raise ValueError('The channel axes cannot include the leading dimension '
991+
'as this is assumed to be the batch axis.')
992+
reduction_axes = [i for i in range(1, x.ndim) if i not in feature_axes]
993+
994+
mean, var = _compute_stats(
995+
x,
996+
reduction_axes,
997+
self.dtype,
998+
self.axis_name,
999+
self.axis_index_groups,
1000+
use_fast_variance=self.use_fast_variance,
1001+
mask=mask,
1002+
)
1003+
1004+
return _normalize(
1005+
x,
1006+
mean,
1007+
var,
1008+
self.scale.value if self.scale else None,
1009+
self.bias.value if self.bias else None,
1010+
reduction_axes,
1011+
feature_axes,
1012+
self.dtype,
1013+
self.epsilon,
1014+
)
1015+
1016+
1017+
class SpectralNorm(Module):
1018+
"""Spectral normalization.
1019+
See:
1020+
- https://arxiv.org/abs/1802.05957
1021+
- https://arxiv.org/abs/1805.08318
1022+
- https://arxiv.org/abs/1809.11096
1023+
1024+
Spectral normalization normalizes the weight params so that the spectral
1025+
norm of the matrix is equal to 1. This is implemented as a layer wrapper
1026+
where each wrapped layer will have its params spectral normalized before
1027+
computing its ``__call__`` output.
1028+
.. note::
1029+
The initialized variables dict will contain, in addition to a 'params'
1030+
collection, a separate 'batch_stats' collection that will contain a
1031+
``u`` vector and ``sigma`` value, which are intermediate values used
1032+
when performing spectral normalization. During training, we pass in
1033+
``update_stats=True`` so that ``u`` and ``sigma`` are updated with
1034+
the most recently computed values using power iteration. This will
1035+
help the power iteration method approximate the true singular value
1036+
more accurately over time. During eval, we pass in ``update_stats=False``
1037+
to ensure we get deterministic behavior from the model.
1038+
1039+
Example usage::
1040+
1041+
>>> from flax import nnx
1042+
>>> import jax
1043+
>>> rngs = nnx.Rngs(0)
1044+
>>> x = jax.random.normal(jax.random.key(0), (3, 4))
1045+
>>> layer = nnx.SpectralNorm(nnx.Linear(4, 5, rngs=rngs),
1046+
... rngs=rngs)
1047+
>>> nnx.state(layer, nnx.OfType(nnx.Param))
1048+
State({
1049+
'layer_instance': {
1050+
'bias': VariableState( # 5 (20 B)
1051+
type=Param,
1052+
value=Array([0., 0., 0., 0., 0.], dtype=float32)
1053+
),
1054+
'kernel': VariableState( # 20 (80 B)
1055+
type=Param,
1056+
value=Array([[ 0.5350889 , -0.48486355, -0.4022262 , -0.61925626, -0.46665004],
1057+
[ 0.31773907, 0.38944173, -0.54608804, 0.84378934, -0.93099 ],
1058+
[-0.67658 , 0.0724705 , -0.6101737 , 0.12972134, 0.877074 ],
1059+
[ 0.27292168, 0.32105306, -0.2556603 , 0.4896752 , 0.19558711]], dtype=float32)
1060+
)
1061+
}
1062+
})
1063+
>>> y = layer(x, update_stats=True)
1064+
1065+
Args:
1066+
layer_instance: Module instance that is wrapped with SpectralNorm
1067+
n_steps: How many steps of power iteration to perform to approximate the
1068+
singular value of the weight params.
1069+
epsilon: A small float added to l2-normalization to avoid dividing by zero.
1070+
dtype: the dtype of the result (default: infer from input and params).
1071+
param_dtype: the dtype passed to parameter initializers (default: float32).
1072+
error_on_non_matrix: Spectral normalization is only defined on matrices. By
1073+
default, this module will return scalars unchanged and flatten
1074+
higher-order tensors in their leading dimensions. Setting this flag to
1075+
True will instead throw an error if a weight tensor with dimension greater
1076+
than 2 is used by the layer.
1077+
collection_name: Name of the collection to store intermediate values used
1078+
when performing spectral normalization.
1079+
rngs: The rng key.
1080+
"""
1081+
1082+
def __init__(
1083+
self,
1084+
layer_instance: Module,
1085+
*,
1086+
n_steps: int = 1,
1087+
epsilon: float = 1e-12,
1088+
dtype: tp.Optional[Dtype] = None,
1089+
param_dtype: Dtype = jnp.float32,
1090+
error_on_non_matrix: bool = False,
1091+
collection_name: str = 'batch_stats',
1092+
rngs: rnglib.Rngs,
1093+
):
1094+
self.layer_instance = layer_instance
1095+
self.n_steps = n_steps
1096+
self.epsilon = epsilon
1097+
self.dtype = dtype
1098+
self.param_dtype = param_dtype
1099+
self.error_on_non_matrix = error_on_non_matrix
1100+
self.collection_name = collection_name
1101+
self.rngs = rngs
1102+
1103+
def __call__(self, x, *args, update_stats: bool, **kwargs):
1104+
"""Compute the largest singular value of the weights in ``self.layer_instance``
1105+
using power iteration and normalize the weights using this value before
1106+
computing the ``__call__`` output.
1107+
1108+
Args:
1109+
x: the input array of the nested layer
1110+
*args: positional arguments to be passed into the call method of the
1111+
underlying layer instance in ``self.layer_instance``.
1112+
update_stats: if True, update the internal ``u`` vector and ``sigma``
1113+
value after computing their updated values using power iteration. This
1114+
will help the power iteration method approximate the true singular value
1115+
more accurately over time.
1116+
**kwargs: keyword arguments to be passed into the call method of the
1117+
underlying layer instance in ``self.layer_instance``.
1118+
1119+
Returns:
1120+
Output of the layer using spectral normalized weights.
1121+
"""
1122+
1123+
state = nnx.state(self.layer_instance)
1124+
1125+
def spectral_normalize(path, vs):
1126+
value = jnp.asarray(vs.value)
1127+
value_shape = value.shape
1128+
1129+
# Skip and return value if input is scalar, vector or if number of power
1130+
# iterations is less than 1
1131+
if value.ndim <= 1 or self.n_steps < 1:
1132+
return value
1133+
# Handle higher-order tensors.
1134+
elif value.ndim > 2:
1135+
if self.error_on_non_matrix:
1136+
raise ValueError(
1137+
f'Input is {value.ndim}D but error_on_non_matrix is set to True'
1138+
)
1139+
else:
1140+
value = jnp.reshape(value, (-1, value.shape[-1]))
1141+
1142+
u_var_name = (
1143+
self.collection_name
1144+
+ '/'
1145+
+ '/'.join(str(k) for k in path)
1146+
+ '/u'
1147+
)
1148+
1149+
try:
1150+
u = state[u_var_name].value
1151+
except KeyError:
1152+
u = jax.random.normal(
1153+
self.rngs.params(),
1154+
(1, value.shape[-1]),
1155+
self.param_dtype,
1156+
)
1157+
1158+
sigma_var_name = (
1159+
self.collection_name
1160+
+ '/'
1161+
+ '/'.join(str(k) for k in path)
1162+
+ '/sigma'
1163+
)
1164+
1165+
try:
1166+
sigma = state[sigma_var_name].value
1167+
except KeyError:
1168+
sigma = jnp.ones((), self.param_dtype)
1169+
1170+
for _ in range(self.n_steps):
1171+
v = _l2_normalize(
1172+
jnp.matmul(u, value.transpose([1, 0])), eps=self.epsilon
1173+
)
1174+
u = _l2_normalize(jnp.matmul(v, value), eps=self.epsilon)
1175+
1176+
u = lax.stop_gradient(u)
1177+
v = lax.stop_gradient(v)
1178+
1179+
sigma = jnp.matmul(jnp.matmul(v, value), jnp.transpose(u))[0, 0]
1180+
1181+
value /= jnp.where(sigma != 0, sigma, 1)
1182+
value_bar = value.reshape(value_shape)
1183+
1184+
if update_stats:
1185+
state[u_var_name] = nnx.Param(u)
1186+
state[sigma_var_name] = nnx.Param(sigma)
1187+
1188+
dtype = dtypes.canonicalize_dtype(vs.value, u, v, sigma, dtype=self.dtype)
1189+
return nnx.Param(jnp.asarray(value_bar, dtype))
1190+
1191+
state = nnx.map_state(spectral_normalize, state)
1192+
nnx.update(self.layer_instance, state)
1193+
1194+
return self.layer_instance(x, *args, **kwargs) # type: ignore

0 commit comments

Comments
 (0)