Skip to content

FEAT: add LoRALinearGeneral #4718

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
from .nn.linear import Einsum as Einsum
from .nn.lora import LoRA as LoRA
from .nn.lora import LoRALinear as LoRALinear
from .nn.lora import LoRALinearGeneral as LoRALinearGeneral
from .nn.lora import LoRAParam as LoRAParam
from .nn.normalization import BatchNorm as BatchNorm
from .nn.normalization import LayerNorm as LayerNorm
Expand Down
4 changes: 3 additions & 1 deletion flax/nnx/nn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def __init__(
in_features: Size | tp.Sequence[Size],
out_features: Size | tp.Sequence[Size],
*,
axis: Axis | tp.Sequence[Axis] = -1,
axis: Axis | tp.Sequence[Axis] = None,
batch_axis: tp.Mapping[Axis, Size] = FrozenDict({}),
use_bias: bool = True,
dtype: Dtype | None = None,
Expand All @@ -162,6 +162,8 @@ def __init__(
):
self.in_features = _canonicalize_tuple(in_features)
self.out_features = _canonicalize_tuple(out_features)
if axis is None:
axis = tuple(range(-len(self.in_features), 0))
self.axis = _canonicalize_tuple(axis)
self.batch_axis = FrozenDict[Axis, Size](batch_axis)
self.use_bias = use_bias
Expand Down
185 changes: 141 additions & 44 deletions flax/nnx/nn/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
from flax.nnx import rnglib, variablelib
from flax.nnx.module import Module
from flax.nnx.nn import initializers
from flax.nnx.nn.linear import Linear
from flax.nnx.nn.linear import Linear, LinearGeneral
from flax.nnx.nn.dtypes import promote_dtype
from flax.typing import Dtype, Initializer
import jax
import jax.numpy as jnp
import numpy as np

Array = jax.Array
Axis = int
Expand All @@ -33,6 +34,11 @@
default_b_initializer = initializers.zeros


def _normalize_axes(axes: tuple[int, ...], ndim: int) -> tuple[int, ...]:
# A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
return tuple(sorted(ax if ax >= 0 else ndim + ax for ax in axes))


class LoRAParam(variablelib.Param[A]):
pass

Expand Down Expand Up @@ -78,18 +84,18 @@ class LoRA(Module):
"""

def __init__(
self,
in_features: int,
lora_rank: int,
out_features: int,
*,
base_module: tp.Optional[Module] = None,
dtype: tp.Optional[Dtype] = None,
param_dtype: Dtype = jnp.float32,
a_initializer: Initializer = default_a_initializer,
b_initializer: Initializer = default_b_initializer,
lora_param_type: tp.Type[variablelib.Variable] = LoRAParam,
rngs: rnglib.Rngs,
self,
in_features: int,
lora_rank: int,
out_features: int,
*,
base_module: tp.Optional[Module] = None,
dtype: tp.Optional[Dtype] = None,
param_dtype: Dtype = jnp.float32,
a_initializer: Initializer = default_a_initializer,
b_initializer: Initializer = default_b_initializer,
lora_param_type: tp.Type[variablelib.Variable] = LoRAParam,
rngs: rnglib.Rngs,
):
self.in_features = in_features
self.out_features = out_features
Expand All @@ -99,10 +105,10 @@ def __init__(
self.base_module = base_module

self.lora_a = lora_param_type(
a_initializer(rngs.params(), (in_features, lora_rank), param_dtype)
a_initializer(rngs.params(), (in_features, lora_rank), param_dtype)
)
self.lora_b = lora_param_type(
b_initializer(rngs.params(), (lora_rank, out_features), param_dtype)
b_initializer(rngs.params(), (lora_rank, out_features), param_dtype)
)

def __call__(self, x: jax.Array):
Expand Down Expand Up @@ -144,11 +150,9 @@ class LoRALinear(Linear):
in_features: the number of input features.
out_features: the number of output features.
lora_rank: the rank of the LoRA dimension.
base_module: a base module to call and substitute, if possible.
dtype: the dtype of the computation (default: infer from input and params).
param_dtype: the dtype passed to parameter initializers (default: float32).
precision: numerical precision of the computation see `jax.lax.Precision`
for details.
lora_base_module: a base module to call and substitute, if possible.
lora_dtype: the dtype of the computation (default: infer from input and params).
lora_param_dtype: the dtype passed to parameter initializers (default: float32).
a_initializer: initializer function for the fan-in matrices. Default to
`he_uniform`.
b_initializer: initializer function for the fan-out matrices. Default to
Expand All @@ -157,33 +161,126 @@ class LoRALinear(Linear):
"""

def __init__(
self,
in_features: int,
out_features: int,
*,
lora_rank: int,
lora_dtype: tp.Optional[Dtype] = None,
lora_param_dtype: Dtype = jnp.float32,
a_initializer: Initializer = default_a_initializer,
b_initializer: Initializer = default_b_initializer,
lora_param_type: tp.Type[variablelib.Variable] = LoRAParam,
rngs: rnglib.Rngs,
**kwargs,
self,
in_features: int,
out_features: int,
*,
lora_rank: int,
lora_base_module: tp.Optional[Module] = None,
lora_dtype: tp.Optional[Dtype] = None,
lora_param_dtype: Dtype = jnp.float32,
a_initializer: Initializer = default_a_initializer,
b_initializer: Initializer = default_b_initializer,
lora_param_type: tp.Type[variablelib.Variable] = LoRAParam,
rngs: rnglib.Rngs,
**kwargs,
):
super().__init__(in_features, out_features, rngs=rngs, **kwargs)
self.lora = LoRA(
in_features,
lora_rank,
out_features,
dtype=lora_dtype,
param_dtype=lora_param_dtype,
a_initializer=a_initializer,
b_initializer=b_initializer,
lora_param_type=lora_param_type,
rngs=rngs,
in_features,
lora_rank,
out_features,
base_module=lora_base_module,
dtype=lora_dtype,
param_dtype=lora_param_dtype,
a_initializer=a_initializer,
b_initializer=b_initializer,
lora_param_type=lora_param_type,
rngs=rngs,
)

def __call__(self, x: jax.Array):
y = super().__call__(x)
y += self.lora(x)
return y
out = super().__call__(x)
out += self.lora(x)
return out


class LoRALinearGeneral(LinearGeneral):
"""An `nnx.LinearGeneral` layer in which the output will be LoRAified.

The model state structure will be compatible with that of LinearGeneral.

Example usage::

>>> from flax import nnx
>>> import jax, jax.numpy as jnp
...
>>> # input features (2, 3), output features (4, 5)
>>> # apply transformation on the the second and last axes
>>> layer = nnx.LinearGeneral((2, 3), (4, 5), axis=(1, -1), rngs=nnx.Rngs(0))
>>> lora_layer = nnx.LoRALinearGeneral((2, 3), (4, 5), axis=(1, -1), lora_rank=1, rngs=nnx.Rngs(0))
>>> layer.kernel.value.shape
(2, 3, 4, 5)
>>> layer.bias.value.shape
(4, 5)
>>> lora_layer.kernel.value.shape
(2, 3, 4, 5)
>>> lora_layer.lora.lora_a.value.shape
(6, 1)
>>> lora_layer.lora.lora_b.value.shape
(1, 20)
>>> jnp.allclose(layer.kernel.value, lora_layer.kernel.value)
Array(True, dtype=bool)
>>> y = lora_layer(jnp.ones((16, 2, 3)))
>>> y.shape
(16, 4, 5)

Args:
in_features: int or tuple with number of input features.
out_features: int or tuple with number of output features.
lora_rank: the rank of the LoRA dimension.
lora_base_module: a base module to call and substitute, if possible.
lora_dtype: the dtype of the computation (default: infer from input and params).
lora_param_dtype: the dtype passed to parameter initializers (default: float32).
a_initializer: initializer function for the fan-in matrices. Default to
`he_uniform`.
b_initializer: initializer function for the fan-out matrices. Default to
`zero initializer`.
lora_param_type: the type of the LoRA params.
"""
def __init__(
self,
in_features: int | tp.Sequence[int],
out_features: int | tp.Sequence[int],
*,
lora_rank: int,
lora_base_module: tp.Optional[Module] = None,
lora_dtype: tp.Optional[Dtype] = None,
lora_param_dtype: Dtype = jnp.float32,
a_initializer: Initializer = default_a_initializer,
b_initializer: Initializer = default_b_initializer,
lora_param_type: tp.Type[variablelib.Variable] = LoRAParam,
rngs: rnglib.Rngs,
**kwargs
):
super().__init__(in_features, out_features, rngs=rngs, **kwargs)
self.last_axis = tuple(range(-len(self.in_features), 0))

total_in_features = np.prod(self.in_features)
total_out_features = np.prod(self.out_features)
self.lora = LoRA(
total_in_features,
lora_rank,
total_out_features,
base_module=lora_base_module,
dtype=lora_dtype,
param_dtype=lora_param_dtype,
a_initializer=a_initializer,
b_initializer=b_initializer,
lora_param_type=lora_param_type,
rngs=rngs,
)

def __call__(self, x: jax.Array):
ndim = x.ndim
axis = _normalize_axes(self.axis, ndim)
last_axis = _normalize_axes(self.last_axis, ndim)
assert axis == last_axis, (
'LoRALinearGeneral only supports applying the transformation on '
'the last axes of the input'
)
batch_shape = x.shape[:-len(self.in_features)]
out = self.lora(x.reshape((*batch_shape, -1)))
out = out.reshape((*batch_shape, *self.out_features))
out += super().__call__(x)
return out