Skip to content

Commit a5eebe5

Browse files
author
Flax Authors
committed
Merge pull request #4741 from google:mutable-array-p2
PiperOrigin-RevId: 759812978
2 parents 65dfd46 + f50b4d9 commit a5eebe5

31 files changed

+1485
-618
lines changed

docs_nnx/api_reference/flax.nnx/transforms.rst

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,6 @@ transforms
33

44
.. automodule:: flax.nnx
55
.. currentmodule:: flax.nnx
6-
7-
.. autoclass:: Jit
8-
:members:
9-
.. autoclass:: Remat
10-
:members:
11-
.. autoclass:: Scan
12-
:members:
13-
.. autoclass:: Vmap
14-
:members:
15-
166
.. autofunction:: grad
177
.. autofunction:: jit
188
.. autofunction:: shard_map

docs_nnx/api_reference/flax.nnx/variables.rst

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@ variables
88
:members:
99
.. autoclass:: Cache
1010
:members:
11-
.. autoclass:: Empty
12-
:members:
1311
.. autoclass:: Intermediate
1412
:members:
1513
.. autoclass:: Param

examples/nnx_toy_examples/01_functional_api.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020

2121
from flax import nnx
2222

23-
X = np.linspace(0, 1, 100)[:, None]
24-
Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape)
23+
X = np.linspace(-jnp.pi, jnp.pi, 100)[:, None]
24+
Y = 0.8 * jnp.sin(X) + 0.1 + np.random.normal(0, 0.1, size=X.shape)
2525

2626

2727
def dataset(batch_size):
@@ -50,11 +50,8 @@ def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs):
5050
self.linear2 = Linear(dhidden, dout, rngs=rngs)
5151

5252
def __call__(self, x):
53-
self.count.value += 1
54-
x = self.linear1(x)
55-
x = jax.nn.relu(x)
56-
x = self.linear2(x)
57-
return x
53+
self.count[...] += 1
54+
return self.linear2(jax.nn.relu(self.linear1(x) * 0.5))
5855

5956

6057
graphdef, params, counts = nnx.split(

examples/nnx_toy_examples/mutable_array_basic.py

Lines changed: 5 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -23,42 +23,7 @@
2323
import numpy as np
2424

2525
from flax import nnx
26-
from flax.nnx.variablelib import is_mutable_array
27-
28-
29-
def mutable_like(path, x):
30-
return (
31-
isinstance(x, nnx.Variable) and x.mutable
32-
) or nnx.variablelib.is_mutable_array(x)
33-
34-
35-
def freeze(x, only: nnx.filterlib.Filter = mutable_like):
36-
freeze_filter = nnx.filterlib.to_predicate(only)
37-
mutable_arrays: set[int] = set()
38-
39-
def check_mutable_array(path, x):
40-
m_array_id = id(x)
41-
if m_array_id in mutable_arrays:
42-
path_str = jax.tree_util.keystr(path)
43-
raise ValueError(
44-
f'Found duplicate MutableArray found at path {path_str}: {x}'
45-
)
46-
mutable_arrays.add(m_array_id)
47-
48-
def _freeze_fn(jax_path, x):
49-
path = tuple(nnx.graph._key_path_to_key(part) for part in jax_path)
50-
if freeze_filter(path, x):
51-
if isinstance(x, nnx.Variable):
52-
check_mutable_array(jax_path, x.raw_value)
53-
return x.from_metadata(x[...], x.get_metadata().copy())
54-
elif nnx.variablelib.is_mutable_array(x):
55-
check_mutable_array(jax_path, x)
56-
return x[...]
57-
return x
58-
59-
return jax.tree.map_with_path(
60-
_freeze_fn, x, is_leaf=lambda x: isinstance(x, nnx.Variable)
61-
)
26+
6227

6328
X = np.linspace(-jnp.pi, jnp.pi, 100)[:, None]
6429
Y = 0.8 * jnp.sin(X) + 0.1 + np.random.normal(0, 0.1, size=X.shape)
@@ -94,21 +59,21 @@ def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs):
9459

9560
def __call__(self, x):
9661
self.count[...] += 1
97-
return self.linear2(jax.nn.gelu(self.linear1(x)) * 0.5)
62+
return self.linear2(jax.nn.relu(self.linear1(x)) * 0.5)
9863

9964

100-
model = MLP(din=1, dhidden=64, dout=1, rngs=nnx.Rngs(0))
65+
model = MLP(din=1, dhidden=32, dout=1, rngs=nnx.Rngs(0))
10166

10267

10368
@jax.jit
10469
def train_step(model, x, y):
105-
graphdef, params, counts = nnx.split(model, nnx.Param, Count)
70+
graphdef, params, counts = nnx.pure(nnx.split(model, nnx.Param, Count))
10671

10772
def loss_fn(params):
10873
model = nnx.merge(graphdef, params, counts)
10974
return jnp.mean((y - model(x)) ** 2)
11075

111-
grads = jax.grad(loss_fn)(freeze(params))
76+
grads = jax.grad(loss_fn)(nnx.freeze(params))
11277

11378
def sgd(w, g):
11479
w[...] -= 0.1 * g[...]

0 commit comments

Comments
 (0)