Skip to content

Implement spectral and instance norm in NNX #4623

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

mattbahr
Copy link
Contributor

@mattbahr mattbahr commented Mar 14, 2025

What does this PR do?

Fixes #4684 Brings the instance and spectral norm implementations over to NNX from Linen.

Checklist

  • This PR fixes a minor issue (e.g.: typo or small bug) or improves the docs (you can dismiss the other
    checks if that's the case).
  • This change is discussed in No SpectralNorm for NNX #4684
  • The documentation and docstrings adhere to the
    documentation guidelines.
  • This change includes necessary high-coverage tests.
    (No quality testing = no merge!)

@mattbahr mattbahr force-pushed the implement-spectral-and-instance-norm branch 3 times, most recently from 32d8f2d to d40a1d9 Compare March 21, 2025 03:50
@mattbahr
Copy link
Contributor Author

@cgarciae Whenever you get the time, would you be willing to review this PR? Thanks!

@mattbahr
Copy link
Contributor Author

mattbahr commented Apr 4, 2025

@cgarciae Just another friendly ping :) Would be great to have your review!

@h-0-0
Copy link

h-0-0 commented Apr 7, 2025

Hi, I tried using the SpectralNorm you've implemented in this PR. But when I wrap a layer with it and use the layer in a jitted function to take gradients and update the layer I'm getting an UnexpectedTracerError. I'm getting the following when I run the function with jax.checking_leaks:
<Param 5474125456> is referred to by <list 14978391808>[1] <list 14978391808> is referred to by <Store 14979622864> <Store 14979622864> is referred to by <tuple 14973474736>[1] <tuple 14973474736> is referred to by <WrappedFun 14973206592>
I think it may have to do with changing the state variable inside the spectral_normalize function which is out of scope.

@mattbahr
Copy link
Contributor Author

mattbahr commented Apr 7, 2025

Hi @h-0-0 thanks for checking out the PR! Could you send your code you used to get the UnexpectedTracerError? I'm still relatively new to the JAX/Flax ecosystem, but my understanding tells me this may not be jit-compatible as it uses state and isn't a pure function, as you've pointed out. We may need a maintainer to weigh in here, but if you send your code, I'd be happy to debug on my end. For transparency, my implementation was based in part on the weight norm implementation in #4568. I don't know of a way to update the parameters in nnx without using state, but if you have any suggestions, feel free to post here.

@h-0-0
Copy link

h-0-0 commented Apr 9, 2025

HI @mattbahr the code I'm using to get the error is somewhat involved. So I've tried to reproduce the error with minimal code. I also am quite new to the ecosystem, I believe there are ways to get around the issues with state as detailed here but I'm not exactly sure what would be best to do in this case and what exactly is causing the error.

Anyways here's some code I put together:

# Imports
import jax
import typing as tp
from jax import lax
import jax.numpy as jnp
from flax import nnx
from flax.nnx import rnglib
from flax.nnx.module import Module, first_from
from flax.nnx.nn import dtypes, initializers
from flax.typing import (
  Array,
  Dtype,
  Initializer,
  Axes,
)

# ------ Code from PR ------
def _l2_normalize(x, axis=None, eps=1e-12):
  """Normalizes along dimension `axis` using an L2 norm.
  This specialized function exists for numerical stability reasons.
  Args:
    x: An input ndarray.
    axis: Dimension along which to normalize, e.g. `1` to separately normalize
      vectors in a batch. Passing `None` views `t` as a flattened vector when
      calculating the norm (equivalent to Frobenius norm).
    eps: Epsilon to avoid dividing by zero.
  Returns:
    An array of the same shape as 'x' L2-normalized along 'axis'.
  """
  return x * jax.lax.rsqrt((x * x).sum(axis=axis, keepdims=True) + eps)

