Skip to content

Remove VariableState #4729

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions docs_nnx/api_reference/flax.nnx/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,6 @@ transforms
.. automodule:: flax.nnx
.. currentmodule:: flax.nnx

.. autoclass:: Jit
:members:
.. autoclass:: Remat
:members:
.. autoclass:: Scan
:members:
.. autoclass:: Vmap
:members:

.. autofunction:: grad
.. autofunction:: jit
.. autofunction:: shard_map
Expand Down
2 changes: 0 additions & 2 deletions docs_nnx/api_reference/flax.nnx/variables.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ variables
:members:
.. autoclass:: VariableMetadata
:members:
.. autoclass:: VariableState
:members:

.. autofunction:: with_metadata

Expand Down
2 changes: 1 addition & 1 deletion docs_nnx/guides/checkpointing.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@
"\n",
"## Restore checkpoints\n",
"\n",
"Note that you saved the checkpoint as a Flax class of [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State), which is also nested with [`nnx.VariableState`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.VariableState) and [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) classes.\n",
"Note that you saved the checkpoint as a Flax class of [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State), which is also nested with [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) and [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) classes.\n",
"\n",
"At checkpoint restoration time, you need to have these classes ready in your runtime, and instruct the checkpointing library (Orbax) to restore your pytree back to that structure. This can be achieved as follows:\n",
"- First, create an abstract Flax NNX model (without allocating any memory for arrays), and show its abstract variable state to the checkpointing library.\n",
Expand Down
2 changes: 1 addition & 1 deletion docs_nnx/guides/checkpointing.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ checkpointer.save(ckpt_dir / 'state', state)

## Restore checkpoints

Note that you saved the checkpoint as a Flax class of [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State), which is also nested with [`nnx.VariableState`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.VariableState) and [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) classes.
Note that you saved the checkpoint as a Flax class of [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State), which is also nested with [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) and [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) classes.

