Skip to content

Commit 73a1b15

Browse files
committed
remove VariableState
1 parent 2f3b344 commit 73a1b15

31 files changed

+437
-641
lines changed

docs_nnx/api_reference/flax.nnx/variables.rst

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@ variables
1818
:members:
1919
.. autoclass:: VariableMetadata
2020
:members:
21-
.. autoclass:: VariableState
22-
:members:
2321

2422
.. autofunction:: with_metadata
2523

docs_nnx/guides/checkpointing.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@
131131
"\n",
132132
"## Restore checkpoints\n",
133133
"\n",
134-
"Note that you saved the checkpoint as a Flax class of [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State), which is also nested with [`nnx.VariableState`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.VariableState) and [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) classes.\n",
134+
"Note that you saved the checkpoint as a Flax class of [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State), which is also nested with [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) and [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) classes.\n",
135135
"\n",
136136
"At checkpoint restoration time, you need to have these classes ready in your runtime, and instruct the checkpointing library (Orbax) to restore your pytree back to that structure. This can be achieved as follows:\n",
137137
"- First, create an abstract Flax NNX model (without allocating any memory for arrays), and show its abstract variable state to the checkpointing library.\n",

docs_nnx/guides/checkpointing.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ checkpointer.save(ckpt_dir / 'state', state)
8282

8383
## Restore checkpoints
8484

85-
Note that you saved the checkpoint as a Flax class of [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State), which is also nested with [`nnx.VariableState`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.VariableState) and [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) classes.
85+
Note that you saved the checkpoint as a Flax class of [`nnx.State`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/state.html#flax.nnx.State), which is also nested with [`nnx.Variable`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Variable) and [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param) classes.
8686

8787
At checkpoint restoration time, you need to have these classes ready in your runtime, and instruct the checkpointing library (Orbax) to restore your pytree back to that structure. This can be achieved as follows:
8888
- First, create an abstract Flax NNX model (without allocating any memory for arrays), and show its abstract variable state to the checkpointing library.

docs_nnx/guides/filters_guide.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,7 @@
415415
"name": "python",
416416
"nbconvert_exporter": "python",
417417
"pygments_lexer": "ipython3",
418-
"version": "3.10.13"
418+
"version": "3.11.9"
419419
}
420420
},
421421
"nbformat": 4,

docs_nnx/nnx_glossary.rst

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,4 @@ For additional terms, refer to the `JAX glossary <https://jax.readthedocs.io/en/
3737
A Flax NNX transformation (transform) is a wrapped version of a `JAX transformation <https://flax.readthedocs.io/en/latest/guides/transforms.html>`__ that allows the function that is being transformed to take the Flax NNX :term:`Module<Module>` (``nnx.Module``) as input or output. For example, a "lifted" version of `jax.jit <https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit>`__ is :meth:`nnx.jit <flax.nnx.jit>`. Check out the `Flax NNX transforms guide <https://flax.readthedocs.io/en/latest/guides/transforms.html>`__ to learn more.
3838

3939
Variable
40-
The weights / parameters / data / array :class:`nnx.Variable <flax.nnx.Variable>` residing in a Flax :term:`Module<Module>`. Variables are defined inside modules as :class:`nnx.Variable <flax.nnx.Variable>` or its subclasses.
41-
42-
Variable state
43-
:class:`nnx.VariableState <flax.nnx.VariableState>` is a purely functional `JAX pytree <https://jax.readthedocs.io/en/latest/working-with-pytrees.html>`__ of all the :term:`Variables<Variable>` inside a :term:`Module<Module>`. Since it is pure, it can be an input or output of a `JAX transformation <https://jax.readthedocs.io/en/latest/key-concepts.html#transformations>`__ function. ``nnx.VariableState`` is obtained by using :meth:`nnx.split <flax.nnx.split>` on the :class:`nnx.Module <flax.nnx.Module>`. (Refer to :term:`splitting<Split and merge>` and :term:`Module<Module>` to learn more.)
40+
The weights / parameters / data / array :class:`nnx.Variable <flax.nnx.Variable>` residing in a Flax :term:`Module<Module>`. Variables are defined inside modules as :class:`nnx.Variable <flax.nnx.Variable>` or its subclasses.