class SpectralNorm(Module):
  """Spectral normalization.
  See:
  - https://arxiv.org/abs/1802.05957
  - https://arxiv.org/abs/1805.08318
  - https://arxiv.org/abs/1809.11096
  Spectral normalization normalizes the weight params so that the spectral
  norm of the matrix is equal to 1. This is implemented as a layer wrapper
  where each wrapped layer will have its params spectral normalized before
  computing its ``__call__`` output.
  .. note::
    The initialized variables dict will contain, in addition to a 'params'
    collection, a separate 'batch_stats' collection that will contain a
    ``u`` vector and ``sigma`` value, which are intermediate values used
    when performing spectral normalization. During training, we pass in
    ``update_stats=True`` so that ``u`` and ``sigma`` are updated with
    the most recently computed values using power iteration. This will
    help the power iteration method approximate the true singular value
    more accurately over time. During eval, we pass in ``update_stats=False``
    to ensure we get deterministic behavior from the model.
  Example usage::
    >>> from flax import nnx
    >>> import jax
    >>> rngs = nnx.Rngs(0)
    >>> x = jax.random.normal(jax.random.key(0), (3, 4))
    >>> layer = nnx.SpectralNorm(nnx.Linear(4, 5, rngs=rngs),
    ...                          rngs=rngs)
    >>> nnx.state(layer, nnx.OfType(nnx.Param))
    State({
      'layer_instance': {
        'bias': VariableState( # 5 (20 B)
          type=Param,
          value=Array([0., 0., 0., 0., 0.], dtype=float32)
        ),
        'kernel': VariableState( # 20 (80 B)
          type=Param,
          value=Array([[ 0.5350889 , -0.48486355, -0.4022262 , -0.61925626, -0.46665004],
                 [ 0.31773907,  0.38944173, -0.54608804,  0.84378934, -0.93099   ],
                 [-0.67658   ,  0.0724705 , -0.6101737 ,  0.12972134,  0.877074  ],
                 [ 0.27292168,  0.32105306, -0.2556603 ,  0.4896752 ,  0.19558711]],      dtype=float32)
        )
      }
    })
    >>> y = layer(x, update_stats=True)
  Args:
    layer_instance: Module instance that is wrapped with SpectralNorm
    n_steps: How many steps of power iteration to perform to approximate the
      singular value of the weight params.
    epsilon: A small float added to l2-normalization to avoid dividing by zero.
    dtype: the dtype of the result (default: infer from input and params).
    param_dtype: the dtype passed to parameter initializers (default: float32).
    error_on_non_matrix: Spectral normalization is only defined on matrices. By
      default, this module will return scalars unchanged and flatten
      higher-order tensors in their leading dimensions. Setting this flag to
      True will instead throw an error if a weight tensor with dimension greater
      than 2 is used by the layer.
    collection_name: Name of the collection to store intermediate values used
      when performing spectral normalization.
    rngs: The rng key.
  """

  def __init__(
    self,
    layer_instance: Module,
    *,
    n_steps: int = 1,
    epsilon: float = 1e-12,
    dtype: tp.Optional[Dtype] = None,
    param_dtype: Dtype = jnp.float32,
    error_on_non_matrix: bool = False,
    collection_name: str = 'batch_stats',
    rngs: rnglib.Rngs,
  ):
    self.layer_instance = layer_instance
    self.n_steps = n_steps
    self.epsilon = epsilon
    self.dtype = dtype
    self.param_dtype = param_dtype
    self.error_on_non_matrix = error_on_non_matrix
    self.collection_name = collection_name
    self.rngs = rngs

  def __call__(self, x, *args, update_stats, **kwargs):
    """Compute the largest singular value of the weights in ``self.layer_instance``
    using power iteration and normalize the weights using this value before
    computing the ``__call__`` output.
    Args:
      x: the input array of the nested layer
      *args: positional arguments to be passed into the call method of the
        underlying layer instance in ``self.layer_instance``.
      update_stats: if True, update the internal ``u`` vector and ``sigma``
        value after computing their updated values using power iteration. This
        will help the power iteration method approximate the true singular value
        more accurately over time.
      **kwargs: keyword arguments to be passed into the call method of the
        underlying layer instance in ``self.layer_instance``.
    Returns:
      Output of the layer using spectral normalized weights.
    """

    state = nnx.state(self.layer_instance)

    def spectral_normalize(path, vs):
      value = jnp.asarray(vs.value)
      value_shape = value.shape

      # Skip and return value if input is scalar, vector or if number of power
      # iterations is less than 1
      if value.ndim <= 1 or self.n_steps < 1:
        return value
      # Handle higher-order tensors.
      elif value.ndim > 2:
        if self.error_on_non_matrix:
          raise ValueError(
            f'Input is {value.ndim}D but error_on_non_matrix is set to True'
          )
        else:
          value = jnp.reshape(value, (-1, value.shape[-1]))

      u_var_name = (
        self.collection_name
        + '/'
        + '/'.join(str(k) for k in path)
        + '/u'
      )

      try:
        u = state[u_var_name].value
      except KeyError:
        u = jax.random.normal(
          self.rngs.params(),
          (1, value.shape[-1]),
          self.param_dtype,
        )

      sigma_var_name = (
        self.collection_name
        + '/'
        + '/'.join(str(k) for k in path)
        + '/sigma'
      )

      try:
        sigma = state[sigma_var_name].value
      except KeyError:
        sigma = jnp.ones((), self.param_dtype)

      for _ in range(self.n_steps):
        v = _l2_normalize(
          jnp.matmul(u, value.transpose([1, 0])), eps=self.epsilon
        )
        u = _l2_normalize(jnp.matmul(v, value), eps=self.epsilon)

      u = lax.stop_gradient(u)
      v = lax.stop_gradient(v)

      sigma = jnp.matmul(jnp.matmul(v, value), jnp.transpose(u))[0, 0]

      value /= jnp.where(sigma != 0, sigma, 1)
      value_bar = value.reshape(value_shape)

      if update_stats:
        state[u_var_name] = nnx.Param(u)
        state[sigma_var_name] = nnx.Param(sigma)

      dtype = dtypes.canonicalize_dtype(vs.value, u, v, sigma, dtype=self.dtype)
      return nnx.Param(jnp.asarray(value_bar, dtype))

    state = nnx.map_state(spectral_normalize, state)
    nnx.update(self.layer_instance, state)

    return self.layer_instance(x, *args, **kwargs) 