At checkpoint restoration time, you need to have these classes ready in your runtime, and instruct the checkpointing library (Orbax) to restore your pytree back to that structure. This can be achieved as follows:
- First, create an abstract Flax NNX model (without allocating any memory for arrays), and show its abstract variable state to the checkpointing library.
Expand Down
2 changes: 1 addition & 1 deletion docs_nnx/guides/filters_guide.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
"version": "3.11.9"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion docs_nnx/guides/flax_gspmd.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@
"%%timeit\n",
"\n",
"def block_all(xs):\n",
" jax.tree_util.tree_map(lambda x: x.block_until_ready(), xs)\n",
" jax.tree.map(lambda x: x.block_until_ready(), xs)\n",
" return xs\n",
"\n",
"with mesh:\n",
Expand Down
2 changes: 1 addition & 1 deletion docs_nnx/guides/flax_gspmd.md
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ If you are using a Google TPU pod or a pod slice, you can create a custom `block
%%timeit

def block_all(xs):
jax.tree_util.tree_map(lambda x: x.block_until_ready(), xs)
jax.tree.map(lambda x: x.block_until_ready(), xs)
return xs

with mesh:
Expand Down
4 changes: 2 additions & 2 deletions docs_nnx/guides/haiku_to_flax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,9 @@ The dropout behavior:
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

grads = nnx.grad(loss_fn)(model)
_, params, rest = nnx.split(model, nnx.Param, ...)
_, params, rest = nnx.pure(nnx.split(model, nnx.Param, ...))
params = jax.tree.map(lambda p, g: p - 0.1 * g, params, grads)
nnx.update(model, nnx.merge_state(params, rest))
nnx.update(model, params, rest)

.. testcode:: Haiku
:hide:
Expand Down
16 changes: 8 additions & 8 deletions docs_nnx/guides/jax_and_nnx_transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ Notice that:
def loss_fn(model):
return ((model(x) - y) ** 2).mean()
grads = nnx.grad(loss_fn)(model)
params = nnx.state(model, nnx.Param)
params = jax.tree_util.tree_map(
params = nnx.pure(nnx.state(model, nnx.Param))
params = jax.tree.map(
lambda p, g: p - 0.1 * g, params, grads
)
nnx.update(model, params)
Expand All @@ -74,8 +74,8 @@ Notice that:
grads = jax.grad(loss_fn, argnums=1)(graphdef, state) #!

model = nnx.merge(graphdef, state) #!
params = nnx.state(model, nnx.Param)
params = jax.tree_util.tree_map(
params = nnx.pure(nnx.state(model, nnx.Param))
params = jax.tree.map(
lambda p, g: p - 0.1 * g, params, grads
)
nnx.update(model, params)
Expand All @@ -102,8 +102,8 @@ in your code is pure and has valid argument types that are recognized by JAX.
model = nnx.merge(graphdef, state)
return ((model(x) - y) ** 2).mean()
grads = jax.grad(loss_fn, 1)(*nnx.split(model)) #!
params = nnx.state(model, nnx.Param)
params = jax.tree_util.tree_map(
params = nnx.pure(nnx.state(model, nnx.Param))
params = jax.tree.map(
lambda p, g: p - 0.1 * g, params, grads
)
nnx.update(model, params)
Expand All @@ -118,8 +118,8 @@ in your code is pure and has valid argument types that are recognized by JAX.
def loss_fn(model):
return ((model(x) - y) ** 2).mean()
grads = nnx.grad(loss_fn)(model)
params = nnx.state(model, nnx.Param)
params = jax.tree_util.tree_map(
params = nnx.pure(nnx.state(model, nnx.Param))
params = jax.tree.map(
lambda p, g: p - 0.1 * g, params, grads
)
nnx.update(model, params)
Expand Down
2 changes: 1 addition & 1 deletion docs_nnx/guides/linen_to_nnx.rst
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ Dropout behavior:
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()

grads = nnx.grad(loss_fn)(model)
_, params, rest = nnx.split(model, nnx.Param, ...)
_, params, rest = nnx.pure(nnx.split(model, nnx.Param, ...))
params = jax.tree.map(lambda p, g: p - 0.1 * g, params, grads)
nnx.update(model, nnx.merge_state(params, rest))

Expand Down
14 changes: 7 additions & 7 deletions docs_nnx/linen_intro.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@
"init_variables = model.init(key2, x)\n",
"y = model.apply(init_variables, x)\n",
"\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n",
"print('initialized parameter shapes:\\n', jax.tree.map(jnp.shape, flax.core.unfreeze(init_variables)))\n",
"print('output:\\n', y)"
]
},
Expand Down Expand Up @@ -337,7 +337,7 @@
"init_variables = model.init(key2, x)\n",
"y = model.apply(init_variables, x)\n",
"\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n",
"print('initialized parameter shapes:\\n', jax.tree.map(jnp.shape, flax.core.unfreeze(init_variables)))\n",
"print('output:\\n', y)"
]
},
Expand Down Expand Up @@ -706,7 +706,7 @@
"\n",
"print('updated variables:\\n', updated_variables)\n",
"print('initialized variable shapes:\\n',\n",
" jax.tree_util.tree_map(jnp.shape, init_variables))\n",
" jax.tree.map(jnp.shape, init_variables))\n",
"print('output:\\n', y)\n",
"\n",
"# Let's run these model variables during \"evaluation\":\n",
Expand Down Expand Up @@ -784,7 +784,7 @@
"init_variables = model.init(key2, x)\n",
"y = model.apply(init_variables, x)\n",
"\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n",
"print('initialized parameter shapes:\\n', jax.tree.map(jnp.shape, flax.core.unfreeze(init_variables)))\n",
"print('output:\\n', y)"
]
},
Expand Down Expand Up @@ -850,7 +850,7 @@
"init_variables = model.init(key2, x)\n",
"y = model.apply(init_variables, x)\n",
"\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n",
"print('initialized parameter shapes:\\n', jax.tree.map(jnp.shape, flax.core.unfreeze(init_variables)))\n",
"print('output:\\n', y)"
]
},
Expand Down Expand Up @@ -1001,7 +1001,7 @@
" batch_axes=(0,))\n",
"\n",
"init_variables = model(train=False).init({'params': key2}, x, x)\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n",
"print('initialized parameter shapes:\\n', jax.tree.map(jnp.shape, flax.core.unfreeze(init_variables)))\n",
"\n",
"y = model(train=True).apply(init_variables, x, x, rngs={'dropout': key4})\n",
"print('output:\\n', y.shape)"
Expand Down Expand Up @@ -1076,7 +1076,7 @@
"model = SimpleScan(2)\n",
"init_variables = model.init(key2, xs)\n",
"\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n",
"print('initialized parameter shapes:\\n', jax.tree.map(jnp.shape, flax.core.unfreeze(init_variables)))\n",
"\n",
"y = model.apply(init_variables, xs)\n",
"print('output:\\n', y)"
Expand Down
14 changes: 7 additions & 7 deletions docs_nnx/linen_intro.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ model = ExplicitMLP(features=[3,4,5])
init_variables = model.init(key2, x)
y = model.apply(init_variables, x)

print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))
print('initialized parameter shapes:\n', jax.tree.map(jnp.shape, flax.core.unfreeze(init_variables)))
print('output:\n', y)
```

