Skip to content

mark shard_map as implemented in transforms guide #4738

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions docs_nnx/guides/transforms.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -406,11 +406,9 @@
"| Lift type | JAX transforms |\n",
"|------------------|-----------------------------------------|\n",
"| `StateAxes` | `vmap`, `pmap`, `scan` |\n",
"| `StateSharding` | `jit`, `shard_map`* |\n",
"| `StateSharding` | `jit`, `shard_map` |\n",
"| `DiffState` | `grad`, `value_and_grad`, `custom_vjp` |\n",
"\n",
"> **Note:** * Flax NNX `shard_map` has not been implemented yet at the time of writing this version of the document.\n",
"\n",
"To specify how to vectorize different sub-states of an object in `nnx.vmap`, the Flax team created a `nnx.StateAxes`. `StateAxes` maps a set of sub-states via Flax NNX [Filters](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) to their corresponding axes, and you can pass the `nnx.StateAxes` to `in_axes` and `out_axes` as if it/they were a pytree prefix.\n",
"\n",
"Let's use the previous `stateful_vector_dot` example and vectorize only the `nnx.Param` variables and broadcast the `count` variable so we only keep a single count for all the batch elements.\n",
Expand Down
4 changes: 1 addition & 3 deletions docs_nnx/guides/transforms.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,11 +200,9 @@ Certain JAX transforms allow the use of pytree prefixes to specify how different
| Lift type | JAX transforms |
|------------------|-----------------------------------------|
| `StateAxes` | `vmap`, `pmap`, `scan` |
| `StateSharding` | `jit`, `shard_map`* |
| `StateSharding` | `jit`, `shard_map` |
| `DiffState` | `grad`, `value_and_grad`, `custom_vjp` |

> **Note:** * Flax NNX `shard_map` has not been implemented yet at the time of writing this version of the document.

To specify how to vectorize different sub-states of an object in `nnx.vmap`, the Flax team created a `nnx.StateAxes`. `StateAxes` maps a set of sub-states via Flax NNX [Filters](https://flax.readthedocs.io/en/latest/guides/filters_guide.html) to their corresponding axes, and you can pass the `nnx.StateAxes` to `in_axes` and `out_axes` as if it/they were a pytree prefix.

Let's use the previous `stateful_vector_dot` example and vectorize only the `nnx.Param` variables and broadcast the `count` variable so we only keep a single count for all the batch elements.
Expand Down
Loading