diff --git a/docs_nnx/guides/transforms.ipynb b/docs_nnx/guides/transforms.ipynb index d6c758ad3..d4abb64a2 100644 --- a/docs_nnx/guides/transforms.ipynb +++ b/docs_nnx/guides/transforms.ipynb @@ -406,11 +406,9 @@ "| Lift type | JAX transforms |\n", "|------------------|-----------------------------------------|\n", "| `StateAxes` | `vmap`, `pmap`, `scan` |\n", - "| `StateSharding` | `jit`, `shard_map`* |\n", + "| `StateSharding` | `jit`, `shard_map` |\n", "| `DiffState` | `grad`, `value_and_grad`, `custom_vjp` |\n", "\n", - "> **Note:** * Flax NNX `shard_map` has not been implemented yet at the time of writing this version of the document.\n", - "\n", "To specify how to vectorize different sub-states of an object in `nnx.vmap`, the Flax team created a `nnx.StateAxes`. `StateAxes` maps a set of sub-states via Flax NNX [Filters](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) to their corresponding axes, and you can pass the `nnx.StateAxes` to `in_axes` and `out_axes` as if it/they were a pytree prefix.\n", "\n", "Let's use the previous `stateful_vector_dot` example and vectorize only the `nnx.Param` variables and broadcast the `count` variable so we only keep a single count for all the batch elements.\n", diff --git a/docs_nnx/guides/transforms.md b/docs_nnx/guides/transforms.md index 95961940e..e863ea7cc 100644 --- a/docs_nnx/guides/transforms.md +++ b/docs_nnx/guides/transforms.md @@ -200,11 +200,9 @@ Certain JAX transforms allow the use of pytree prefixes to specify how different | Lift type | JAX transforms | |------------------|-----------------------------------------| | `StateAxes` | `vmap`, `pmap`, `scan` | -| `StateSharding` | `jit`, `shard_map`* | +| `StateSharding` | `jit`, `shard_map` | | `DiffState` | `grad`, `value_and_grad`, `custom_vjp` | -> **Note:** * Flax NNX `shard_map` has not been implemented yet at the time of writing this version of the document. - To specify how to vectorize different sub-states of an object in `nnx.vmap`, the Flax team created a `nnx.StateAxes`. `StateAxes` maps a set of sub-states via Flax NNX [Filters](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) to their corresponding axes, and you can pass the `nnx.StateAxes` to `in_axes` and `out_axes` as if it/they were a pytree prefix. Let's use the previous `stateful_vector_dot` example and vectorize only the `nnx.Param` variables and broadcast the `count` variable so we only keep a single count for all the batch elements.