Expand Down Expand Up @@ -168,7 +168,7 @@ model = SimpleMLP(features=[3,4,5])
init_variables = model.init(key2, x)
y = model.apply(init_variables, x)

print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))
print('initialized parameter shapes:\n', jax.tree.map(jnp.shape, flax.core.unfreeze(init_variables)))
print('output:\n', y)
```

Expand Down Expand Up @@ -338,7 +338,7 @@ updated_variables = flax.core.freeze(dict(params=init_params,

print('updated variables:\n', updated_variables)
print('initialized variable shapes:\n',
jax.tree_util.tree_map(jnp.shape, init_variables))
jax.tree.map(jnp.shape, init_variables))
print('output:\n', y)

# Let's run these model variables during "evaluation":
Expand Down Expand Up @@ -383,7 +383,7 @@ model = MLP(features=[3,4,5])
init_variables = model.init(key2, x)
y = model.apply(init_variables, x)

print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))
print('initialized parameter shapes:\n', jax.tree.map(jnp.shape, flax.core.unfreeze(init_variables)))
print('output:\n', y)
```

Expand Down Expand Up @@ -420,7 +420,7 @@ model = RematMLP(features=[3,4,5])
init_variables = model.init(key2, x)
y = model.apply(init_variables, x)

print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))
print('initialized parameter shapes:\n', jax.tree.map(jnp.shape, flax.core.unfreeze(init_variables)))
print('output:\n', y)
```

Expand Down Expand Up @@ -545,7 +545,7 @@ model = functools.partial(
batch_axes=(0,))

init_variables = model(train=False).init({'params': key2}, x, x)
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))
print('initialized parameter shapes:\n', jax.tree.map(jnp.shape, flax.core.unfreeze(init_variables)))

y = model(train=True).apply(init_variables, x, x, rngs={'dropout': key4})
print('output:\n', y.shape)
Expand Down Expand Up @@ -590,7 +590,7 @@ xs = random.uniform(key1, (1, 5, 2))
model = SimpleScan(2)
init_variables = model.init(key2, xs)

print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))
print('initialized parameter shapes:\n', jax.tree.map(jnp.shape, flax.core.unfreeze(init_variables)))

y = model.apply(init_variables, xs)
print('output:\n', y)
Expand Down
5 changes: 1 addition & 4 deletions docs_nnx/nnx_glossary.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,4 @@ For additional terms, refer to the `JAX glossary <https://jax.readthedocs.io/en/
A Flax NNX transformation (transform) is a wrapped version of a `JAX transformation <https://flax.readthedocs.io/en/latest/guides/transforms.html>`__ that allows the function that is being transformed to take the Flax NNX :term:`Module<Module>` (``nnx.Module``) as input or output. For example, a "lifted" version of `jax.jit <https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit>`__ is :meth:`nnx.jit <flax.nnx.jit>`. Check out the `Flax NNX transforms guide <https://flax.readthedocs.io/en/latest/guides/transforms.html>`__ to learn more.

Variable
The weights / parameters / data / array :class:`nnx.Variable <flax.nnx.Variable>` residing in a Flax :term:`Module<Module>`. Variables are defined inside modules as :class:`nnx.Variable <flax.nnx.Variable>` or its subclasses.

