diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index 938c939a9..3b3ddcb77 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -99,6 +99,7 @@ from .nn.linear import Einsum as Einsum from .nn.lora import LoRA as LoRA from .nn.lora import LoRALinear as LoRALinear +from .nn.lora import LoRALinearGeneral as LoRALinearGeneral from .nn.lora import LoRAParam as LoRAParam from .nn.normalization import BatchNorm as BatchNorm from .nn.normalization import LayerNorm as LayerNorm diff --git a/flax/nnx/nn/linear.py b/flax/nnx/nn/linear.py index 22bfb1e1b..365ea4ced 100644 --- a/flax/nnx/nn/linear.py +++ b/flax/nnx/nn/linear.py @@ -146,7 +146,7 @@ def __init__( in_features: Size | tp.Sequence[Size], out_features: Size | tp.Sequence[Size], *, - axis: Axis | tp.Sequence[Axis] = -1, + axis: Axis | tp.Sequence[Axis] = None, batch_axis: tp.Mapping[Axis, Size] = FrozenDict({}), use_bias: bool = True, dtype: Dtype | None = None, @@ -162,6 +162,8 @@ def __init__( ): self.in_features = _canonicalize_tuple(in_features) self.out_features = _canonicalize_tuple(out_features) + if axis is None: + axis = tuple(range(-len(self.in_features), 0)) self.axis = _canonicalize_tuple(axis) self.batch_axis = FrozenDict[Axis, Size](batch_axis) self.use_bias = use_bias diff --git a/flax/nnx/nn/lora.py b/flax/nnx/nn/lora.py index 36f19d8b5..b8e5db915 100644 --- a/flax/nnx/nn/lora.py +++ b/flax/nnx/nn/lora.py @@ -18,11 +18,12 @@ from flax.nnx import rnglib, variablelib from flax.nnx.module import Module from flax.nnx.nn import initializers -from flax.nnx.nn.linear import Linear +from flax.nnx.nn.linear import Linear, LinearGeneral from flax.nnx.nn.dtypes import promote_dtype from flax.typing import Dtype, Initializer import jax import jax.numpy as jnp +import numpy as np Array = jax.Array Axis = int @@ -33,6 +34,11 @@ default_b_initializer = initializers.zeros +def _normalize_axes(axes: tuple[int, ...], ndim: int) -> tuple[int, ...]: + # A tuple by convention. len(axes_tuple) then also gives the rank efficiently. + return tuple(sorted(ax if ax >= 0 else ndim + ax for ax in axes)) + + class LoRAParam(variablelib.Param[A]): pass @@ -78,18 +84,18 @@ class LoRA(Module): """ def __init__( - self, - in_features: int, - lora_rank: int, - out_features: int, - *, - base_module: tp.Optional[Module] = None, - dtype: tp.Optional[Dtype] = None, - param_dtype: Dtype = jnp.float32, - a_initializer: Initializer = default_a_initializer, - b_initializer: Initializer = default_b_initializer, - lora_param_type: tp.Type[variablelib.Variable] = LoRAParam, - rngs: rnglib.Rngs, + self, + in_features: int, + lora_rank: int, + out_features: int, + *, + base_module: tp.Optional[Module] = None, + dtype: tp.Optional[Dtype] = None, + param_dtype: Dtype = jnp.float32, + a_initializer: Initializer = default_a_initializer, + b_initializer: Initializer = default_b_initializer, + lora_param_type: tp.Type[variablelib.Variable] = LoRAParam, + rngs: rnglib.Rngs, ): self.in_features = in_features self.out_features = out_features @@ -99,10 +105,10 @@ def __init__( self.base_module = base_module self.lora_a = lora_param_type( - a_initializer(rngs.params(), (in_features, lora_rank), param_dtype) + a_initializer(rngs.params(), (in_features, lora_rank), param_dtype) ) self.lora_b = lora_param_type( - b_initializer(rngs.params(), (lora_rank, out_features), param_dtype) + b_initializer(rngs.params(), (lora_rank, out_features), param_dtype) ) def __call__(self, x: jax.Array): @@ -144,11 +150,9 @@ class LoRALinear(Linear): in_features: the number of input features. out_features: the number of output features. lora_rank: the rank of the LoRA dimension. - base_module: a base module to call and substitute, if possible. - dtype: the dtype of the computation (default: infer from input and params). - param_dtype: the dtype passed to parameter initializers (default: float32). - precision: numerical precision of the computation see `jax.lax.Precision` - for details. + lora_base_module: a base module to call and substitute, if possible. + lora_dtype: the dtype of the computation (default: infer from input and params). + lora_param_dtype: the dtype passed to parameter initializers (default: float32). a_initializer: initializer function for the fan-in matrices. Default to `he_uniform`. b_initializer: initializer function for the fan-out matrices. Default to @@ -157,33 +161,126 @@ class LoRALinear(Linear): """ def __init__( - self, - in_features: int, - out_features: int, - *, - lora_rank: int, - lora_dtype: tp.Optional[Dtype] = None, - lora_param_dtype: Dtype = jnp.float32, - a_initializer: Initializer = default_a_initializer, - b_initializer: Initializer = default_b_initializer, - lora_param_type: tp.Type[variablelib.Variable] = LoRAParam, - rngs: rnglib.Rngs, - **kwargs, + self, + in_features: int, + out_features: int, + *, + lora_rank: int, + lora_base_module: tp.Optional[Module] = None, + lora_dtype: tp.Optional[Dtype] = None, + lora_param_dtype: Dtype = jnp.float32, + a_initializer: Initializer = default_a_initializer, + b_initializer: Initializer = default_b_initializer, + lora_param_type: tp.Type[variablelib.Variable] = LoRAParam, + rngs: rnglib.Rngs, + **kwargs, ): super().__init__(in_features, out_features, rngs=rngs, **kwargs) self.lora = LoRA( - in_features, - lora_rank, - out_features, - dtype=lora_dtype, - param_dtype=lora_param_dtype, - a_initializer=a_initializer, - b_initializer=b_initializer, - lora_param_type=lora_param_type, - rngs=rngs, + in_features, + lora_rank, + out_features, + base_module=lora_base_module, + dtype=lora_dtype, + param_dtype=lora_param_dtype, + a_initializer=a_initializer, + b_initializer=b_initializer, + lora_param_type=lora_param_type, + rngs=rngs, ) def __call__(self, x: jax.Array): - y = super().__call__(x) - y += self.lora(x) - return y + out = super().__call__(x) + out += self.lora(x) + return out + + +class LoRALinearGeneral(LinearGeneral): + """An `nnx.LinearGeneral` layer in which the output will be LoRAified. + +The model state structure will be compatible with that of LinearGeneral. + +Example usage:: + + >>> from flax import nnx + >>> import jax, jax.numpy as jnp + ... + >>> # input features (2, 3), output features (4, 5) + >>> # apply transformation on the the second and last axes + >>> layer = nnx.LinearGeneral((2, 3), (4, 5), axis=(1, -1), rngs=nnx.Rngs(0)) + >>> lora_layer = nnx.LoRALinearGeneral((2, 3), (4, 5), axis=(1, -1), lora_rank=1, rngs=nnx.Rngs(0)) + >>> layer.kernel.value.shape + (2, 3, 4, 5) + >>> layer.bias.value.shape + (4, 5) + >>> lora_layer.kernel.value.shape + (2, 3, 4, 5) + >>> lora_layer.lora.lora_a.value.shape + (6, 1) + >>> lora_layer.lora.lora_b.value.shape + (1, 20) + >>> jnp.allclose(layer.kernel.value, lora_layer.kernel.value) + Array(True, dtype=bool) + >>> y = lora_layer(jnp.ones((16, 2, 3))) + >>> y.shape + (16, 4, 5) + +Args: + in_features: int or tuple with number of input features. + out_features: int or tuple with number of output features. + lora_rank: the rank of the LoRA dimension. + lora_base_module: a base module to call and substitute, if possible. + lora_dtype: the dtype of the computation (default: infer from input and params). + lora_param_dtype: the dtype passed to parameter initializers (default: float32). + a_initializer: initializer function for the fan-in matrices. Default to + `he_uniform`. + b_initializer: initializer function for the fan-out matrices. Default to + `zero initializer`. + lora_param_type: the type of the LoRA params. +""" + def __init__( + self, + in_features: int | tp.Sequence[int], + out_features: int | tp.Sequence[int], + *, + lora_rank: int, + lora_base_module: tp.Optional[Module] = None, + lora_dtype: tp.Optional[Dtype] = None, + lora_param_dtype: Dtype = jnp.float32, + a_initializer: Initializer = default_a_initializer, + b_initializer: Initializer = default_b_initializer, + lora_param_type: tp.Type[variablelib.Variable] = LoRAParam, + rngs: rnglib.Rngs, + **kwargs + ): + super().__init__(in_features, out_features, rngs=rngs, **kwargs) + self.last_axis = tuple(range(-len(self.in_features), 0)) + + total_in_features = np.prod(self.in_features) + total_out_features = np.prod(self.out_features) + self.lora = LoRA( + total_in_features, + lora_rank, + total_out_features, + base_module=lora_base_module, + dtype=lora_dtype, + param_dtype=lora_param_dtype, + a_initializer=a_initializer, + b_initializer=b_initializer, + lora_param_type=lora_param_type, + rngs=rngs, + ) + + def __call__(self, x: jax.Array): + ndim = x.ndim + axis = _normalize_axes(self.axis, ndim) + last_axis = _normalize_axes(self.last_axis, ndim) + assert axis == last_axis, ( + 'LoRALinearGeneral only supports applying the transformation on ' + 'the last axes of the input' + ) + batch_shape = x.shape[:-len(self.in_features)] + out = self.lora(x.reshape((*batch_shape, -1))) + out = out.reshape((*batch_shape, *self.out_features)) + out += super().__call__(x) + return out