Skip to content

Commit ef8c084

Browse files
committed
[nnx] add optax optimizer
1 parent a5eebe5 commit ef8c084

File tree

6 files changed

+216
-59
lines changed

6 files changed

+216
-59
lines changed

flax/nnx/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@
139139
from .training.metrics import Metric as Metric
140140
from .training.metrics import MultiMetric as MultiMetric
141141
from .training.optimizer import Optimizer as Optimizer
142+
from .training.optimizer import OptaxOptimizer as OptaxOptimizer
142143
from .transforms.autodiff import DiffState as DiffState
143144
from .transforms.autodiff import grad as grad
144145
from .transforms.autodiff import value_and_grad as value_and_grad

flax/nnx/graph.py

Lines changed: 42 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2652,10 +2652,22 @@ def clone(node: Node) -> Node:
26522652
graphdef, state = split(node)
26532653
return merge(graphdef, state)
26542654

2655+
def find_duplicates(tree) -> tuple[str, str] | None:
2656+
mutable_arrays: dict[int, str] = {}
2657+
paths_leaves = jax.tree.leaves_with_path(tree)
2658+
for path, x in paths_leaves:
2659+
m_array_id = id(x)
2660+
if m_array_id in mutable_arrays:
2661+
current_path_str = jax.tree_util.keystr(path)
2662+
previous_path_str = mutable_arrays[m_array_id]
2663+
return current_path_str, previous_path_str
2664+
mutable_arrays[m_array_id] = jax.tree_util.keystr(path)
2665+
2666+
return None
26552667

26562668
def _mutable_like(path, x):
26572669
return (
2658-
isinstance(x, Variable) and x.mutable
2670+
isinstance(x, Variable | VariableState) and x.mutable
26592671
) or variablelib.is_mutable_array(x)
26602672

26612673

@@ -2698,45 +2710,36 @@ def freeze(tree: A, /, only: filterlib.Filter = _mutable_like) -> A:
26982710
Returns:
26992711
A pytree with the frozen arrays.
27002712
"""
2713+
if (duplicate := find_duplicates(tree)) is not None:
2714+
current_path_str, previous_path_str = duplicate
2715+
raise ValueError(
2716+
f"Found duplicate at path '{current_path_str}' "
2717+
f"and '{previous_path_str}'."
2718+
)
27012719
freeze_filter = filterlib.to_predicate(only)
2702-
mutable_arrays: dict[int, str] = {}
2703-
2704-
def check_mutable_array(path, x):
2705-
m_array_id = id(x)
2706-
if m_array_id in mutable_arrays:
2707-
current_path_str = jax.tree_util.keystr(path)
2708-
previous_path_str = mutable_arrays[m_array_id]
2709-
raise ValueError(
2710-
f'Found duplicate MutableArray found at path {current_path_str} '
2711-
f'and {previous_path_str} at object {x}.'
2712-
)
2713-
mutable_arrays[m_array_id] = jax.tree_util.keystr(path)
27142720

27152721
def _freeze_fn(jax_path, x):
2716-
path = tuple(_key_path_to_key(part) for part in jax_path)
2722+
path = jax_to_nnx_path(jax_path)
27172723
if freeze_filter(path, x):
2718-
if isinstance(x, Variable):
2719-
check_mutable_array(jax_path, x.raw_value)
2720-
return x.from_metadata(x[...], x.get_metadata().copy())
2721-
elif variablelib.is_mutable_array(x):
2722-
check_mutable_array(jax_path, x)
2723-
return x[...]
2724+
x = jax.tree.map(lambda x: x[...], x)
2725+
elif isinstance(x, Variable | VariableState):
2726+
x = jax.tree.map(lambda x: x, x)
27242727
return x
27252728

27262729
tree = jax.tree.map_with_path(
2727-
_freeze_fn, tree, is_leaf=lambda x: isinstance(x, Variable)
2730+
_freeze_fn, tree, is_leaf=lambda x: isinstance(x, Variable | VariableState)
27282731
)
27292732
return tree
27302733

27312734

27322735
def _array_like(path, x):
27332736
return (
2734-
isinstance(x, Variable) and isinstance(x.raw_value, jax.Array)
2737+
isinstance(x, Variable | VariableState) and not x.mutable
27352738
) or isinstance(x, jax.Array)
27362739

27372740

27382741
def mutable(tree: A, /, only: filterlib.Filter = _array_like) -> A:
2739-
"""Converts a pytree of arrays to mutable arrays.
2742+
"""Converts a tree of arrays to mutable arrays.
27402743
27412744
Example::
27422745
@@ -2774,34 +2777,24 @@ def mutable(tree: A, /, only: filterlib.Filter = _array_like) -> A:
27742777
Returns:
27752778
A pytree with the mutable arrays.
27762779
"""
2780+
if (duplicate := find_duplicates(tree)) is not None:
2781+
current_path_str, previous_path_str = duplicate
2782+
raise ValueError(
2783+
f"Found duplicate at path '{current_path_str}' "
2784+
f"and '{previous_path_str}'."
2785+
)
27772786
mutable_filter = filterlib.to_predicate(only)
2778-
arrays: dict[int, str] = {}
2779-
2780-
def check_array(path, x):
2781-
m_array_id = id(x)
2782-
if m_array_id in arrays:
2783-
current_path_str = jax.tree_util.keystr(path)
2784-
previous_path_str = arrays[m_array_id]
2785-
raise ValueError(
2786-
f'Found duplicate Array found at path {current_path_str} '
2787-
f'and {previous_path_str} at object {x}.'
2788-
)
2789-
arrays[m_array_id] = jax.tree_util.keystr(path)
27902787

