Skip to content

Commit d3b8a8e

Browse files
committed
add HijaxVariable
1 parent 8103bff commit d3b8a8e

File tree

24 files changed

+1180
-1020
lines changed

24 files changed

+1180
-1020
lines changed

docs_nnx/api_reference/flax.nnx/graph.rst

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ graph
3131
.. autofunction:: find_duplicates
3232
.. autofunction:: pure
3333
.. autofunction:: as_ref
34-
.. autofunction:: as_hijax
35-
.. autofunction:: as_lojax
34+
.. autofunction:: as_mutable
35+
.. autofunction:: as_immutable
36+
.. autofunction:: as_pytree
3637
.. autofunction:: flatten
3738
.. autofunction:: unflatten

docs_nnx/guides/hijax.ipynb

Lines changed: 0 additions & 573 deletions
This file was deleted.

docs_nnx/hijax/index.rst

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
Hijax (experimental)
2+
====================
3+
4+
5+
6+
----
7+
8+
Basic usage
9+
^^^^^^^^^^^^
10+
11+
.. testsetup::
12+
13+
import jax
14+
import jax.numpy as jnp
15+
16+
current_mode = nnx.current_variable_mode()
17+
18+
.. testcode::
19+
20+
from flax import nnx
21+
import optax
22+
23+
nnx.variable_mode('mutable') # | 'ref'
24+
25+
class Model(nnx.Module):
26+
def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
27+
self.linear = nnx.Linear(din, dmid, rngs=rngs)
28+
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
29+
self.dropout = nnx.Dropout(0.2)
30+
self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)
31+
32+
def __call__(self, x, rngs):
33+
x = nnx.relu(self.dropout(self.bn(self.linear(x)), rngs=rngs))
34+
return self.linear_out(x)
35+
36+
model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization
37+
optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)
38+
39+
@jax.jit
40+
def train_step(model, optimizer, rngs, x, y):
41+
graphdef, params, nondiff = nnx.split(model, nnx.Param, ...)
42+
def loss_fn(params):
43+
model = nnx.merge(graphdef, params, nondiff)
44+
return ((model(x, rngs) - y) ** 2).mean()
45+
loss, grads = jax.value_and_grad(loss_fn)(nnx.as_immutable(params))
46+
optimizer.update(model, grads) # in-place updates
47+
return loss
48+
49+
nnx.variable_mode(current_mode) # clean up for CI tests
50+
51+
52+
----
53+
54+
.. toctree::
55+
:hidden:
56+
:maxdepth: 2
57+
58+
variable

0 commit comments

Comments
 (0)