Skip to content

Commit adba45d

Browse files
committed
remove VariableState
1 parent 2f3b344 commit adba45d

40 files changed

+533
-751
lines changed

docs_nnx/api_reference/flax.nnx/transforms.rst

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,6 @@ transforms
44
.. automodule:: flax.nnx
55
.. currentmodule:: flax.nnx
66

7-
.. autoclass:: Jit
8-
:members:
9-
.. autoclass:: Remat
10-
:members:
11-
.. autoclass:: Scan
12-
:members:
13-
.. autoclass:: Vmap
14-
:members:
15-
167
.. autofunction:: grad
178
.. autofunction:: jit
189
.. autofunction:: shard_map

docs_nnx/api_reference/flax.nnx/variables.rst

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@ variables
1818
:members:
1919
.. autoclass:: VariableMetadata
2020
:members:
21-
.. autoclass:: VariableState
22-
:members:
2321

2422
.. autofunction:: with_metadata
2523

docs_nnx/guides/checkpointing.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@
131131
"\n",
132132
"## Restore checkpoints\n",
133133
"\n",
134-
"Note that you saved the checkpoint as a Flax class of [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State), which is also nested with [`nnx.VariableState`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.VariableState) and [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) classes.\n",
134+
"Note that you saved the checkpoint as a Flax class of [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State), which is also nested with [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) and [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) classes.\n",
135135
"\n",
136136
"At checkpoint restoration time, you need to have these classes ready in your runtime, and instruct the checkpointing library (Orbax) to restore your pytree back to that structure. This can be achieved as follows:\n",
137137
"- First, create an abstract Flax NNX model (without allocating any memory for arrays), and show its abstract variable state to the checkpointing library.\n",

docs_nnx/guides/checkpointing.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ checkpointer.save(ckpt_dir / 'state', state)
8282

8383
## Restore checkpoints
8484

85-
Note that you saved the checkpoint as a Flax class of [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State), which is also nested with [`nnx.VariableState`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.VariableState) and [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) classes.
85+
Note that you saved the checkpoint as a Flax class of [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State), which is also nested with [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) and [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) classes.
8686

8787
At checkpoint restoration time, you need to have these classes ready in your runtime, and instruct the checkpointing library (Orbax) to restore your pytree back to that structure. This can be achieved as follows:
8888
- First, create an abstract Flax NNX model (without allocating any memory for arrays), and show its abstract variable state to the checkpointing library.

docs_nnx/guides/filters_guide.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,7 @@
415415
"name": "python",
416416
"nbconvert_exporter": "python",
417417
"pygments_lexer": "ipython3",
418-
"version": "3.10.13"
418+
"version": "3.11.9"
419419
}
420420
},
421421
"nbformat": 4,

docs_nnx/guides/flax_gspmd.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -652,7 +652,7 @@
652652
"%%timeit\n",
653653
"\n",
654654
"def block_all(xs):\n",
655-
" jax.tree_util.tree_map(lambda x: x.block_until_ready(), xs)\n",
655+
" jax.tree.map(lambda x: x.block_until_ready(), xs)\n",
656656
" return xs\n",
657657
"\n",
658658
"with mesh:\n",