examples/nnx_toy_examples/10_fsdp_and_optimizer.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -71,49 +71,58 @@ def __call__(self, x: jax.Array):
7171
return nnx.relu(x @ self.w1 + self.b1) @ self.w2
7272

7373

74-
class SGDState(nnx.Variable):
74+
class SGDState(nnx.Variable[jax.Array]):
7575
pass
7676

7777

7878
class SGD(nnx.Object):
79-
def __init__(self, params: nnx.State, lr, decay=0.9):
79+
def __init__(self, params, lr, decay=0.9):
8080
def init_optimizer_state(variable: nnx.Variable):
8181
return SGDState(
8282
jnp.zeros_like(variable.value), **variable.get_metadata()
8383
)
8484

8585
self.lr = lr
86-
self.params = params
87-
self.momentum: nnx.State = jax.tree.map(init_optimizer_state, self.params)
86+
self.momentum: nnx.State = jax.tree.map(
87+
init_optimizer_state,
88+
params,
89+
is_leaf=lambda x: isinstance(x, nnx.Variable),
90+
)
8891
self.decay = decay
8992

90-
def update(self, grads: nnx.State):
93+
def update(self, params, grads):
9194
def update_fn(
92-
params: nnx.Variable, momentum: SGDState, grad: nnx.VariableState
95+
params: nnx.Variable[jax.Array], momentum: SGDState, grad: jax.Array
9396
):
9497
# v_t = β * v_{t-1} + (1 - β) * ∇J(θ_t)
95-
momentum.value = self.decay * momentum + (1 - self.decay) * grad.value
98+
momentum.value = self.decay * momentum[...] + (1 - self.decay) * grad[...]
9699
# θ_{t+1} = θ_t - α * v_t
97-
params.value -= self.lr * momentum
98-
99-
jax.tree.map(update_fn, self.params, self.momentum, grads)
100+
params.value -= self.lr * momentum[...]
101+
102+
jax.tree.map(
103+
update_fn,
104+
params,
105+
self.momentum,
106+
grads,
107+
is_leaf=lambda x: isinstance(x, nnx.Variable),
108+
)
100109

101110

102111
@nnx.jit
103112
def create_model():
104113
model = MLP(1, 32, 1, rngs=nnx.Rngs(0))
105114
optimizer = SGD(nnx.variables(model, nnx.Param), 0.01, decay=0.9)
106-
state = nnx.state(optimizer)
115+
state = nnx.variables(optimizer)
107116
sharded_state = jax.lax.with_sharding_constraint(
108117
state, nnx.get_named_sharding(state, mesh)
109118
)
110119

111-
def get_named_shardings(path: tuple, value: nnx.VariableState):
120+
def get_named_shardings(path: tuple, value: nnx.Variable):
112121
if path[0] == 'params':
113-
return value.replace(NamedSharding(mesh, P(*value.sharding)))
122+
return NamedSharding(mesh, P(*value.sharding))
114123
elif path[0] == 'momentum':
115124
# currently the same as above but in general it could be different
116-
return value.replace(NamedSharding(mesh, P(*value.sharding)))
125+
return NamedSharding(mesh, P(*value.sharding))
117126
else:
118127
raise ValueError(f'Unknown path: {path}')
119128

@@ -137,7 +146,7 @@ def loss_fn(model):
137146
return loss
138147

139148
loss, grad = nnx.value_and_grad(loss_fn)(model)
140-
optimizer.update(grad)
149+
optimizer.update(nnx.variables(model, nnx.Param), grad)
141150
return loss
142151

143152

