Skip to content

Commit 866fe95

Browse files
Jake VanderPlasFlax Authors
Jake VanderPlas
authored and
Flax Authors
committed
[flax] unconditionally register nnx.Variable as a pytree
Required by changes to JAX in jax-ml/jax#28630 PiperOrigin-RevId: 758847244
1 parent 5131718 commit 866fe95

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

flax/nnx/variablelib.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -784,13 +784,12 @@ def _variable_unflatten(
784784
):
785785
return cls.from_metadata(value=children[0], attributes=dict(static))
786786

787-
if config.flax_mutable_array:
788-
jax.tree_util.register_pytree_with_keys(
789-
Variable,
790-
flatten_with_keys=_variable_flatten_with_keys,
791-
unflatten_func=partial(_variable_unflatten, Variable), # type: ignore
792-
flatten_func=_variable_flatten,
793-
)
787+
jax.tree_util.register_pytree_with_keys(
788+
Variable,
789+
flatten_with_keys=_variable_flatten_with_keys,
790+
unflatten_func=partial(_variable_unflatten, Variable), # type: ignore
791+
flatten_func=_variable_flatten,
792+
)
794793

795794

796795
class Param(Variable[A]):

0 commit comments

Comments
 (0)