docs_nnx/guides/flax_gspmd.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ If you are using a Google TPU pod or a pod slice, you can create a custom `block
295295
%%timeit
296296
297297
def block_all(xs):
298-
jax.tree_util.tree_map(lambda x: x.block_until_ready(), xs)
298+
jax.tree.map(lambda x: x.block_until_ready(), xs)
299299
return xs
300300
301301
with mesh:

docs_nnx/guides/haiku_to_flax.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,9 @@ The dropout behavior:
199199
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
200200

201201
grads = nnx.grad(loss_fn)(model)
202-
_, params, rest = nnx.split(model, nnx.Param, ...)
202+
_, params, rest = nnx.pure(nnx.split(model, nnx.Param, ...))
203203
params = jax.tree.map(lambda p, g: p - 0.1 * g, params, grads)
204-
nnx.update(model, nnx.merge_state(params, rest))
204+
nnx.update(model, params, rest)
205205

206206
.. testcode:: Haiku
207207
:hide:

docs_nnx/guides/jax_and_nnx_transforms.rst

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ Notice that:
5656
def loss_fn(model):
5757
return ((model(x) - y) ** 2).mean()
5858
grads = nnx.grad(loss_fn)(model)
59-
params = nnx.state(model, nnx.Param)
60-
params = jax.tree_util.tree_map(
59+
params = nnx.pure(nnx.state(model, nnx.Param))
60+
params = jax.tree.map(
6161
lambda p, g: p - 0.1 * g, params, grads
6262
)
6363
nnx.update(model, params)
@@ -74,8 +74,8 @@ Notice that:
7474
grads = jax.grad(loss_fn, argnums=1)(graphdef, state) #!
7575

7676
model = nnx.merge(graphdef, state) #!
77-
params = nnx.state(model, nnx.Param)
78-
params = jax.tree_util.tree_map(
77+
params = nnx.pure(nnx.state(model, nnx.Param))
78+
params = jax.tree.map(
7979
lambda p, g: p - 0.1 * g, params, grads
8080
)
8181
nnx.update(model, params)
@@ -102,8 +102,8 @@ in your code is pure and has valid argument types that are recognized by JAX.
102102
model = nnx.merge(graphdef, state)
103103
return ((model(x) - y) ** 2).mean()
104104
grads = jax.grad(loss_fn, 1)(*nnx.split(model)) #!
105-
params = nnx.state(model, nnx.Param)
106-
params = jax.tree_util.tree_map(
105+
params = nnx.pure(nnx.state(model, nnx.Param))
106+
params = jax.tree.map(
107107
lambda p, g: p - 0.1 * g, params, grads
108108
)
109109
nnx.update(model, params)
@@ -118,8 +118,8 @@ in your code is pure and has valid argument types that are recognized by JAX.
118118
def loss_fn(model):
119119
return ((model(x) - y) ** 2).mean()
120120
grads = nnx.grad(loss_fn)(model)
121-
params = nnx.state(model, nnx.Param)
122-
params = jax.tree_util.tree_map(
121+
params = nnx.pure(nnx.state(model, nnx.Param))
122+
params = jax.tree.map(
123123
lambda p, g: p - 0.1 * g, params, grads
124124
)
125125
nnx.update(model, params)

docs_nnx/guides/linen_to_nnx.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ Dropout behavior:
182182
return optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
183183
184184
grads = nnx.grad(loss_fn)(model)
185-
_, params, rest = nnx.split(model, nnx.Param, ...)
185+
_, params, rest = nnx.pure(nnx.split(model, nnx.Param, ...))
186186
params = jax.tree.map(lambda p, g: p - 0.1 * g, params, grads)
187187
nnx.update(model, nnx.merge_state(params, rest))
188188

docs_nnx/linen_intro.ipynb

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@
276276
"init_variables = model.init(key2, x)\n",
277277
"y = model.apply(init_variables, x)\n",
278278
"\n",
279-
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n",
279+
"print('initialized parameter shapes:\\n', jax.tree.map(jnp.shape, flax.core.unfreeze(init_variables)))\n",
280280
"print('output:\\n', y)"
281281
]
282282
},
@@ -337,7 +337,7 @@
337337
"init_variables = model.init(key2, x)\n",
338338
"y = model.apply(init_variables, x)\n",
339339
"\n",
340-
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n",
340+
"print('initialized parameter shapes:\\n', jax.tree.map(jnp.shape, flax.core.unfreeze(init_variables)))\n",
341341
"print('output:\\n', y)"
342342
]
343343
},
@@ -706,7 +706,7 @@
706706
"\n",
707707
"print('updated variables:\\n', updated_variables)\n",
708708
"print('initialized variable shapes:\\n',\n",
709-
" jax.tree_util.tree_map(jnp.shape, init_variables))\n",
709+
" jax.tree.map(jnp.shape, init_variables))\n",
710710
"print('output:\\n', y)\n",
711711
"\n",
712712
"# Let's run these model variables during \"evaluation\":\n",
@@ -784,7 +784,7 @@
784784
"init_variables = model.init(key2, x)\n",
785785
"y = model.apply(init_variables, x)\n",
786786
"\n",
787-
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n",
787+
"print('initialized parameter shapes:\\n', jax.tree.map(jnp.shape, flax.core.unfreeze(init_variables)))\n",
788788
"print('output:\\n', y)"
789789
]
790790
},
@@ -850,7 +850,7 @@
850850
"init_variables = model.init(key2, x)\n",
851851
"y = model.apply(init_variables, x)\n",
852852
"\n",
853-
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n",
853+
"print('initialized parameter shapes:\\n', jax.tree.map(jnp.shape, flax.core.unfreeze(init_variables)))\n",
854854
"print('output:\\n', y)"
855855
]
856856
},
@@ -1001,7 +1001,7 @@
10011001
" batch_axes=(0,))\n",
10021002
"\n",
10031003
"init_variables = model(train=False).init({'params': key2}, x, x)\n",
1004-
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n",
1004+
"print('initialized parameter shapes:\\n', jax.tree.map(jnp.shape, flax.core.unfreeze(init_variables)))\n",
10051005
"\n",
10061006
"y = model(train=True).apply(init_variables, x, x, rngs={'dropout': key4})\n",
10071007
"print('output:\\n', y.shape)"
@@ -1076,7 +1076,7 @@
10761076
"model = SimpleScan(2)\n",
10771077
"init_variables = model.init(key2, xs)\n",
10781078
"\n",
1079-
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n",
1079+
"print('initialized parameter shapes:\\n', jax.tree.map(jnp.shape, flax.core.unfreeze(init_variables)))\n",
10801080
"\n",
10811081
"y = model.apply(init_variables, xs)\n",
10821082
"print('output:\\n', y)"

docs_nnx/linen_intro.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ model = ExplicitMLP(features=[3,4,5])
136136
init_variables = model.init(key2, x)
137137
y = model.apply(init_variables, x)
138138
139-
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))
139+
print('initialized parameter shapes:\n', jax.tree.map(jnp.shape, flax.core.unfreeze(init_variables)))
140140
print('output:\n', y)
141141
```
142142

