diff --git a/docs_nnx/guides/performance.md b/docs_nnx/guides/performance.md index ffe3deecb..cc16b4622 100644 --- a/docs_nnx/guides/performance.md +++ b/docs_nnx/guides/performance.md @@ -119,7 +119,7 @@ def jax_train_step(graphdef, state, x, y): for _ in range(10): x, y = jnp.ones((32, 2)), jnp.zeros((32, 3)) - state, loss = jax_train_step(graphdef, state, x, y) + loss, state = jax_train_step(graphdef, state, x, y) # update objects after training nnx.update((model, optimizer, metrics), state)