@@ -181,6 +181,24 @@ def _normalize(
181
181
return jnp .asarray (y , dtype )
182
182
183
183
184
+ def _l2_normalize (x , axis = None , eps = 1e-12 ):
185
+ """Normalizes along dimension `axis` using an L2 norm.
186
+
187
+ This specialized function exists for numerical stability reasons.
188
+
189
+ Args:
190
+ x: An input ndarray.
191
+ axis: Dimension along which to normalize, e.g. `1` to separately normalize
192
+ vectors in a batch. Passing `None` views `t` as a flattened vector when
193
+ calculating the norm (equivalent to Frobenius norm).
194
+ eps: Epsilon to avoid dividing by zero.
195
+
196
+ Returns:
197
+ An array of the same shape as 'x' L2-normalized along 'axis'.
198
+ """
199
+ return x * jax .lax .rsqrt ((x * x ).sum (axis = axis , keepdims = True ) + eps )
200
+
201
+
184
202
class BatchNorm (Module ):
185
203
"""BatchNorm Module.
186
204
@@ -835,4 +853,342 @@ def __call__(self, x, *, mask: tp.Optional[jax.Array] = None):
835
853
(self .feature_axis ,),
836
854
self .dtype ,
837
855
self .epsilon ,
838
- )
856
+ )
857
+
858
+
859
+ class InstanceNorm (Module ):
860
+ """Instance normalization (https://arxiv.org/abs/1607.08022v3).
861
+
862
+ InstanceNorm normalizes the activations of the layer for each channel (rather
863
+ than across all channels like Layer Normalization), and for each given example
864
+ in a batch independently (rather than across an entire batch like Batch
865
+ Normalization). i.e. applies a transformation that maintains the mean activation
866
+ within each channel within each example close to 0 and the activation standard
867
+ deviation close to 1.
868
+ .. note::
869
+ This normalization operation is identical to LayerNorm and GroupNorm; the
870
+ difference is simply which axes are reduced and the shape of the feature axes
871
+ (i.e. the shape of the learnable scale and bias parameters).
872
+
873
+ Example usage::
874
+
875
+ >>> from flax import nnx
876
+ >>> import jax
877
+ >>> import numpy as np
878
+ >>> # dimensions: (batch, height, width, channel)
879
+ >>> x = jax.random.normal(jax.random.key(0), (2, 3, 4, 5))
880
+ >>> layer = nnx.InstanceNorm(5, rngs=nnx.Rngs(0))
881
+ >>> nnx.state(layer, nnx.OfType(nnx.Param))
882
+ State({
883
+ 'bias': VariableState( # 5 (20 B)
884
+ type=Param,
885
+ value=Array([0., 0., 0., 0., 0.], dtype=float32)
886
+ ),
887
+ 'scale': VariableState( # 5 (20 B)
888
+ type=Param,
889
+ value=Array([1., 1., 1., 1., 1.], dtype=float32)
890
+ )
891
+ })
892
+ >>> y = layer(x)
893
+ >>> # having a channel_axis of -1 in InstanceNorm is identical to reducing all non-batch,
894
+ >>> # non-channel axes and using the feature_axes as the feature_axes in LayerNorm
895
+ >>> y2 = nnx.LayerNorm(5, reduction_axes=[1, 2], feature_axes=-1, rngs=nnx.Rngs(0))(x)
896
+ >>> np.testing.assert_allclose(y, y2, atol=1e-7)
897
+ >>> y3 = nnx.GroupNorm(5, num_groups=x.shape[-1], rngs=nnx.Rngs(0))(x)
898
+ >>> np.testing.assert_allclose(y, y3, atol=1e-7)
899
+
900
+ Args:
901
+ num_features: the number of input features/channels.
902
+ epsilon: A small float added to variance to avoid dividing by zero.
903
+ dtype: the dtype of the result (default: infer from input and params).
904
+ param_dtype: the dtype passed to parameter initializers (default: float32).
905
+ use_bias: If True, bias (beta) is added.
906
+ use_scale: If True, multiply by scale (gamma). When the next layer is linear
907
+ (also e.g. nn.relu), this can be disabled since the scaling will be done
908
+ by the next layer.
909
+ bias_init: Initializer for bias, by default, zero.
910
+ scale_init: Initializer for scale, by default, one.
911
+ feature_axes: Axes for features. The learned bias and scaling parameters will
912
+ be in the shape defined by the feature axes. All other axes except the batch
913
+ axes (which is assumed to be the leading axis) will be reduced.
914
+ axis_name: the axis name used to combine batch statistics from multiple
915
+ devices. See ``jax.pmap`` for a description of axis names (default: None).
916
+ This is only needed if the model is subdivided across devices, i.e. the
917
+ array being normalized is sharded across devices within a pmap or shard
918
+ map. For SPMD jit, you do not need to manually synchronize. Just make sure
919
+ that the axes are correctly annotated and XLA:SPMD will insert the
920
+ necessary collectives.
921
+ axis_index_groups: groups of axis indices within that named axis
922
+ representing subsets of devices to reduce over (default: None). For
923
+ example, ``[[0, 1], [2, 3]]`` would independently batch-normalize over the
924
+ examples on the first two and last two devices. See ``jax.lax.psum`` for
925
+ more details.
926
+ use_fast_variance: If true, use a faster, but less numerically stable,
927
+ calculation for the variance.
928
+ rngs: The rng key.
929
+ """
930
+
931
+ def __init__ (
932
+ self ,
933
+ num_features : int ,
934
+ * ,
935
+ epsilon : float = 1e-6 ,
936
+ dtype : tp .Optional [Dtype ] = None ,
937
+ param_dtype : Dtype = jnp .float32 ,
938
+ use_bias : bool = True ,
939
+ use_scale : bool = True ,
940
+ bias_init : Initializer = initializers .zeros ,
941
+ scale_init : Initializer = initializers .ones ,
942
+ feature_axes : Axes = - 1 ,
943
+ axis_name : tp .Optional [str ] = None ,
944
+ axis_index_groups : tp .Any = None ,
945
+ use_fast_variance : bool = True ,
946
+ rngs : rnglib .Rngs ,
947
+ ):
948
+ feature_shape = (num_features ,)
949
+ self .scale : nnx .Param [jax .Array ] | None
950
+ if use_scale :
951
+ key = rngs .params ()
952
+ self .scale = nnx .Param (scale_init (key , feature_shape , param_dtype ))
953
+ else :
954
+ self .scale = None
955
+
956
+ self .bias : nnx .Param [jax .Array ] | None
957
+ if use_bias :
958
+ key = rngs .params ()
959
+ self .bias = nnx .Param (bias_init (key , feature_shape , param_dtype ))
960
+ else :
961
+ self .bias = None
962
+
963
+ self .num_features = num_features
964
+ self .epsilon = epsilon
965
+ self .dtype = dtype
966
+ self .param_dtype = param_dtype
967
+ self .use_bias = use_bias
968
+ self .use_scale = use_scale
969
+ self .bias_init = bias_init
970
+ self .scale_init = scale_init
971
+ self .feature_axes = feature_axes
972
+ self .axis_name = axis_name
973
+ self .axis_index_groups = axis_index_groups
974
+ self .use_fast_variance = use_fast_variance
975
+ self .rngs = rngs
976
+
977
+ def __call__ (self , x , * , mask : tp .Optional [jax .Array ] = None ):
978
+ """Applies instance normalization on the input.
979
+
980
+ Args:
981
+ x: the inputs
982
+ mask: Binary array of shape broadcastable to ``inputs`` tensor, indicating
983
+ the positions for which the mean and variance should be computed.
984
+
985
+ Returns:
986
+ Normalized inputs (the same shape as inputs).
987
+ """
988
+ feature_axes = _canonicalize_axes (x .ndim , self .feature_axes )
989
+ if 0 in feature_axes :
990
+ raise ValueError ('The channel axes cannot include the leading dimension '
991
+ 'as this is assumed to be the batch axis.' )
992
+ reduction_axes = [i for i in range (1 , x .ndim ) if i not in feature_axes ]
993
+
994
+ mean , var = _compute_stats (
995
+ x ,
996
+ reduction_axes ,
997
+ self .dtype ,
998
+ self .axis_name ,
999
+ self .axis_index_groups ,
1000
+ use_fast_variance = self .use_fast_variance ,
1001
+ mask = mask ,
1002
+ )
1003
+
1004
+ return _normalize (
1005
+ x ,
1006
+ mean ,
1007
+ var ,
1008
+ self .scale .value if self .scale else None ,
1009
+ self .bias .value if self .bias else None ,
1010
+ reduction_axes ,
1011
+ feature_axes ,
1012
+ self .dtype ,
1013
+ self .epsilon ,
1014
+ )
1015
+
1016
+
1017
+ class SpectralNorm (Module ):
1018
+ """Spectral normalization.
1019
+ See:
1020
+ - https://arxiv.org/abs/1802.05957
1021
+ - https://arxiv.org/abs/1805.08318
1022
+ - https://arxiv.org/abs/1809.11096
1023
+
1024
+ Spectral normalization normalizes the weight params so that the spectral
1025
+ norm of the matrix is equal to 1. This is implemented as a layer wrapper
1026
+ where each wrapped layer will have its params spectral normalized before
1027
+ computing its ``__call__`` output.
1028
+ .. note::
1029
+ The initialized variables dict will contain, in addition to a 'params'
1030
+ collection, a separate 'batch_stats' collection that will contain a
1031
+ ``u`` vector and ``sigma`` value, which are intermediate values used
1032
+ when performing spectral normalization. During training, we pass in
1033
+ ``update_stats=True`` so that ``u`` and ``sigma`` are updated with
1034
+ the most recently computed values using power iteration. This will
1035
+ help the power iteration method approximate the true singular value
1036
+ more accurately over time. During eval, we pass in ``update_stats=False``
1037
+ to ensure we get deterministic behavior from the model.
1038
+
1039
+ Example usage::
1040
+
1041
+ >>> from flax import nnx
1042
+ >>> import jax
1043
+ >>> rngs = nnx.Rngs(0)
1044
+ >>> x = jax.random.normal(jax.random.key(0), (3, 4))
1045
+ >>> layer = nnx.SpectralNorm(nnx.Linear(4, 5, rngs=rngs),
1046
+ ... rngs=rngs)
1047
+ >>> nnx.state(layer, nnx.OfType(nnx.Param))
1048
+ State({
1049
+ 'layer_instance': {
1050
+ 'bias': VariableState( # 5 (20 B)
1051
+ type=Param,
1052
+ value=Array([0., 0., 0., 0., 0.], dtype=float32)
1053
+ ),
1054
+ 'kernel': VariableState( # 20 (80 B)
1055
+ type=Param,
1056
+ value=Array([[ 0.5350889 , -0.48486355, -0.4022262 , -0.61925626, -0.46665004],
1057
+ [ 0.31773907, 0.38944173, -0.54608804, 0.84378934, -0.93099 ],
1058
+ [-0.67658 , 0.0724705 , -0.6101737 , 0.12972134, 0.877074 ],
1059
+ [ 0.27292168, 0.32105306, -0.2556603 , 0.4896752 , 0.19558711]], dtype=float32)
1060
+ )
1061
+ }
1062
+ })
1063
+ >>> y = layer(x, update_stats=True)
1064
+
1065
+ Args:
1066
+ layer_instance: Module instance that is wrapped with SpectralNorm
1067
+ n_steps: How many steps of power iteration to perform to approximate the
1068
+ singular value of the weight params.
1069
+ epsilon: A small float added to l2-normalization to avoid dividing by zero.
1070
+ dtype: the dtype of the result (default: infer from input and params).
1071
+ param_dtype: the dtype passed to parameter initializers (default: float32).
1072
+ error_on_non_matrix: Spectral normalization is only defined on matrices. By
1073
+ default, this module will return scalars unchanged and flatten
1074
+ higher-order tensors in their leading dimensions. Setting this flag to
1075
+ True will instead throw an error if a weight tensor with dimension greater
1076
+ than 2 is used by the layer.
1077
+ collection_name: Name of the collection to store intermediate values used
1078
+ when performing spectral normalization.
1079
+ rngs: The rng key.
1080
+ """
1081
+
1082
+ def __init__ (
1083
+ self ,
1084
+ layer_instance : Module ,
1085
+ * ,
1086
+ n_steps : int = 1 ,
1087
+ epsilon : float = 1e-12 ,
1088
+ dtype : tp .Optional [Dtype ] = None ,
1089
+ param_dtype : Dtype = jnp .float32 ,
1090
+ error_on_non_matrix : bool = False ,
1091
+ collection_name : str = 'batch_stats' ,
1092
+ rngs : rnglib .Rngs ,
1093
+ ):
1094
+ self .layer_instance = layer_instance
1095
+ self .n_steps = n_steps
1096
+ self .epsilon = epsilon
1097
+ self .dtype = dtype
1098
+ self .param_dtype = param_dtype
1099
+ self .error_on_non_matrix = error_on_non_matrix
1100
+ self .collection_name = collection_name
1101
+ self .rngs = rngs
1102
+
1103
+ def __call__ (self , x , * args , update_stats : bool , ** kwargs ):
1104
+ """Compute the largest singular value of the weights in ``self.layer_instance``
1105
+ using power iteration and normalize the weights using this value before
1106
+ computing the ``__call__`` output.
1107
+
1108
+ Args:
1109
+ x: the input array of the nested layer
1110
+ *args: positional arguments to be passed into the call method of the
1111
+ underlying layer instance in ``self.layer_instance``.
1112
+ update_stats: if True, update the internal ``u`` vector and ``sigma``
1113
+ value after computing their updated values using power iteration. This
1114
+ will help the power iteration method approximate the true singular value
1115
+ more accurately over time.
1116
+ **kwargs: keyword arguments to be passed into the call method of the
1117
+ underlying layer instance in ``self.layer_instance``.
1118
+
1119
+ Returns:
1120
+ Output of the layer using spectral normalized weights.
1121
+ """
1122
+
1123
+ state = nnx .state (self .layer_instance )
1124
+
1125
+ def spectral_normalize (path , vs ):
1126
+ value = jnp .asarray (vs .value )
1127
+ value_shape = value .shape
1128
+
1129
+ # Skip and return value if input is scalar, vector or if number of power
1130
+ # iterations is less than 1
1131
+ if value .ndim <= 1 or self .n_steps < 1 :
1132
+ return value
1133
+ # Handle higher-order tensors.
1134
+ elif value .ndim > 2 :
1135
+ if self .error_on_non_matrix :
1136
+ raise ValueError (
1137
+ f'Input is { value .ndim } D but error_on_non_matrix is set to True'
1138
+ )
1139
+ else :
1140
+ value = jnp .reshape (value , (- 1 , value .shape [- 1 ]))
1141
+
1142
+ u_var_name = (
1143
+ self .collection_name
1144
+ + '/'
1145
+ + '/' .join (str (k ) for k in path )
1146
+ + '/u'
1147
+ )
1148
+
1149
+ try :
1150
+ u = state [u_var_name ].value
1151
+ except KeyError :
1152
+ u = jax .random .normal (
1153
+ self .rngs .params (),
1154
+ (1 , value .shape [- 1 ]),
1155
+ self .param_dtype ,
1156
+ )
1157
+
1158
+ sigma_var_name = (
1159
+ self .collection_name
1160
+ + '/'
1161
+ + '/' .join (str (k ) for k in path )
1162
+ + '/sigma'
1163
+ )
1164
+
1165
+ try :
1166
+ sigma = state [sigma_var_name ].value
1167
+ except KeyError :
1168
+ sigma = jnp .ones ((), self .param_dtype )
1169
+
1170
+ for _ in range (self .n_steps ):
1171
+ v = _l2_normalize (
1172
+ jnp .matmul (u , value .transpose ([1 , 0 ])), eps = self .epsilon
1173
+ )
1174
+ u = _l2_normalize (jnp .matmul (v , value ), eps = self .epsilon )
1175
+
1176
+ u = lax .stop_gradient (u )
1177
+ v = lax .stop_gradient (v )
1178
+
1179
+ sigma = jnp .matmul (jnp .matmul (v , value ), jnp .transpose (u ))[0 , 0 ]
1180
+
1181
+ value /= jnp .where (sigma != 0 , sigma , 1 )
1182
+ value_bar = value .reshape (value_shape )
1183
+
1184
+ if update_stats :
1185
+ state [u_var_name ] = nnx .Param (u )
1186
+ state [sigma_var_name ] = nnx .Param (sigma )
1187
+
1188
+ dtype = dtypes .canonicalize_dtype (vs .value , u , v , sigma , dtype = self .dtype )
1189
+ return nnx .Param (jnp .asarray (value_bar , dtype ))
1190
+
1191
+ state = nnx .map_state (spectral_normalize , state )
1192
+ nnx .update (self .layer_instance , state )
1193
+
1194
+ return self .layer_instance (x , * args , ** kwargs ) # type: ignore
0 commit comments