flax/nnx/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from .graph import split as split
4444
from .graph import update as update
4545
from .graph import clone as clone
46+
from .graph import pure as pure
4647
from .graph import pop as pop
4748
from .graph import state as state
4849
from .graph import graphdef as graphdef
@@ -161,7 +162,6 @@
161162
from .variablelib import Intermediate as Intermediate
162163
from .variablelib import Perturbation as Perturbation
163164
from .variablelib import Variable as Variable
164-
from .variablelib import VariableState as VariableState
165165
from .variablelib import VariableMetadata as VariableMetadata
166166
from .variablelib import with_metadata as with_metadata
167167
from .variablelib import variable_type_from_name as variable_type_from_name

flax/nnx/bridge/module.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -374,28 +374,27 @@ def _get_variables(self) -> tp.Mapping:
374374
state = graph.state(self)
375375
_variables: dict = {}
376376

377-
variable_state: variablelib.VariableState
378-
for path, variable_state in statelib.to_flat_state(state):
379-
380-
if issubclass(variable_state.type, rnglib.RngState):
377+
variable: variablelib.Variable
378+
for path, variable in statelib.to_flat_state(state):
379+
if isinstance(variable, rnglib.RngState):
381380
# Don't return RNG states, since Linen doesn't have them.
382381
continue
383382

384383
try:
385-
collection = variablelib.variable_name_from_type(variable_state.type)
384+
collection = variablelib.variable_name_from_type(type(variable))
386385
except ValueError:
387-
collection = variable_state.type.__name__
386+
collection = type(variable).__name__
388387

389388
if collection not in _variables:
390389
_variables[collection] = {}
391390

392391
if (
393-
isinstance(variable_state, variablelib.VariableState)
394-
and not variable_state._var_metadata
392+
isinstance(variable, variablelib.Variable)
393+
and not variable._var_metadata
395394
):
396-
leaf = variable_state.value
395+
leaf = variable.value
397396
else:
398-
leaf = bridge_variables.to_linen_var(variable_state)
397+
leaf = bridge_variables.to_linen_var(variable)
399398

400399
_variables[collection][path] = leaf
401400

flax/nnx/bridge/variables.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def _variable_parents_count(t: type):
4343

4444

4545
class NNXMeta(struct.PyTreeNode, meta.AxisMetadata[A]):
46-
"""Default Flax metadata class for `nnx.VariableState`."""
46+
"""Default Flax metadata class for `nnx.Variable`."""
4747

4848
var_type: type[variablelib.Variable[tp.Any]] = struct.field(pytree_node=False)
4949
value: Any = struct.field(pytree_node=True)
@@ -65,14 +65,14 @@ def remove_axis(self, index: int, params: dict[Any, Any]) -> 'NNXMeta[A]':
6565

6666
def get_partition_spec(self) -> jax.sharding.PartitionSpec:
6767
"""Returns the ``Partitionspec`` for this partitioned value."""
68-
nnx_var = self.to_nnx_variable().to_state()
69-
return spmd.get_partition_spec(nnx_var).value
68+
nnx_var = self.to_nnx_variable()
69+
return spmd.get_partition_spec(nnx_var)
7070

7171
def to_nnx_variable(self) -> variablelib.Variable:
7272
return self.var_type(self.value, **self.metadata)
7373

7474