@@ -168,7 +168,7 @@ model = SimpleMLP(features=[3,4,5])
168168
init_variables = model.init(key2, x)
169169
y = model.apply(init_variables, x)
170170
171-
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))
171+
print('initialized parameter shapes:\n', jax.tree.map(jnp.shape, flax.core.unfreeze(init_variables)))
172172
print('output:\n', y)
173173
```
174174

@@ -338,7 +338,7 @@ updated_variables = flax.core.freeze(dict(params=init_params,
338338
339339
print('updated variables:\n', updated_variables)
340340
print('initialized variable shapes:\n',
341-
jax.tree_util.tree_map(jnp.shape, init_variables))
341+
jax.tree.map(jnp.shape, init_variables))
342342
print('output:\n', y)
343343
344344
# Let's run these model variables during "evaluation":
@@ -383,7 +383,7 @@ model = MLP(features=[3,4,5])
383383
init_variables = model.init(key2, x)
384384
y = model.apply(init_variables, x)
385385
386-
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))
386+
print('initialized parameter shapes:\n', jax.tree.map(jnp.shape, flax.core.unfreeze(init_variables)))
387387
print('output:\n', y)
388388
```
389389

@@ -420,7 +420,7 @@ model = RematMLP(features=[3,4,5])
420420
init_variables = model.init(key2, x)
421421
y = model.apply(init_variables, x)
422422
423-
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))
423+
print('initialized parameter shapes:\n', jax.tree.map(jnp.shape, flax.core.unfreeze(init_variables)))
424424
print('output:\n', y)
425425
```
426426

@@ -545,7 +545,7 @@ model = functools.partial(
545545
batch_axes=(0,))
546546
547547
init_variables = model(train=False).init({'params': key2}, x, x)
548-
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))
548+
print('initialized parameter shapes:\n', jax.tree.map(jnp.shape, flax.core.unfreeze(init_variables)))
549549
550550
y = model(train=True).apply(init_variables, x, x, rngs={'dropout': key4})
551551
print('output:\n', y.shape)
@@ -590,7 +590,7 @@ xs = random.uniform(key1, (1, 5, 2))
590590
model = SimpleScan(2)
591591
init_variables = model.init(key2, xs)
592592
593-
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))
593+
print('initialized parameter shapes:\n', jax.tree.map(jnp.shape, flax.core.unfreeze(init_variables)))
594594
595595
y = model.apply(init_variables, xs)
596596
print('output:\n', y)

docs_nnx/nnx_glossary.rst

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,4 @@ For additional terms, refer to the `JAX glossary <https://jax.readthedocs.io/en/
3737
A Flax NNX transformation (transform) is a wrapped version of a `JAX transformation <https://flax.readthedocs.io/en/latest/guides/transforms.html>`__ that allows the function that is being transformed to take the Flax NNX :term:`Module<Module>` (``nnx.Module``) as input or output. For example, a "lifted" version of `jax.jit <https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit>`__ is :meth:`nnx.jit <flax.nnx.jit>`. Check out the `Flax NNX transforms guide <https://flax.readthedocs.io/en/latest/guides/transforms.html>`__ to learn more.
3838

3939
Variable
40-
The weights / parameters / data / array :class:`nnx.Variable <flax.nnx.Variable>` residing in a Flax :term:`Module<Module>`. Variables are defined inside modules as :class:`nnx.Variable <flax.nnx.Variable>` or its subclasses.
41-
42-
Variable state
43-
:class:`nnx.VariableState <flax.nnx.VariableState>` is a purely functional `JAX pytree <https://jax.readthedocs.io/en/latest/working-with-pytrees.html>`__ of all the :term:`Variables<Variable>` inside a :term:`Module<Module>`. Since it is pure, it can be an input or output of a `JAX transformation <https://jax.readthedocs.io/en/latest/key-concepts.html#transformations>`__ function. ``nnx.VariableState`` is obtained by using :meth:`nnx.split <flax.nnx.split>` on the :class:`nnx.Module <flax.nnx.Module>`. (Refer to :term:`splitting<Split and merge>` and :term:`Module<Module>` to learn more.)
40+
The weights / parameters / data / array :class:`nnx.Variable <flax.nnx.Variable>` residing in a Flax :term:`Module<Module>`. Variables are defined inside modules as :class:`nnx.Variable <flax.nnx.Variable>` or its subclasses.