# ------ One layer neural network ------
class SmallMLP(nnx.Module):
    def __init__(self, rngs: nnx.Rngs = None, *args, **kwargs):
        self._built = False
        self.net = SpectralNorm(nnx.Linear(2, 6, rngs=rngs), rngs=rngs)

    def __call__(self, x, update_stats=True):
        return self.net(x, update_stats=update_stats)

# ------ Recreating error ------
def get_loss(x, y):
    def fun(net):
        return jnp.sum((net(x, update_stats=True) - y) ** 2)
    return fun

def test_spectral_norm(net, x, y):
    init_loss = get_loss(x, y)
    loss, grads = nnx.value_and_grad(init_loss)(net)
    return loss, grads

# Jitted function
jitted_test_spectral_norm = nnx.jit(test_spectral_norm)

# Create a linear layer
rngs = nnx.Rngs(0)
net = SmallMLP(rngs=rngs)
import optax
optimizer = nnx.Optimizer(net, optax.adam(1e-3)) 

# Create test input
x = jax.random.normal(jax.random.PRNGKey(0), (4, 2))
y = jax.random.normal(jax.random.PRNGKey(1), (4, 1))

# Test the non-jitted function
loss, grads = test_spectral_norm(net, x, y)
# optimizer.update(grads)
print("Successfully used non-jitted function!")

# Test the jitted function
loss, grads = jitted_test_spectral_norm(net, x, y)
print("Successfully used jitted function with SpectralNorm layer!")

Even when not using jit you seem to get an error when using the optimiser to update the network. If you comment out 'optimizer.update(grads)' you'll see that when using jit you also get an error.

@mattbahr
Copy link
Contributor Author

mattbahr commented Apr 9, 2025

Thanks for the info! I'll play around with it a little bit and see what I come up with

@h-0-0
Copy link

h-0-0 commented Apr 9, 2025

By the way if you run the above code but swap SpectralNorm for WeightNorm from #4568 you get the same error.

@mattbahr mattbahr force-pushed the implement-spectral-and-instance-norm branch from d40a1d9 to d53c9fe Compare April 11, 2025 03:00
@mattbahr
Copy link
Contributor Author

mattbahr commented Apr 11, 2025

@h-0-0 Can you try your test again with the latest code? I just had to return vs.replace(jnp.asarray(value_bar, dtype)) from the spectral_normalize function instead of nnx.Param(jnp.asarray(value_bar, dtype)). Your code sample is working for me now

@h-0-0
Copy link

h-0-0 commented Apr 11, 2025

Hi, @mattbahr I tried re-running the code above and my own and both are now working! Seems it's fixed, thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

No SpectralNorm for NNX
2 participants