27912788
def _mutable_fn(jax_path, x):
2792-
path = tuple(_key_path_to_key(part) for part in jax_path)
2789+
path = jax_to_nnx_path(jax_path)
27932790
if mutable_filter(path, x):
2794-
if isinstance(x, Variable) and isinstance(x.raw_value, jax.Array):
2795-
check_array(jax_path, x.raw_value)
2796-
mutable_array = variablelib.mutable_array(x.raw_value)
2797-
return x.from_metadata(mutable_array, x.get_metadata().copy())
2798-
elif isinstance(x, jax.Array):
2799-
check_array(jax_path, x)
2800-
return variablelib.mutable_array(x)
2791+
x = jax.tree.map(variablelib.mutable_array, x)
2792+
elif isinstance(x, Variable | VariableState):
2793+
x = jax.tree.map(lambda x: x, x)
28012794
return x
28022795

28032796
return jax.tree.map_with_path(
2804-
_mutable_fn, tree, is_leaf=lambda x: isinstance(x, Variable)
2797+
_mutable_fn, tree, is_leaf=lambda x: isinstance(x, Variable | VariableState)
28052798
)
28062799

28072800

@@ -3047,6 +3040,11 @@ def _key_path_to_key(key: tp.Any) -> Key:
30473040
else:
30483041
return str(key)
30493042

3043+
3044+
def jax_to_nnx_path(jax_path: tuple, /):
3045+
return tuple(_key_path_to_key(part) for part in jax_path)
3046+
3047+
30503048
class IndexesPytreeDef(tp.NamedTuple):
30513049
key_index: HashableMapping[Key, int]
30523050
treedef: jax.tree_util.PyTreeDef

flax/nnx/nn/normalization.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import jax.numpy as jnp
1919
from jax import lax
2020

21-
from flax import nnx
21+
from flax import nnx, config
2222
from flax.nnx import rnglib
2323
from flax.nnx.module import Module, first_from
2424
from flax.nnx.nn import dtypes, initializers
@@ -360,11 +360,18 @@ def __call__(
360360
use_fast_variance=self.use_fast_variance,
361361
mask=mask,
362362
)
363+
# stop_gradient only for flax_mutable_array
364+
if config.flax_mutable_array:
365+
stop_gradient = jax.lax.stop_gradient
366+
else:
367+
stop_gradient = lambda x: x
363368