examples/nnx_toy_examples/10_fsdp_and_optimizer.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -71,49 +71,58 @@ def __call__(self, x: jax.Array):
7171
return nnx.relu(x @ self.w1 + self.b1) @ self.w2
7272

7373

74-
class SGDState(nnx.Variable):
74+
class SGDState(nnx.Variable[jax.Array]):
7575
pass
7676

7777

7878
class SGD(nnx.Object):
79-
def __init__(self, params: nnx.State, lr, decay=0.9):
79+
def __init__(self, params, lr, decay=0.9):
8080
def init_optimizer_state(variable: nnx.Variable):
8181
return SGDState(
8282
jnp.zeros_like(variable.value), **variable.get_metadata()
8383
)
8484

8585
self.lr = lr
86-
self.params = params
87-
self.momentum: nnx.State = jax.tree.map(init_optimizer_state, self.params)
86+
self.momentum: nnx.State = jax.tree.map(
87+
init_optimizer_state,
88+
params,
89+
is_leaf=lambda x: isinstance(x, nnx.Variable),
90+
)
8891
self.decay = decay
8992

90-
def update(self, grads: nnx.State):
93+
def update(self, params, grads):
9194
def update_fn(
92-
params: nnx.Variable, momentum: SGDState, grad: nnx.VariableState
95+
params: nnx.Variable[jax.Array], momentum: SGDState, grad: jax.Array
9396
):
9497
# v_t = β * v_{t-1} + (1 - β) * ∇J(θ_t)
95-
momentum.value = self.decay * momentum + (1 - self.decay) * grad.value
98+
momentum.value = self.decay * momentum[...] + (1 - self.decay) * grad[...]
9699
# θ_{t+1} = θ_t - α * v_t
97-
params.value -= self.lr * momentum
98-
99-
jax.tree.map(update_fn, self.params, self.momentum, grads)
100+
params.value -= self.lr * momentum[...]
101+
102+
jax.tree.map(
103+
update_fn,
104+
params,
105+
self.momentum,
106+
grads,
107+
is_leaf=lambda x: isinstance(x, nnx.Variable),
108+
)
100109

101110

102111
@nnx.jit
103112
def create_model():
104113
model = MLP(1, 32, 1, rngs=nnx.Rngs(0))
105114
optimizer = SGD(nnx.variables(model, nnx.Param), 0.01, decay=0.9)
106-
state = nnx.state(optimizer)
115+
state = nnx.variables(optimizer)
107116
sharded_state = jax.lax.with_sharding_constraint(
108117
state, nnx.get_named_sharding(state, mesh)
109118
)
110119

111-
def get_named_shardings(path: tuple, value: nnx.VariableState):
120+
def get_named_shardings(path: tuple, value: nnx.Variable):
112121
if path[0] == 'params':
113-
return value.replace(NamedSharding(mesh, P(*value.sharding)))
122+
return NamedSharding(mesh, P(*value.sharding))
114123
elif path[0] == 'momentum':
115124
# currently the same as above but in general it could be different
116-
return value.replace(NamedSharding(mesh, P(*value.sharding)))
125+
return NamedSharding(mesh, P(*value.sharding))
117126
else:
118127
raise ValueError(f'Unknown path: {path}')
119128

@@ -137,7 +146,7 @@ def loss_fn(model):
137146
return loss
138147

139148
loss, grad = nnx.value_and_grad(loss_fn)(model)
140-
optimizer.update(grad)
149+
optimizer.update(nnx.variables(model, nnx.Param), grad)
141150
return loss
142151

143152

flax/nnx/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from .graph import split as split
4444
from .graph import update as update
4545
from .graph import clone as clone
46+
from .graph import pure as pure
4647
from .graph import pop as pop
4748
from .graph import state as state
4849
from .graph import graphdef as graphdef
@@ -161,7 +162,6 @@
161162
from .variablelib import Intermediate as Intermediate
162163
from .variablelib import Perturbation as Perturbation
163164
from .variablelib import Variable as Variable
164-
from .variablelib import VariableState as VariableState
165165
from .variablelib import VariableMetadata as VariableMetadata
166166
from .variablelib import with_metadata as with_metadata
167167
from .variablelib import variable_type_from_name as variable_type_from_name

0 commit comments

Comments
 (0)