We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 5131718 commit 866fe95Copy full SHA for 866fe95
flax/nnx/variablelib.py
@@ -784,13 +784,12 @@ def _variable_unflatten(
784
):
785
return cls.from_metadata(value=children[0], attributes=dict(static))
786
787
-if config.flax_mutable_array:
788
- jax.tree_util.register_pytree_with_keys(
789
- Variable,
790
- flatten_with_keys=_variable_flatten_with_keys,
791
- unflatten_func=partial(_variable_unflatten, Variable), # type: ignore
792
- flatten_func=_variable_flatten,
793
- )
+jax.tree_util.register_pytree_with_keys(
+ Variable,
+ flatten_with_keys=_variable_flatten_with_keys,
+ unflatten_func=partial(_variable_unflatten, Variable), # type: ignore
+ flatten_func=_variable_flatten,
+)
794
795
796
class Param(Variable[A]):
0 commit comments