75-
def is_vanilla_variable(vs: variablelib.VariableState) -> bool:
75+
def is_vanilla_variable(vs: variablelib.Variable) -> bool:
7676
"""A variables state is vanilla if its metadata is essentially blank.
7777
7878
Returns False only if it has non-empty hooks or any non-built-in attribute.
@@ -86,7 +86,7 @@ def is_vanilla_variable(vs: variablelib.VariableState) -> bool:
8686
return True
8787

8888

89-
def to_linen_var(vs: variablelib.VariableState) -> meta.AxisMetadata:
89+
def to_linen_var(vs: variablelib.Variable) -> meta.AxisMetadata:
9090
metadata = vs.get_metadata()
9191
if 'linen_meta_type' in metadata:
9292
linen_type = metadata['linen_meta_type']
@@ -95,7 +95,7 @@ def to_linen_var(vs: variablelib.VariableState) -> meta.AxisMetadata:
9595
return linen_type(vs.value, **metadata)
9696
if is_vanilla_variable(vs):
9797
return vs.value
98-
return NNXMeta(vs.type, vs.value, metadata)
98+
return NNXMeta(type(vs), vs.value, metadata)
9999

100100

101101
def get_col_name(keypath: tp.Sequence[Any]) -> str:
@@ -150,10 +150,7 @@ def nnx_attrs_to_linen_vars(nnx_attrs: dict) -> dict:
150150
for kp, v in traversals.flatten_mapping(nnx_attrs).items():
151151
if isinstance(v, variablelib.Variable):
152152
col_name = variablelib.variable_name_from_type(type(v))
153-
v = to_linen_var(v.to_state())
154-
elif isinstance(v, variablelib.VariableState):
155-
col_name = variablelib.variable_name_from_type(v.type)
156-
v = to_linen_var(v)
153+
v = to_linen_var(v.copy())
157154
elif isinstance(v, graph.GraphDef):
158155
col_name = 'nnx' # an nnx.GraphDef for some ToLinen submodule
159156
else:

flax/nnx/bridge/wrappers.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,8 @@ class ToLinen(linen.Module):
239239
args: tp.Sequence = ()
240240
kwargs: tp.Mapping[str, tp.Any] = FrozenDict({})
241241
skip_rng: bool = False
242-
metadata_fn: tp.Callable[[variablelib.VariableState], tp.Any] | None = (
243-
bv.to_linen_var
242+
metadata_fn: tp.Callable[[variablelib.Variable], tp.Any] | None = (
243+
bv.to_linen_var
244244
)
245245

246246
@linen.compact
@@ -297,7 +297,7 @@ def _update_variables(self, module):
297297

298298
# group state by collection
299299
for path, leaf in nnx.to_flat_state(state):
300-
type_ = leaf.type if isinstance(leaf, nnx.VariableState) else type(leaf)
300+
type_ = type(leaf)
301301
collection = variablelib.variable_name_from_type(
302302
type_, allow_register=True
303303
)
@@ -310,7 +310,7 @@ def _update_variables(self, module):
310310
if self.is_mutable_collection(collection):
311311

312312
def _to_linen_var(x):
313-
if isinstance(x, nnx.VariableState):
313+
if isinstance(x, nnx.Variable):
314314
if self.metadata_fn:
315315
return self.metadata_fn(x)
316316
else:
@@ -319,22 +319,22 @@ def _to_linen_var(x):
319319

320320
collection_state = nnx.traversals.unflatten_mapping(flat_state)
321321
collection_state = jax.tree.map(
322-
_to_linen_var,
323-
collection_state,
324-
is_leaf=lambda x: isinstance(x, nnx.VariableState),
322+
_to_linen_var,
323+
collection_state,
324+
is_leaf=lambda x: isinstance(x, nnx.Variable),
325325
)
326326
for k, v in collection_state.items():
327327
self.put_variable(collection, k, v)
328328

329329

330330
def to_linen(
331-
nnx_class: tp.Callable[..., Module],
332-
*args,
333-
metadata_fn: (
334-
tp.Callable[[variablelib.VariableState], tp.Any] | None
335-
) = bv.to_linen_var,
336-
name: str | None = None,
337-
**kwargs,
331+
nnx_class: tp.Callable[..., Module],
332+
*args,
333+
metadata_fn: (
334+
tp.Callable[[variablelib.Variable], tp.Any] | None
335+
) = bv.to_linen_var,
336+
name: str | None = None,
337+
**kwargs,
338338
):
339339
"""Shortcut of `nnx.bridge.ToLinen` if user is not changing any of its default fields."""
340340
return ToLinen(

flax/nnx/extract.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def from_prefixes(
179179
def default_split_fn(
180180
ctx: graph.SplitContext, path: KeyPath, prefix: Prefix, leaf: Leaf
181181
) -> tp.Any:
182-
return NodeStates.from_split(*ctx.split(leaf))
182+
return NodeStates.from_split(*graph.pure(ctx.split(leaf)))
183183

184184

185185
def to_tree(

0 commit comments

Comments
 (0)