Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ sphinx-copybutton
matplotlib
penzai
scikit-learn
flax==0.10.6

# install jax-ai-stack from current directory
.
3 changes: 2 additions & 1 deletion docs/source/JAX_Vision_transformer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -879,7 +879,7 @@
"plt.show()\n",
"\n",
"\n",
"optimizer = nnx.ModelAndOptimizer(model, optax.sgd(lr_schedule, momentum, nesterov=True))"
"optimizer = nnx.Optimizer(model, optax.sgd(lr_schedule, momentum, nesterov=True))"
]
},
{
Expand Down Expand Up @@ -1268,6 +1268,7 @@
],
"metadata": {
"jupytext": {
"default_lexer": "ipython3",
"formats": "ipynb,md:myst"
},
"kernelspec": {
Expand Down
4 changes: 2 additions & 2 deletions docs/source/JAX_Vision_transformer.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.15.2
jupytext_version: 1.17.3
kernelspec:
display_name: Python 3 (ipykernel)
language: python
Expand Down Expand Up @@ -619,7 +619,7 @@ plt.xlim((0, num_epochs))
plt.show()


optimizer = nnx.ModelAndOptimizer(model, optax.sgd(lr_schedule, momentum, nesterov=True))
optimizer = nnx.Optimizer(model, optax.sgd(lr_schedule, momentum, nesterov=True))
```

Define a loss function with `optax.softmax_cross_entropy_with_integer_labels`:
Expand Down
5 changes: 3 additions & 2 deletions docs/source/JAX_basic_text_classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@
},
{
"cell_type": "code",
"execution_count": 31,
"execution_count": null,
"id": "6d9f4756-4e64-49d1-81dd-20c0e0480dd0",
"metadata": {},
"outputs": [],
Expand All @@ -528,7 +528,7 @@
"learning_rate = 0.0005\n",
"momentum = 0.9\n",
"\n",
"optimizer = nnx.ModelAndOptimizer(model, optax.adam(learning_rate, momentum))"
"optimizer = nnx.Optimizer(model, optax.adam(learning_rate, momentum))"
]
},
{
Expand Down Expand Up @@ -1025,6 +1025,7 @@
],
"metadata": {
"jupytext": {
"default_lexer": "ipython3",
"formats": "ipynb,md:myst"
},
"kernelspec": {
Expand Down
4 changes: 2 additions & 2 deletions docs/source/JAX_basic_text_classification.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.15.2
jupytext_version: 1.17.3
kernelspec:
display_name: Python 3 (ipykernel)
language: python
Expand Down Expand Up @@ -303,7 +303,7 @@ num_epochs = 10
learning_rate = 0.0005
momentum = 0.9

optimizer = nnx.ModelAndOptimizer(model, optax.adam(learning_rate, momentum))
optimizer = nnx.Optimizer(model, optax.adam(learning_rate, momentum))
```

```{code-cell} ipython3
Expand Down
3 changes: 2 additions & 1 deletion docs/source/JAX_examples_image_segmentation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1587,7 +1587,7 @@
"plt.show()\n",
"\n",
"\n",
"optimizer = nnx.ModelAndOptimizer(model, optax.adam(lr_schedule, momentum))"
"optimizer = nnx.Optimizer(model, optax.adam(lr_schedule, momentum))"
]
},
{
Expand Down Expand Up @@ -2542,6 +2542,7 @@
],
"metadata": {
"jupytext": {
"default_lexer": "ipython3",
"formats": "ipynb,md:myst"
},
"kernelspec": {
Expand Down
4 changes: 2 additions & 2 deletions docs/source/JAX_examples_image_segmentation.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.15.2
jupytext_version: 1.17.3
kernelspec:
display_name: Python 3 (ipykernel)
language: python
Expand Down Expand Up @@ -1067,7 +1067,7 @@ plt.xlim((0, num_epochs))
plt.show()


optimizer = nnx.ModelAndOptimizer(model, optax.adam(lr_schedule, momentum))
optimizer = nnx.Optimizer(model, optax.adam(lr_schedule, momentum))
```

Let us implement Jaccard loss and the loss function combining Cross-Entropy and Jaccard losses.
Expand Down
2 changes: 1 addition & 1 deletion docs/source/JAX_for_LLM_pretraining.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -978,7 +978,7 @@
],
"source": [
"model = create_model(rngs=nnx.Rngs(0))\n",
"optimizer = nnx.ModelAndOptimizer(model, optax.adam(1e-3))\n",
"optimizer = nnx.Optimizer(model, optax.adam(1e-3))\n",
"metrics = nnx.MultiMetric(\n",
" loss=nnx.metrics.Average('loss'),\n",
")\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/source/JAX_for_LLM_pretraining.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.15.2
jupytext_version: 1.17.3
kernelspec:
display_name: Python 3
name: python3
Expand Down Expand Up @@ -476,7 +476,7 @@ id: Ysl6CsfENeJN
outputId: 5dd06dca-f030-4927-a9b6-35d412da535c
---
model = create_model(rngs=nnx.Rngs(0))
optimizer = nnx.ModelAndOptimizer(model, optax.adam(1e-3))
optimizer = nnx.Optimizer(model, optax.adam(1e-3))
metrics = nnx.MultiMetric(
loss=nnx.metrics.Average('loss'),
)
Expand Down
3 changes: 2 additions & 1 deletion docs/source/JAX_for_PyTorch_users.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1272,7 +1272,7 @@
"learning_rate = 0.005\n",
"momentum = 0.9\n",
"\n",
"optimizer = nnx.ModelAndOptimizer(model, optax.adamw(learning_rate, momentum))"
"optimizer = nnx.Optimizer(model, optax.adamw(learning_rate, momentum))"
]
},
{
Expand Down Expand Up @@ -1565,6 +1565,7 @@
"provenance": []
},
"jupytext": {
"default_lexer": "ipython3",
"formats": "ipynb,md:myst"
},
"kernelspec": {
Expand Down
4 changes: 2 additions & 2 deletions docs/source/JAX_for_PyTorch_users.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.15.2
jupytext_version: 1.17.3
kernelspec:
display_name: Python 3 (ipykernel)
language: python
Expand Down Expand Up @@ -703,7 +703,7 @@ import optax
learning_rate = 0.005
momentum = 0.9

optimizer = nnx.ModelAndOptimizer(model, optax.adamw(learning_rate, momentum))
optimizer = nnx.Optimizer(model, optax.adamw(learning_rate, momentum))
```

```{code-cell} ipython3
Expand Down
3 changes: 2 additions & 1 deletion docs/source/JAX_image_captioning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1304,7 +1304,7 @@
"momentum = 0.9\n",
"total_steps = len(train_dataset) // train_batch_size\n",
"\n",
"optimizer = nnx.ModelAndOptimizer(\n",
"optimizer = nnx.Optimizer(\n",
" model, optax.sgd(learning_rate, momentum, nesterov=True), wrt=trainable_params_filter\n",
")"
]
Expand Down Expand Up @@ -2316,6 +2316,7 @@
],
"metadata": {
"jupytext": {
"default_lexer": "ipython3",
"formats": "ipynb,md:myst"
},
"kernelspec": {
Expand Down
4 changes: 2 additions & 2 deletions docs/source/JAX_image_captioning.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.15.2
jupytext_version: 1.17.3
kernelspec:
display_name: Python 3 (ipykernel)
language: python
Expand Down Expand Up @@ -925,7 +925,7 @@ learning_rate = 0.015
momentum = 0.9
total_steps = len(train_dataset) // train_batch_size

optimizer = nnx.ModelAndOptimizer(
optimizer = nnx.Optimizer(
model, optax.sgd(learning_rate, momentum, nesterov=True), wrt=trainable_params_filter
)
```
Expand Down
3 changes: 2 additions & 1 deletion docs/source/JAX_machine_translation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,7 @@
"outputs": [],
"source": [
"model = TransformerModel(sequence_length, vocab_size, embed_dim, latent_dim, num_heads, dropout_rate, rngs=rng)\n",
"optimizer = nnx.ModelAndOptimizer(model, optax.adamw(learning_rate))"
"optimizer = nnx.Optimizer(model, optax.adamw(learning_rate))"
]
},
{
Expand Down Expand Up @@ -1039,6 +1039,7 @@
],
"metadata": {
"jupytext": {
"default_lexer": "ipython3",
"formats": "ipynb,md:myst"
},
"kernelspec": {
Expand Down
4 changes: 2 additions & 2 deletions docs/source/JAX_machine_translation.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.15.2
jupytext_version: 1.17.3
kernelspec:
display_name: Python 3 (ipykernel)
language: python
Expand Down Expand Up @@ -452,7 +452,7 @@ def evaluate_model(epoch):

```{code-cell} ipython3
model = TransformerModel(sequence_length, vocab_size, embed_dim, latent_dim, num_heads, dropout_rate, rngs=rng)
optimizer = nnx.ModelAndOptimizer(model, optax.adamw(learning_rate))
optimizer = nnx.Optimizer(model, optax.adamw(learning_rate))
```

## Start the Training!
Expand Down
3 changes: 2 additions & 1 deletion docs/source/JAX_time_series_classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@
"learning_rate = 0.0005\n",
"momentum = 0.9\n",
"\n",
"optimizer = nnx.ModelAndOptimizer(model, optax.adam(learning_rate, momentum))"
"optimizer = nnx.Optimizer(model, optax.adam(learning_rate, momentum))"
]
},
{
Expand Down Expand Up @@ -1516,6 +1516,7 @@
],
"metadata": {
"jupytext": {
"default_lexer": "ipython3",
"formats": "ipynb,md:myst"
},
"kernelspec": {
Expand Down
4 changes: 2 additions & 2 deletions docs/source/JAX_time_series_classification.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.15.2
jupytext_version: 1.17.3
kernelspec:
display_name: jax-env
language: python
Expand Down Expand Up @@ -250,7 +250,7 @@ num_epochs = 300
learning_rate = 0.0005
momentum = 0.9

optimizer = nnx.ModelAndOptimizer(model, optax.adam(learning_rate, momentum))
optimizer = nnx.Optimizer(model, optax.adam(learning_rate, momentum))
```

We'll define a loss and logits computation function using Optax's
Expand Down
3 changes: 2 additions & 1 deletion docs/source/JAX_transformer_text_classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,7 @@
"learning_rate = 0.0001 # The learning rate.\n",
"momentum = 0.9 # Momentum for Adam.\n",
"\n",
"optimizer = nnx.ModelAndOptimizer(model, optax.adam(learning_rate, momentum))"
"optimizer = nnx.Optimizer(model, optax.adam(learning_rate, momentum))"
]
},
{
Expand Down Expand Up @@ -1321,6 +1321,7 @@
],
"metadata": {
"jupytext": {
"default_lexer": "ipython3",
"formats": "ipynb,md:myst"
},
"kernelspec": {
Expand Down
4 changes: 2 additions & 2 deletions docs/source/JAX_transformer_text_classification.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.15.2
jupytext_version: 1.17.3
kernelspec:
display_name: jax-env
language: python
Expand Down Expand Up @@ -377,7 +377,7 @@ num_epochs = 10 # Number of epochs during training.
learning_rate = 0.0001 # The learning rate.
momentum = 0.9 # Momentum for Adam.

optimizer = nnx.ModelAndOptimizer(model, optax.adam(learning_rate, momentum))
optimizer = nnx.Optimizer(model, optax.adam(learning_rate, momentum))
```

Next, we define the loss function - `compute_losses_and_logits()` - using `optax.softmax_cross_entropy_with_integer_labels`:
Expand Down
3 changes: 2 additions & 1 deletion docs/source/JAX_visualizing_models_metrics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@
"import jax\n",
"import optax\n",
"\n",
"optimizer = nnx.ModelAndOptimizer(model, optax.sgd(learning_rate=0.05))\n",
"optimizer = nnx.Optimizer(model, optax.sgd(learning_rate=0.05))\n",
"\n",
"def loss_fun(\n",
" model: nnx.Module,\n",
Expand Down Expand Up @@ -449,6 +449,7 @@
"provenance": []
},
"jupytext": {
"default_lexer": "ipython3",
"formats": "ipynb,md:myst"
},
"kernelspec": {
Expand Down
4 changes: 2 additions & 2 deletions docs/source/JAX_visualizing_models_metrics.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.15.2
jupytext_version: 1.17.3
kernelspec:
display_name: Python 3 (ipykernel)
language: python
Expand Down Expand Up @@ -145,7 +145,7 @@ In order to track loss across our training run, we've collected the loss functio
import jax
import optax

optimizer = nnx.ModelAndOptimizer(model, optax.sgd(learning_rate=0.05))
optimizer = nnx.Optimizer(model, optax.sgd(learning_rate=0.05))

def loss_fun(
model: nnx.Module,
Expand Down
2 changes: 1 addition & 1 deletion docs/source/digits_diffusion_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,7 @@
")\n",
"\n",
"# Optimizer configuration (AdamW) with gradient clipping.\n",
"optimizer = nnx.ModelAndOptimizer(model, optax.chain(\n",
"optimizer = nnx.Optimizer(model, optax.chain(\n",
" optax.clip_by_global_norm(0.5), # Gradient clipping.\n",
" optax.adamw(\n",
" learning_rate=schedule_fn,\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/source/digits_diffusion_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.15.2
jupytext_version: 1.17.3
kernelspec:
display_name: Python 3
name: python3
Expand Down Expand Up @@ -641,7 +641,7 @@ schedule_fn = optax.join_schedules(
)

# Optimizer configuration (AdamW) with gradient clipping.
optimizer = nnx.ModelAndOptimizer(model, optax.chain(
optimizer = nnx.Optimizer(model, optax.chain(
optax.clip_by_global_norm(0.5), # Gradient clipping.
optax.adamw(
learning_rate=schedule_fn,
Expand Down
Loading
Loading