Variable state
:class:`nnx.VariableState <flax.nnx.VariableState>` is a purely functional `JAX pytree <https://jax.readthedocs.io/en/latest/working-with-pytrees.html>`__ of all the :term:`Variables<Variable>` inside a :term:`Module<Module>`. Since it is pure, it can be an input or output of a `JAX transformation <https://jax.readthedocs.io/en/latest/key-concepts.html#transformations>`__ function. ``nnx.VariableState`` is obtained by using :meth:`nnx.split <flax.nnx.split>` on the :class:`nnx.Module <flax.nnx.Module>`. (Refer to :term:`splitting<Split and merge>` and :term:`Module<Module>` to learn more.)
The weights / parameters / data / array :class:`nnx.Variable <flax.nnx.Variable>` residing in a Flax :term:`Module<Module>`. Variables are defined inside modules as :class:`nnx.Variable <flax.nnx.Variable>` or its subclasses.
39 changes: 24 additions & 15 deletions examples/nnx_toy_examples/10_fsdp_and_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,49 +71,58 @@ def __call__(self, x: jax.Array):
return nnx.relu(x @ self.w1 + self.b1) @ self.w2


class SGDState(nnx.Variable):
class SGDState(nnx.Variable[jax.Array]):
pass


class SGD(nnx.Object):
def __init__(self, params: nnx.State, lr, decay=0.9):
def __init__(self, params, lr, decay=0.9):
def init_optimizer_state(variable: nnx.Variable):
return SGDState(
jnp.zeros_like(variable.value), **variable.get_metadata()
)

self.lr = lr
self.params = params
self.momentum: nnx.State = jax.tree.map(init_optimizer_state, self.params)
self.momentum: nnx.State = jax.tree.map(
init_optimizer_state,
params,
is_leaf=lambda x: isinstance(x, nnx.Variable),
)
self.decay = decay

def update(self, grads: nnx.State):
def update(self, params, grads):
def update_fn(
params: nnx.Variable, momentum: SGDState, grad: nnx.VariableState
params: nnx.Variable[jax.Array], momentum: SGDState, grad: jax.Array
):
# v_t = β * v_{t-1} + (1 - β) * ∇J(θ_t)
momentum.value = self.decay * momentum + (1 - self.decay) * grad.value
momentum.value = self.decay * momentum[...] + (1 - self.decay) * grad[...]
# θ_{t+1} = θ_t - α * v_t
params.value -= self.lr * momentum

jax.tree.map(update_fn, self.params, self.momentum, grads)
params.value -= self.lr * momentum[...]

jax.tree.map(
update_fn,
params,
self.momentum,
grads,
is_leaf=lambda x: isinstance(x, nnx.Variable),
)


@nnx.jit
def create_model():
model = MLP(1, 32, 1, rngs=nnx.Rngs(0))
optimizer = SGD(nnx.variables(model, nnx.Param), 0.01, decay=0.9)
state = nnx.state(optimizer)
state = nnx.variables(optimizer)
sharded_state = jax.lax.with_sharding_constraint(
state, nnx.get_named_sharding(state, mesh)
)

def get_named_shardings(path: tuple, value: nnx.VariableState):
def get_named_shardings(path: tuple, value: nnx.Variable):
if path[0] == 'params':
return value.replace(NamedSharding(mesh, P(*value.sharding)))
return NamedSharding(mesh, P(*value.sharding))
elif path[0] == 'momentum':
# currently the same as above but in general it could be different
return value.replace(NamedSharding(mesh, P(*value.sharding)))
return NamedSharding(mesh, P(*value.sharding))
else:
raise ValueError(f'Unknown path: {path}')

Expand All @@ -137,7 +146,7 @@ def loss_fn(model):
return loss

loss, grad = nnx.value_and_grad(loss_fn)(model)
optimizer.update(grad)
optimizer.update(nnx.variables(model, nnx.Param), grad)
return loss


Expand Down
4 changes: 1 addition & 3 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@
from .filterlib import Everything as Everything
from .filterlib import Nothing as Nothing
from .graph import GraphDef as GraphDef
from .graph import GraphState as GraphState
from .graph import PureState as PureState
from .object import Object as Object
from .helpers import Dict as Dict
from .helpers import Sequential as Sequential
Expand All @@ -43,6 +41,7 @@
from .graph import split as split
from .graph import update as update
from .graph import clone as clone
from .graph import pure as pure
from .graph import pop as pop
from .graph import state as state
from .graph import graphdef as graphdef
Expand Down Expand Up @@ -161,7 +160,6 @@
from .variablelib import Intermediate as Intermediate
from .variablelib import Perturbation as Perturbation
from .variablelib import Variable as Variable
from .variablelib import VariableState as VariableState
from .variablelib import VariableMetadata as VariableMetadata
from .variablelib import with_metadata as with_metadata
from .variablelib import variable_type_from_name as variable_type_from_name
Expand Down
Loading
Loading