|
| 1 | +Hijax (experimental) |
| 2 | +==================== |
| 3 | + |
| 4 | + |
| 5 | + |
| 6 | +---- |
| 7 | + |
| 8 | +Basic usage |
| 9 | +^^^^^^^^^^^^ |
| 10 | + |
| 11 | +.. testsetup:: |
| 12 | + |
| 13 | + import jax |
| 14 | + import jax.numpy as jnp |
| 15 | + |
| 16 | + current_mode = nnx.current_variable_mode() |
| 17 | + |
| 18 | +.. testcode:: |
| 19 | + |
| 20 | + from flax import nnx |
| 21 | + import optax |
| 22 | + |
| 23 | + nnx.variable_mode('mutable') # | 'ref' |
| 24 | + |
| 25 | + class Model(nnx.Module): |
| 26 | + def __init__(self, din, dmid, dout, rngs: nnx.Rngs): |
| 27 | + self.linear = nnx.Linear(din, dmid, rngs=rngs) |
| 28 | + self.bn = nnx.BatchNorm(dmid, rngs=rngs) |
| 29 | + self.dropout = nnx.Dropout(0.2) |
| 30 | + self.linear_out = nnx.Linear(dmid, dout, rngs=rngs) |
| 31 | + |
| 32 | + def __call__(self, x, rngs): |
| 33 | + x = nnx.relu(self.dropout(self.bn(self.linear(x)), rngs=rngs)) |
| 34 | + return self.linear_out(x) |
| 35 | + |
| 36 | + model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization |
| 37 | + optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param) |
| 38 | + |
| 39 | + @jax.jit |
| 40 | + def train_step(model, optimizer, rngs, x, y): |
| 41 | + graphdef, params, nondiff = nnx.split(model, nnx.Param, ...) |
| 42 | + def loss_fn(params): |
| 43 | + model = nnx.merge(graphdef, params, nondiff) |
| 44 | + return ((model(x, rngs) - y) ** 2).mean() |
| 45 | + loss, grads = jax.value_and_grad(loss_fn)(nnx.as_immutable(params)) |
| 46 | + optimizer.update(model, grads) # in-place updates |
| 47 | + return loss |
| 48 | + |
| 49 | + nnx.variable_mode(current_mode) # clean up for CI tests |
| 50 | + |
| 51 | + |
| 52 | +---- |
| 53 | + |
| 54 | +.. toctree:: |
| 55 | + :hidden: |
| 56 | + :maxdepth: 2 |
| 57 | + |
| 58 | + variable |
0 commit comments