364-
self.mean[...] = (
369+
self.mean[...] = stop_gradient(
365370
self.momentum * self.mean.value + (1 - self.momentum) * mean
366371
)
367-
self.var[...] = self.momentum * self.var.value + (1 - self.momentum) * var
372+
self.var[...] = stop_gradient(
373+
self.momentum * self.var.value + (1 - self.momentum) * var
374+
)
368375

369376
return _normalize(
370377
x,

flax/nnx/training/optimizer.py

Lines changed: 98 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@
2020
import optax
2121

2222
from flax import nnx
23-
from flax.nnx import filterlib
24-
from flax.nnx import variablelib
23+
from flax.nnx import filterlib, graph
2524
from flax.nnx.object import Object
2625
from flax.nnx.variablelib import Variable, VariableState
2726

@@ -51,7 +50,7 @@ class OptVariable(OptState):
5150

5251
def _wrap_optimizer_state(opt_state):
5352
def wrap_optimizer_state_fn(x):
54-
if isinstance(x, variablelib.VariableState):
53+
if isinstance(x, VariableState):
5554
new_state = x.copy()
5655
new_state.source_type = x.type
5756
new_state.type = OptVariable
@@ -62,7 +61,7 @@ def wrap_optimizer_state_fn(x):
6261
return jax.tree.map(
6362
wrap_optimizer_state_fn,
6463
opt_state,
65-
is_leaf=lambda x: isinstance(x, variablelib.VariableState),
64+
is_leaf=lambda x: isinstance(x, VariableState),
6665
)
6766

6867

@@ -274,4 +273,98 @@ def update(self, grads, **kwargs):
274273

275274
self.step.value += 1
276275
nnx.update(self.model, new_params)
277-
_update_opt_state(self.opt_state, new_opt_state)
276+
_update_opt_state(self.opt_state, new_opt_state)
277+
278+
279+
def to_opt_state(tree):
280+
def _to_opt_state(x):
281+
if isinstance(x, Variable | VariableState):
282+
opt_state = OptVariable(x[...], **x.get_metadata()) # type: ignore
283+
else:
284+
opt_state = OptArray(x)
285+
return opt_state
286+
287+
tree = jax.tree.map(
288+
_to_opt_state,
289+
tree,
290+
is_leaf=lambda x: isinstance(x, Variable | VariableState),
291+
)
292+
return tree
293+
294+
295+
class OptaxOptimizer(Object):
296+
"""Stateful wrapper around an Optax optimizer.
297+
298+
Example usage::
299+
300+
>>> from flax import config
301+
>>> if not config.flax_mutable_array:
302+
... import pytest
303+
... pytest.skip('MutableArrays required for this example')
304+
...
305+
>>> import jax, jax.numpy as jnp
306+
>>> from flax import nnx
307+
>>> from flax import config
308+
>>> import optax
309+
...
310+
>>> class Model(nnx.Module):
311+
... __data__ = ('linear1', 'linear2', 'bn')
312+
... def __init__(self, rngs):
313+
... self.linear1 = nnx.Linear(2, 3, rngs=rngs)
314+
... self.bn = nnx.BatchNorm(3, rngs=rngs)
315+
... self.linear2 = nnx.Linear(3, 4, rngs=rngs)
316+
... def __call__(self, x):
317+
... return self.linear2(nnx.relu(self.bn(self.linear1(x))))
318+
...
319+
>>> x = jax.random.normal(jax.random.key(0), (5, 2))
320+
>>> y = jnp.ones((5, 4))
321+
...
322+
>>> model = Model(nnx.Rngs(1))
323+
>>> optimizer = nnx.OptaxOptimizer(nnx.state(model, nnx.Param), tx=optax.adam(1e-3))
324+
...
325+
>>> @jax.jit
326+
... def train_step(model, optimizer, x, y):
327+
... graphdef, params, nondiff = nnx.split(model, nnx.Param, ...)
328+
... def loss_fn(params):
329+
... model = nnx.merge(graphdef, params, nondiff)
330+
... return ((model(x) - y) ** 2).mean()
331+
...
332+
... loss, grads = jax.value_and_grad(loss_fn)(nnx.freeze(params))
333+
... optimizer.update(params, grads)
334+
... return loss
335+
...
336+
>>> loss = train_step(model, optimizer, x, y)
337+
>>> loss
338+
Array(1.2029127, dtype=float32)
339+
340+
Args:
341+
params: The parameters to be optimized.
342+
tx: An optax gradient transformation.
343+
"""
344+
__nodes__ = ('step', 'opt_state')
345+
346+
def __init__(self, params, tx: optax.GradientTransformation):
347+
self.tx = tx
348+
self.step = OptArray(jnp.array(0, dtype=jnp.uint32))
349+
self.opt_state = to_opt_state(tx.init(params))
350+
351+
def update(self, params, grads, **kwargs):
352+
param_arrays = graph.freeze(graph.pure(params))
353+
grad_arrays = graph.freeze(graph.pure(grads))
354+
opt_state_arrays = graph.freeze(graph.pure(self.opt_state))
355+
356+
updates, new_opt_state = self.tx.update(
357+
grad_arrays, opt_state_arrays, param_arrays, **kwargs
358+
)
359+
new_params = optax.apply_updates(param_arrays, updates)
360+
361+
def _update_variable(param, value):
362+
param[...] = value
363+
364+
jax.tree.map(
365+
_update_variable,
366+
(params, self.opt_state),
367+
(new_params, new_opt_state),
368+
is_leaf=lambda x: isinstance(x, Variable | VariableState),
369+
)
370+
self.step[...] += 1

flax/nnx/variablelib.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -219,18 +219,17 @@ def __delattr__(self, name: str):
219219
else:
220220
del self._var_metadata[name]
221221

222-
@classmethod
223-
def state(cls, value: A, **metadata) -> VariableState[A]:
224-
return cls(value, **metadata).to_state()
225-
226222
@property
227-
def mutable(self) -> bool | None:
223+
def mutable(self) -> bool:
228224
if is_mutable_array(self.raw_value):
229225
return True
230226
elif isinstance(self.raw_value, jax.Array):
231227
return False
232228
else:
233-
return None
229+
raise ValueError(
230+
f'mutable is only supported for jax.Array and MutableArray, '
231+
f'got {type(self.raw_value).__name__}'
232+
)
234233

235234
def get_metadata(self):
236235
return self._var_metadata
@@ -972,6 +971,18 @@ def raw_value(self) -> A:
972971
def raw_value(self, value: A) -> None:
973972
object.__setattr__(self, 'value', value)
974973

974+
@property
975+
def mutable(self) -> bool:
976+
if is_mutable_array(self.raw_value):
977+
return True
978+
elif isinstance(self.raw_value, jax.Array):
979+
return False
980+
else:
981+
raise ValueError(
982+
f'mutable is only supported for jax.Array and MutableArray, '
983+
f'got {type(self.raw_value).__name__}'
984+
)
985+
975986
def __getattribute__(self, name: str) -> None:
976987
if name == 'value':
977988
value = object.__getattribute__(self, 'value')

0 commit comments

Comments
 (0)