diff --git a/docs/learning_jax/README.md b/docs/learning_jax/README.md new file mode 100644 index 0000000..547cd59 --- /dev/null +++ b/docs/learning_jax/README.md @@ -0,0 +1,14 @@ +# Learning-JAX +Slide decks, coding exercises, and quick references for learning the JAX AI Stack. The coding exercises are designed to be runnable in a free Colab instance. + +For more comprehensive documentation please see the individual websites: + +* https://jaxstack.ai +* https://jax.dev +* https://flax.readthedocs.io +* https://orbax.readthedocs.io +* https://optax.readthedocs.io +* https://google-grain.readthedocs.io +* https://chex.readthedocs.io + +[Join our growing community on Discord](https://goo.gle/jax-community) and connect with other developers! diff --git a/docs/learning_jax/code-exercises/1 - JAX AI Stack.ipynb b/docs/learning_jax/code-exercises/1 - JAX AI Stack.ipynb new file mode 100644 index 0000000..6e1e8f6 --- /dev/null +++ b/docs/learning_jax/code-exercises/1 - JAX AI Stack.ipynb @@ -0,0 +1 @@ +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[{"file_id":"183UawZ8L3Tbm1ueDynDqO_TyGysgZ8rt","timestamp":1755114181793}],"toc_visible":true},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","source":["# Introduction\n","\n","**Welcome to the JAX AI Stack Exercises!**\n","\n","This notebook is designed to accompany the \"Leveraging the JAX AI Stack\" lecture. You'll get hands-on experience with core JAX concepts, Flax NNX for model building, Optax for optimization, and Orbax for checkpointing.\n","\n","The exercises will guide you through implementing key components, drawing parallels to PyTorch where appropriate, to solidify your understanding.\n","\n","Let's get started!"],"metadata":{"id":"AEYnLrsY27El"}},{"cell_type":"code","execution_count":null,"metadata":{"id":"OPA5MMD621LQ"},"outputs":[],"source":["# @title Setup: Install and Import Libraries\n","# Install necessary libraries\n","!pip install -Uq flax optax orbax-checkpoint chex\n","\n","import jax\n","import jax.numpy as jnp\n","import flax\n","from flax import nnx\n","import optax\n","import orbax.checkpoint as ocp # For Orbax\n","from typing import Any, Dict, Tuple # For type hints\n","\n"," # Helper to print PyTrees more nicely for demonstration\n"," import pprint\n","import os # For Orbax directory management\n","import shutil # For cleaning up Orbax directory\n","\n","print(f\"JAX version: {jax.__version__}\")\n","print(f\"Flax version: {flax.__version__}\")\n","print(f\"Optax version: {optax.__version__}\")\n","print(f\"Orbax version: {ocp.__version__}\")\n","\n","# Global JAX PRNG key for reproducibility in exercises\n","# Students can learn to split this key for different operations.\n","main_key = jax.random.key(0)"]},{"cell_type":"markdown","source":["## Exercise 1: JAX Core & NumPy API\n","\n","**Goal**: Get familiar with jax.numpy and JAX's functional programming style.\n","\n","### Instructions:\n","\n","1. Create two JAX arrays, a (a 2x2 matrix of random numbers) and b (a 2x2 matrix of ones) using jax.numpy (jnp). You'll need a jax.random.key for creating random numbers.\n","2. Perform element-wise addition of a and b.\n","3. Perform matrix multiplication of a and b.\n","4. Demonstrate JAX's immutability:\n"," - Store the Python id() of array a.\n"," - Perform an operation like a = a + 1.\n"," - Print the new id() of a and observe that it has changed, indicating a new array was created."],"metadata":{"id":"3gC7luR35tJd"}},{"cell_type":"code","source":["# Instructions for Exercise 1\n","key_ex1, main_key = jax.random.split(main_key) # Split the main key\n","\n","# 1. Create JAX arrays a and b\n","# TODO: Create array 'a' (2x2 random normal) and 'b' (2x2 ones)\n","a = None # Placeholder\n","b = None # Placeholder\n","\n","print(\"Array a:\\n\", a)\n","print(\"Array b:\\n\", b)\n","\n","# 2. Perform element-wise addition\n","# TODO: Add a and b\n","c = None # Placeholder\n","print(\"Element-wise sum c = a + b:\\n\", c)\n","\n","# 3. Perform matrix multiplication\n","# TODO: Matrix multiply a and b\n","d = None # Placeholder\n","print(\"Matrix product d = a @ b:\\n\", d)\n","\n","# 4. Demonstrate immutability\n","# original_a_id = id(a)\n","# print(f\"Original id(a): {original_a_id}\")\n","\n","# TODO: Perform an operation that reassigns 'a', e.g., a = a + 1\n","# a_new_ref = None # Placeholder\n","# new_a_id = id(a_new_ref)\n","# print(f\"New id(a) after 'a = a + 1': {new_a_id}\")\n","\n","# TODO: Check if original_a_id is different from new_a_id\n","# print(f\"IDs are different: {None}\") # Placeholder"],"metadata":{"id":"8Tq_WFzc5Ycl"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Solution 1: JAX Core & NumPy API\n","key_ex1_sol, main_key = jax.random.split(main_key)\n","\n","# 1. Create JAX arrays a and b\n","a_sol = jax.random.normal(key_ex1_sol, (2, 2))\n","b_sol = jnp.ones((2, 2))\n","\n","print(\"Array a:\\n\", a_sol)\n","print(\"Array b:\\n\", b_sol)\n","\n","# 2. Perform element-wise addition\n","c_sol = a_sol + b_sol\n","print(\"Element-wise sum c = a + b:\\n\", c_sol)\n","\n","# 3. Perform matrix multiplication\n","d_sol = jnp.dot(a_sol, b_sol) # or d = a @ b\n","print(\"Matrix product d = a @ b:\\n\", d_sol)\n","\n","# 4. Demonstrate immutability\n","original_a_id_sol = id(a_sol)\n","print(f\"Original id(a_sol): {original_a_id_sol}\")\n","\n","a_sol_new_ref = a_sol + 1 # This creates a new array and rebinds the Python variable.\n","new_a_id_sol = id(a_sol_new_ref)\n","print(f\"New id(a_sol_new_ref) after 'a_sol = a_sol + 1': {new_a_id_sol}\")\n","print(f\"IDs are different: {original_a_id_sol != new_a_id_sol}\")\n","print(\"This shows that the original array was not modified in-place; a new array was created.\")"],"metadata":{"id":"0p2HrUzH6NYQ"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Exercise 2: jax.jit (Just-In-Time Compilation)\n","\n","**Goal**: Understand how to use jax.jit to compile JAX functions for performance.\n","\n","### Instructions:\n","\n","1. Define a Python function compute_heavy_stuff(x, w, b) that performs a sequence of jnp operations:\n"," - y = jnp.dot(x, w)\n"," - y = y + b\n"," - y = jnp.tanh(y)\n"," - result = jnp.sum(y)\n"," - Return result.\n","2. Create a JIT-compiled version of this function, fast_compute_heavy_stuff, using jax.jit.\n","3. Create some large dummy JAX arrays for x, w, and b.\n","4. Call both the original and JIT-compiled functions with the dummy data.\n","5. (Optional) Use the `%timeit` magic command in Colab (in separate cells) to compare their execution speeds. Remember that the first call to a JIT-compiled function includes compilation time."],"metadata":{"id":"MK4rErEp6WPx"}},{"cell_type":"code","source":["# Instructions for Exercise 2\n","key_ex2_main, main_key = jax.random.split(main_key)\n","key_ex2_x, key_ex2_w, key_ex2_b = jax.random.split(key_ex2_main, 3)\n","\n","# 1. Define the Python function\n","def compute_heavy_stuff(x, w, b):\n"," # TODO: Implement the operations\n"," y1 = None # Placeholder\n"," y2 = None # Placeholder\n"," y3 = None # Placeholder\n"," result = None # Placeholder\n"," return result\n","\n","# 2. Create a JIT-compiled version\n","# TODO: Use jax.jit to compile compute_heavy_stuff\n","fast_compute_heavy_stuff = None # Placeholder\n","\n","# 3. Create dummy data\n","dim1, dim2, dim3 = 500, 1000, 500\n","x_data = jax.random.normal(key_ex2_x, (dim1, dim2))\n","w_data = jax.random.normal(key_ex2_w, (dim2, dim3))\n","b_data = jax.random.normal(key_ex2_b, (dim3,))\n","\n","# 4. Call both functions\n","result_original = None # Placeholder compute_heavy_stuff(x_data, w_data, b_data)\n","result_fast_first_call = None # Placeholder fast_compute_heavy_stuff(x_data, w_data, b_data) # First call (compiles)\n","result_fast_second_call = None # Placeholder fast_compute_heavy_stuff(x_data, w_data, b_data) # Second call (uses compiled)\n","\n","print(f\"Result (original): {result_original}\")\n","print(f\"Result (fast, 1st call): {result_fast_first_call}\")\n","print(f\"Result (fast, 2nd call): {result_fast_second_call}\")\n","\n","# if result_original is not None and result_fast_first_call is not None:\n","# assert jnp.allclose(result_original, result_fast_first_call), \"Results should match!\"\n","# print(\"\\nResults from original and JIT-compiled functions match.\")\n","\n","# 5. Optional: Timing (use %timeit in separate cells for accuracy)\n","# print(\"\\nTo see the speed difference, run these in separate cells:\")\n","# print(\"%timeit compute_heavy_stuff(x_data, w_data, b_data).block_until_ready()\")\n","# print(\"%timeit fast_compute_heavy_stuff(x_data, w_data, b_data).block_until_ready()\")"],"metadata":{"id":"SNwAyNyO6SM3"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Solution 2: `jax.jit` (Just-In-Time Compilation)\n","key_ex2_sol_main, main_key = jax.random.split(main_key)\n","key_ex2_sol_x, key_ex2_sol_w, key_ex2_sol_b = jax.random.split(key_ex2_sol_main, 3)\n","\n","# 1. Define the Python function\n","def compute_heavy_stuff_sol(x, w, b):\n"," y = jnp.dot(x, w)\n"," y = y + b\n"," y = jnp.tanh(y)\n"," result = jnp.sum(y)\n"," return result\n","\n","# 2. Create a JIT-compiled version\n","fast_compute_heavy_stuff_sol = jax.jit(compute_heavy_stuff_sol)\n","\n","# 3. Create dummy data\n","dim1_sol, dim2_sol, dim3_sol = 500, 1000, 500\n","x_data_sol = jax.random.normal(key_ex2_sol_x, (dim1_sol, dim2_sol))\n","w_data_sol = jax.random.normal(key_ex2_sol_w, (dim2_sol, dim3_sol))\n","b_data_sol = jax.random.normal(key_ex2_sol_b, (dim3_sol,))\n","\n","# 4. Call both functions\n","# Call original once to ensure it's not timed with any JAX overhead if it were the first JAX op\n","result_original_sol = compute_heavy_stuff_sol(x_data_sol, w_data_sol, b_data_sol).block_until_ready()\n","\n","# First call to JITed function includes compilation time\n","result_fast_sol_first_call = fast_compute_heavy_stuff_sol(x_data_sol, w_data_sol, b_data_sol).block_until_ready()\n","\n","# Subsequent calls use the cached compiled code\n","result_fast_sol_second_call = fast_compute_heavy_stuff_sol(x_data_sol, w_data_sol, b_data_sol).block_until_ready()\n","\n","print(f\"Result (original): {result_original_sol}\")\n","print(f\"Result (fast, 1st call): {result_fast_sol_first_call}\")\n","print(f\"Result (fast, 2nd call): {result_fast_sol_second_call}\")\n","\n","assert jnp.allclose(result_original_sol, result_fast_sol_first_call), \"Results should match!\"\n","print(\"\\nResults from original and JIT-compiled functions match.\")\n","\n","# 5. Optional: Timing\n","# To accurately measure, run these in separate Colab cells:\n","# Cell 1:\n","# %timeit compute_heavy_stuff_sol(x_data_sol, w_data_sol, b_data_sol).block_until_ready()\n","# Cell 2:\n","# %timeit fast_compute_heavy_stuff_sol(x_data_sol, w_data_sol, b_data_sol).block_until_ready()\n","# You should observe that the JIT-compiled version is significantly faster after the initial compilation.\n","print(\"\\nTo see the speed difference, run the %timeit commands (provided in comments above) in separate cells.\")"],"metadata":{"id":"xOLQxFay61ls"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Exercise 3: jax.grad (Automatic Differentiation)\n","\n","**Goal**: Learn to use jax.grad to compute gradients of functions.\n","\n","### Instructions:\n","\n","1. Define a Python function scalar_loss(params, x, y_true) that:\n"," - Takes a dictionary params with keys 'w' and 'b'.\n"," - Computes y_pred = params['w'] * x + params['b'].\n"," - Returns a scalar loss, e.g., jnp.mean((y_pred - y_true)**2).\n","2. Use jax.grad to create a new function, compute_gradients, that computes the gradient of scalar_loss with respect to its first argument (params).\n","3. Initialize some dummy params, x_input, and y_target values.\n","4. Call compute_gradients to get the gradients. Print the gradients."],"metadata":{"id":"MNZqLNB57CpS"}},{"cell_type":"code","source":["# Instructions for Exercise 3\n","\n","# 1. Define the scalar_loss function\n","def scalar_loss(params: Dict[str, jnp.ndarray], x: jnp.ndarray, y_true: jnp.ndarray) -> jnp.ndarray:\n"," # TODO: Implement the prediction and loss calculation\n"," y_pred = None # Placeholder\n"," loss = None # Placeholder\n"," return loss\n","\n","# 2. Create the gradient function using jax.grad\n","# TODO: Gradient of scalar_loss w.r.t. 'params' (argnums=0)\n","compute_gradients = None # Placeholder\n","\n","# 3. Initialize dummy data\n","params_init = {'w': jnp.array(2.0), 'b': jnp.array(1.0)}\n","x_input_data = jnp.array([1.0, 2.0, 3.0])\n","y_target_data = jnp.array([7.0, 9.0, 11.0]) # Targets for y = 3x + 4 (to make non-zero loss with init_params)\n","\n","# 4. Call the gradient function\n","gradients = None # Placeholder compute_gradients(params_init, x_input_data, y_target_data)\n","print(\"Initial params:\", params_init)\n","print(\"Gradients w.r.t params:\\n\", gradients)\n","\n","# Expected gradients (manual calculation for y_pred = wx+b, loss = mean((y_pred - y_true)^2)):\n","# dL/dw = mean(2 * (wx+b - y_true) * x)\n","# dL/db = mean(2 * (wx+b - y_true) * 1)\n","# For params_init={'w': 2.0, 'b': 1.0}, x=[1,2,3], y_true=[7,9,11]\n","# x=1: y_pred = 2*1+1 = 3. Error = 3-7 = -4. dL/dw_i_term = 2*(-4)*1 = -8. dL/db_i_term = 2*(-4)*1 = -8\n","# x=2: y_pred = 2*2+1 = 5. Error = 5-9 = -4. dL/dw_i_term = 2*(-4)*2 = -16. dL/db_i_term = 2*(-4)*1 = -8\n","# x=3: y_pred = 2*3+1 = 7. Error = 7-11 = -4. dL/dw_i_term = 2*(-4)*3 = -24. dL/db_i_term = 2*(-4)*1 = -8\n","# Mean gradients: dL/dw = (-8-16-24)/3 = -48/3 = -16. dL/db = (-8-8-8)/3 = -24/3 = -8.\n","# if gradients is not None:\n","# assert jnp.isclose(gradients['w'], -16.0)\n","# assert jnp.isclose(gradients['b'], -8.0)\n","# print(\"\\nGradients match expected values.\")"],"metadata":{"id":"g8S-6snP69KI"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Solution 3: `jax.grad` (Automatic Differentiation)\n","\n","# 1. Define the scalar_loss function\n","def scalar_loss_sol(params: Dict[str, jnp.ndarray], x: jnp.ndarray, y_true: jnp.ndarray) -> jnp.ndarray:\n"," y_pred = params['w'] * x + params['b']\n"," loss = jnp.mean((y_pred - y_true)**2)\n"," return loss\n","\n","# 2. Create the gradient function using jax.grad\n","# Gradient of scalar_loss w.r.t. 'params' (which is the 0-th argument)\n","compute_gradients_sol = jax.grad(scalar_loss_sol, argnums=0)\n","\n","# 3. Initialize dummy data\n","params_init_sol = {'w': jnp.array(2.0), 'b': jnp.array(1.0)}\n","x_input_data_sol = jnp.array([1.0, 2.0, 3.0])\n","y_target_data_sol = jnp.array([7.0, 9.0, 11.0])\n","\n","# 4. Call the gradient function\n","gradients_sol = compute_gradients_sol(params_init_sol, x_input_data_sol, y_target_data_sol)\n","print(\"Initial params:\", params_init_sol)\n","print(\"Gradients w.r.t params:\\n\", pprint.pformat(gradients_sol))\n","\n","# Verify with expected values (calculated in instructions)\n","expected_dL_dw = -16.0\n","expected_dL_db = -8.0\n","assert jnp.isclose(gradients_sol['w'], expected_dL_dw), f\"Grad w.r.t 'w' is {gradients_sol['w']}, expected {expected_dL_dw}\"\n","assert jnp.isclose(gradients_sol['b'], expected_dL_db), f\"Grad w.r.t 'b' is {gradients_sol['b']}, expected {expected_dL_db}\"\n","print(\"\\nGradients match expected values.\")"],"metadata":{"id":"jcjiql4O7ZQy"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Exercise 4: jax.vmap (Automatic Vectorization)\n","\n","**Goal**: Use jax.vmap to automatically batch operations.\n","\n","### Instructions:\n","\n","1. Define a function apply_affine(vector, matrix, bias) that takes a single 1D vector, a 2D matrix, and a 1D bias. It should compute jnp.dot(matrix, vector) + bias.\n","2. You have a batch of vectors (a 2D array where each row is a vector), but a single matrix and a single bias that should be applied to each vector in the batch.\n","3. Use jax.vmap to create batched_apply_affine that efficiently applies apply_affine to each vector in the batch.\n"," - Hint: in_axes for jax.vmap should specify 0 for the batched vector argument, and None for matrix and bias as they are not batched (broadcasted). The out_axes should be 0 to indicate the output is batched along the first axis.\n","4. Test batched_apply_affine with sample data."],"metadata":{"id":"XWoB6bD-7g2M"}},{"cell_type":"code","source":["# Instructions for Exercise 4\n","key_ex4_main, main_key = jax.random.split(main_key)\n","key_ex4_vec, key_ex4_mat, key_ex4_bias = jax.random.split(key_ex4_main, 3)\n","\n","# 1. Define apply_affine for a single vector\n","def apply_affine(vector: jnp.ndarray, matrix: jnp.ndarray, bias: jnp.ndarray) -> jnp.ndarray:\n"," # TODO: Compute jnp.dot(matrix, vector) + bias\n"," result = None # Placeholder\n"," return result\n","\n","# 2. Prepare data\n","batch_size = 4\n","input_features = 3\n","output_features = 2\n","\n","# batch_of_vectors: (batch_size, input_features)\n","# single_matrix: (output_features, input_features)\n","# single_bias: (output_features,)\n","batch_of_vectors = jax.random.normal(key_ex4_vec, (batch_size, input_features))\n","single_matrix = jax.random.normal(key_ex4_mat, (output_features, input_features))\n","single_bias = jax.random.normal(key_ex4_bias, (output_features,))\n","\n","\n","# 3. Use jax.vmap to create batched_apply_affine\n","# TODO: Specify in_axes correctly: vector is batched, matrix and bias are not. out_axes should be 0.\n","batched_apply_affine = None # Placeholder jax.vmap(apply_affine, in_axes=(..., ... , ...), out_axes=...)\n","\n","\n","# 4. Test batched_apply_affine\n","result_vmap = None # Placeholder batched_apply_affine(batch_of_vectors, single_matrix, single_bias)\n","print(\"Batch of vectors shape:\", batch_of_vectors.shape)\n","print(\"Single matrix shape:\", single_matrix.shape)\n","print(\"Single bias shape:\", single_bias.shape)\n","if result_vmap is not None:\n"," print(\"Result using vmap shape:\", result_vmap.shape) # Expected: (batch_size, output_features)\n","\n"," # For comparison, a manual loop (less efficient):\n"," # manual_results = []\n"," # for i in range(batch_size):\n"," # manual_results.append(apply_affine(batch_of_vectors[i], single_matrix, single_bias))\n"," # result_manual_loop = jnp.stack(manual_results)\n"," # assert jnp.allclose(result_vmap, result_manual_loop)\n"," # print(\"vmap result matches manual loop result.\")\n","else:\n"," print(\"result_vmap is None.\")"],"metadata":{"id":"vA9mu1si7dii"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Solution 4: `jax.vmap` (Automatic Vectorization)\n","key_ex4_sol_main, main_key = jax.random.split(main_key)\n","key_ex4_sol_vec, key_ex4_sol_mat, key_ex4_sol_bias = jax.random.split(key_ex4_sol_main, 3)\n","\n","# 1. Define apply_affine for a single vector\n","def apply_affine_sol(vector: jnp.ndarray, matrix: jnp.ndarray, bias: jnp.ndarray) -> jnp.ndarray:\n"," return jnp.dot(matrix, vector) + bias\n","\n","# 2. Prepare data\n","batch_size_sol = 4\n","input_features_sol = 3\n","output_features_sol = 2\n","\n","batch_of_vectors_sol = jax.random.normal(key_ex4_sol_vec, (batch_size_sol, input_features_sol))\n","single_matrix_sol = jax.random.normal(key_ex4_sol_mat, (output_features_sol, input_features_sol))\n","single_bias_sol = jax.random.normal(key_ex4_sol_bias, (output_features_sol,))\n","\n","# 3. Use jax.vmap to create batched_apply_affine\n","# Vector is batched along axis 0, matrix and bias are not batched (broadcasted).\n","# out_axes=0 means the output will also be batched along its first axis.\n","batched_apply_affine_sol = jax.vmap(apply_affine_sol, in_axes=(0, None, None), out_axes=0)\n","\n","# 4. Test batched_apply_affine\n","result_vmap_sol = batched_apply_affine_sol(batch_of_vectors_sol, single_matrix_sol, single_bias_sol)\n","print(\"Batch of vectors shape:\", batch_of_vectors_sol.shape)\n","print(\"Single matrix shape:\", single_matrix_sol.shape)\n","print(\"Single bias shape:\", single_bias_sol.shape)\n","print(\"Result using vmap shape:\", result_vmap_sol.shape) # Expected: (batch_size, output_features)\n","assert result_vmap_sol.shape == (batch_size_sol, output_features_sol)\n","\n","# For comparison, a manual loop (less efficient):\n","manual_results_sol = []\n","for i in range(batch_size_sol):\n"," manual_results_sol.append(apply_affine_sol(batch_of_vectors_sol[i], single_matrix_sol, single_bias_sol))\n","result_manual_loop_sol = jnp.stack(manual_results_sol)\n","\n","assert jnp.allclose(result_vmap_sol, result_manual_loop_sol)\n","print(\"\\nvmap result matches manual loop result, demonstrating correct vectorization.\")"],"metadata":{"id":"q1QkKEtF76yo"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Exercise 5: Flax NNX - Defining a Model\n","\n","**Goal**: Learn to define a simple neural network model using Flax NNX.\n","\n","### Instructions:\n","\n","1. Define a Flax NNX model class SimpleNNXModel that inherits from nnx.Module.\n","2. In its __init__, define one nnx.Linear layer. The layer should take din (input features) and dout (output features) as arguments. Remember to pass the rngs argument to nnx.Linear for parameter initialization (e.g., rngs=rngs).\n","3. Implement the __call__ method (the forward pass) which takes an input x and passes it through the linear layer.\n","4. Instantiate your SimpleNNXModel. You'll need to create an nnx.Rngs object using a JAX PRNG key (e.g., nnx.Rngs(params=jax.random.key(seed))). The key name params is conventional for nnx.Linear.\n","5. Test your model instance with a dummy input batch. Print the output and the model's state (parameters) using nnx.display()."],"metadata":{"id":"3LAlhdzq8D_S"}},{"cell_type":"code","source":["# Instructions for Exercise 5\n","key_ex5_model_init, main_key = jax.random.split(main_key)\n","\n","# 1. & 2. & 3. Define the SimpleNNXModel\n","class SimpleNNXModel(nnx.Module):\n"," def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):\n"," # TODO: Define an nnx.Linear layer named 'dense_layer'\n"," # self.dense_layer = nnx.Linear(...)\n"," self.some_attribute = None # Placeholder, remove later\n"," pass # Remove this placeholder if class is not empty\n","\n"," def __call__(self, x: jnp.ndarray) -> jnp.ndarray:\n"," # TODO: Pass input x through the dense_layer\n"," # return self.dense_layer(x)\n"," return x # Placeholder\n","\n","# 4. Instantiate the model\n","model_din = 3\n","model_dout = 2\n","# TODO: Create nnx.Rngs for parameter initialization. Use 'params' as the key name.\n","model_rngs = None # Placeholder nnx.Rngs(params=key_ex5_model_init)\n","my_model = None # Placeholder SimpleNNXModel(din=model_din, dout=model_dout, rngs=model_rngs)\n","\n","# 5. Test with dummy data\n","dummy_batch_size = 4\n","dummy_input_ex5 = jnp.ones((dummy_batch_size, model_din))\n","\n","model_output = None # Placeholder\n","if my_model is not None:\n"," model_output = my_model(dummy_input_ex5)\n"," print(f\"Model output shape: {model_output.shape}\")\n"," print(f\"Model output:\\n{model_output}\")\n","\n"," model_state = my_model.get_state()\n"," print(f\"\\nModel state (parameters, etc.):\")\n"," pprint.pprint(model_state)\n","else:\n"," print(\"my_model is None.\")"],"metadata":{"id":"BzUjMHll7--R"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Solution 5: Flax NNX - Defining a Model\n","key_ex5_sol_model_init, main_key = jax.random.split(main_key)\n","\n","# 1. & 2. & 3. Define the SimpleNNXModel\n","class SimpleNNXModel_Sol(nnx.Module): # Renamed for solution cell\n"," def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):\n"," # nnx.Linear will use the 'params' key from rngs by default for its parameters\n"," self.dense_layer = nnx.Linear(din, dout, rngs=rngs)\n","\n"," def __call__(self, x: jnp.ndarray) -> jnp.ndarray:\n"," return self.dense_layer(x)\n","\n","# 4. Instantiate the model\n","model_din_sol = 3\n","model_dout_sol = 2\n","# Create nnx.Rngs for parameter initialization.\n","# 'params' is the default key nnx.Linear looks for in the rngs object.\n","model_rngs_sol = nnx.Rngs(params=key_ex5_sol_model_init)\n","my_model_sol = SimpleNNXModel_Sol(din=model_din_sol, dout=model_dout_sol, rngs=model_rngs_sol)\n","\n","# 5. Test with dummy data\n","dummy_batch_size_sol = 4\n","dummy_input_ex5_sol = jnp.ones((dummy_batch_size_sol, model_din_sol))\n","\n","model_output_sol = my_model_sol(dummy_input_ex5_sol)\n","print(f\"Model output shape: {model_output_sol.shape}\")\n","print(f\"Model output:\\n{model_output_sol}\")\n","\n","# model_state_sol = my_model_sol.get_state()\n","_, model_state_sol = nnx.split(my_model_sol)\n","print(f\"\\nModel state (parameters, etc.):\")\n","nnx.display(model_state_sol)\n","\n","# Check that parameters are present\n","assert 'dense_layer' in model_state_sol, \"Key 'dense_layer' not in model_state\"\n","assert 'kernel' in model_state_sol['dense_layer'], \"Key 'kernel' not in model_state['dense_layer']\"\n","assert 'bias' in model_state_sol['dense_layer'], \"Key 'bias' not in model_state['dense_layer']\"\n","print(\"\\nModel parameters (kernel and bias for dense_layer) are present in the state.\")"],"metadata":{"id":"QbBqSZse8V9y"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Exercise 6: Optax & Flax NNX - Creating an Optimizer\n","\n","**Goal**: Set up an Optax optimizer and wrap it with nnx.Optimizer for use with a Flax NNX model.\n","\n","### Instructions:\n","1. Use the SimpleNNXModel_Sol class and an instance my_model_sol from the previous exercise's solution. (If running standalone, re-instantiate it).\n","2. Create an Optax optimizer, for example, optax.adam with a learning rate of 0.001.\n","3. Create an nnx.Optimizer instance. This wrapper links the Optax optimizer with your Flax NNX model (my_model_sol).\n","4. Print the nnx.Optimizer instance and its state attribute to see the initialized optimizer state (e.g., Adam's momentum terms)."],"metadata":{"id":"i4kuv2IH-FbA"}},{"cell_type":"code","source":["# Instructions for Exercise 6\n","\n","# 1. Assume my_model_sol is available from Exercise 5 solution\n","# (If running standalone, re-instantiate it)\n","if 'my_model_sol' not in globals():\n"," print(\"Re-initializing model from Ex5 solution for Ex6.\")\n"," key_ex6_model_init, main_key = jax.random.split(main_key)\n"," _model_din_ex6 = 3\n"," _model_dout_ex6 = 2\n"," _model_rngs_ex6 = nnx.Rngs(params=key_ex6_model_init)\n"," # Use solution class name if defined, otherwise student's class name\n"," _ModelClass = SimpleNNXModel_Sol if 'SimpleNNXModel_Sol' in globals() else SimpleNNXModel\n"," model_for_opt = _ModelClass(din=_model_din_ex6, dout=_model_dout_ex6, rngs=_model_rngs_ex6)\n"," print(\"Model for optimizer created.\")\n","else:\n"," model_for_opt = my_model_sol # Use the one from previous solution\n"," print(\"Using model 'my_model_sol' from previous exercise for 'model_for_opt'.\")\n","\n","\n","# 2. Create an Optax optimizer\n","learning_rate = 0.001\n","# TODO: Create an optax.adam optimizer transform\n","optax_tx = None # Placeholder optax.adam(...)\n","\n","# 3. Create an nnx.Optimizer wrapper\n","# TODO: Wrap the model (model_for_opt) and the optax transform (optax_tx)\n","# The `wrt` argument is now required to specify what to differentiate with respect to.\n","nnx_optimizer = None # Placeholder nnx.Optimizer(...)\n","\n","# 4. Print the optimizer and its state\n","print(\"\\nFlax NNX Optimizer wrapper:\")\n","nnx.display(nnx_optimizer)\n","\n","print(\"\\nInitial Optimizer State (Optax state, e.g., Adam's momentum):\")\n","if nnx_optimizer is not None and hasattr(nnx_optimizer, 'opt_state'):\n"," pprint.pprint(nnx_optimizer.state)\n"," # if hasattr(nnx_optimizer, 'opt_state'):\n"," # adam_state = nnx_optimizer.opt_state\n"," # assert len(adam_state) > 0 and hasattr(adam_state[0], 'count')\n"," # print(\"\\nOptimizer state structure looks plausible for Adam.\")\n","else:\n"," print(\"nnx_optimizer or its state is None or not structured as expected.\")"],"metadata":{"id":"ytaIj3xK8ZMI"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Solution 6: Optax & Flax NNX - Creating an Optimizer\n","\n","# 1. Use my_model_sol from Exercise 5 solution\n","# If not run sequentially, ensure my_model_sol is defined:\n","if 'my_model_sol' not in globals():\n"," print(\"Re-initializing model from Ex5 solution for Ex6.\")\n"," key_ex6_sol_model_init, main_key = jax.random.split(main_key)\n"," _model_din_sol_ex6 = 3\n"," _model_dout_sol_ex6 = 2\n"," _model_rngs_sol_ex6 = nnx.Rngs(params=key_ex6_sol_model_init)\n"," # Ensure SimpleNNXModel_Sol is used\n"," my_model_sol = SimpleNNXModel_Sol(din=_model_din_sol_ex6, dout=_model_dout_sol_ex6, rngs=_model_rngs_sol_ex6)\n"," print(\"Model for optimizer re-created as 'my_model_sol'.\")\n","else:\n"," print(\"Using model 'my_model_sol' from previous exercise.\")\n","\n","\n","# 2. Create an Optax optimizer\n","learning_rate_sol = 0.001\n","# Create an optax.adam optimizer transform\n","optax_tx_sol = optax.adam(learning_rate=learning_rate_sol)\n","\n","# 3. Create an nnx.Optimizer wrapper\n","# This links the model and the Optax optimizer.\n","# The optimizer state will be initialized based on the model's parameters.\n","nnx_optimizer_sol = nnx.Optimizer(my_model_sol, optax_tx_sol, wrt=nnx.Param)\n","\n","# 4. Print the optimizer and its state\n","print(\"\\nFlax NNX Optimizer wrapper:\")\n","nnx.display(nnx_optimizer_sol) # Shows the model it's associated with and the Optax transform\n","\n","print(\"\\nInitial Optimizer State (Optax state, e.g., Adam's momentum):\")\n","# nnx.Optimizer stores the actual Optax state in its .opt_state attribute.\n","# This state is a PyTree that matches the structure of the model's parameters.\n","pprint.pprint(nnx_optimizer_sol.opt_state)\n","\n","# Verify the structure of the optimizer state for Adam (count, mu, nu for each param)\n","assert hasattr(nnx_optimizer_sol, 'opt_state'), \"Optax opt_state not found in nnx.Optimizer\"\n","# The opt_state is a tuple, typically (CountState(), ScaleByAdamState()) for adam\n","adam_optax_internal_state = nnx_optimizer_sol.opt_state\n","assert len(adam_optax_internal_state) > 0 and hasattr(adam_optax_internal_state[0], 'count'), \"Adam 'count' state not found.\"\n","# The second element of the tuple is often where parameter-specific states like mu and nu reside\n","if len(adam_optax_internal_state) > 1 and hasattr(adam_optax_internal_state[1], 'mu'):\n"," param_specific_state = adam_optax_internal_state[1]\n"," assert 'dense_layer' in param_specific_state.mu and 'kernel' in param_specific_state.mu['dense_layer'], \"Adam 'mu' state for kernel not found.\"\n"," print(\"\\nOptimizer state structure looks correct for Adam.\")\n","else:\n"," print(\"\\nWarning: Optimizer state structure for Adam might be different or not fully verified.\")"],"metadata":{"id":"f1ccATgB-Zed"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Exercise 7: Training Step with Flax NNX and Optax\n","\n","**Goal**: Implement a complete JIT-compiled training step for a Flax NNX model using Optax.\n","\n","### Instructions:\n","\n","1. You'll need:\n"," - An instance of your model class (e.g., my_model_sol from Ex 5/6 solution).\n"," - An instance of nnx.Optimizer (e.g., nnx_optimizer_sol from Ex 6 solution).\n","2. Define a train_step function that is decorated with @nnx.jit. This function should take the model, optimizer, input x_batch, and target y_batch as arguments.\n","3. Inside train_step:\n"," - Define an inner loss_fn_for_grad. This function must take the model as its first argument. Inside, it computes the model's predictions for x_batch and then calculates the mean squared error (MSE) against y_batch.\n"," - Use nnx.value_and_grad(loss_fn_for_grad)(model_arg) to compute both the loss value and the gradients with respect to the model passed to loss_fn_for_grad. (model_arg is the model instance passed into train_step).\n"," - Update the model's parameters (and the optimizer's state) using optimizer_arg.update(model_arg, grads). The update method takes the model and gradients, and updates the model's state in-place.\n"," - Return the computed loss_value.\n","4. Create dummy x_batch and y_batch data.\n","5. Call your train_step function. Print the returned loss.\n","6. (Optional) Verify that the model's parameters have changed after the train_step by comparing a parameter value before and after the call."],"metadata":{"id":"i7jXowc9ACNB"}},{"cell_type":"code","source":["# Instructions for Exercise 7\n","key_ex7_main, main_key = jax.random.split(main_key)\n","key_ex7_x, key_ex7_y = jax.random.split(key_ex7_main, 2)\n","\n","# 1. Use model and optimizer from previous exercises' solutions\n","# Ensure my_model_sol and nnx_optimizer_sol are available\n","if 'my_model_sol' not in globals() or 'nnx_optimizer_sol' not in globals():\n"," print(\"Re-initializing model and optimizer from Ex5/Ex6 solutions for Ex7.\")\n"," key_ex7_model_fallback, main_key = jax.random.split(main_key)\n"," _model_din_ex7 = 3\n"," _model_dout_ex7 = 2\n"," _model_rngs_ex7 = nnx.Rngs(params=key_ex7_model_fallback)\n"," # Ensure SimpleNNXModel_Sol is used\n"," my_model_ex7 = SimpleNNXModel_Sol(din=_model_din_ex7, dout=_model_dout_ex7, rngs=_model_rngs_ex7)\n"," _optax_tx_ex7 = optax.adam(learning_rate=0.001)\n"," nnx_optimizer_ex7 = nnx.Optimizer(my_model_ex7, _optax_tx_ex7)\n"," print(\"Model and optimizer re-created for Ex7.\")\n","else:\n"," my_model_ex7 = my_model_sol\n"," nnx_optimizer_ex7 = nnx_optimizer_sol\n"," print(\"Using 'my_model_sol' and 'nnx_optimizer_sol' for 'my_model_ex7' and 'nnx_optimizer_ex7'.\")\n","\n","\n","# 2. & 3. Define the train_step function\n","# TODO: Decorate with @nnx.jit\n","# def train_step(model_arg: nnx.Module, optimizer_arg: nnx.Optimizer, # Type hint with base nnx.Module\n","# x_batch: jnp.ndarray, y_batch: jnp.ndarray) -> jnp.ndarray:\n","\n"," # TODO: Define inner loss_fn_for_grad(current_model_state_for_grad_fn)\n"," # def loss_fn_for_grad(model_in_grad_fn: nnx.Module): # Type hint with base nnx.Module\n"," # y_pred = model_in_grad_fn(x_batch)\n"," # loss = jnp.mean((y_pred - y_batch)**2)\n"," # return loss\n"," # return jnp.array(0.0) # Placeholder\n","\n"," # TODO: Compute loss value and gradients using nnx.value_and_grad\n"," # loss_value, grads = nnx.value_and_grad(loss_fn_for_grad)(model_arg) # Pass model_arg\n","\n"," # TODO: Update the optimizer (which updates the model_arg in-place)\n"," # optimizer_arg.update(model_arg, grads)\n","\n"," # return loss_value\n","# return jnp.array(0.0) # Placeholder defined train_step function\n","\n","# For the student to define:\n","# Make sure the function signature is correct for nnx.jit\n","@nnx.jit\n","def train_step(model_arg: nnx.Module, optimizer_arg: nnx.Optimizer,\n"," x_batch: jnp.ndarray, y_batch: jnp.ndarray) -> jnp.ndarray:\n"," # Placeholder implementation for student\n"," def loss_fn_for_grad(model_in_grad_fn: nnx.Module):\n"," # y_pred = model_in_grad_fn(x_batch)\n"," # loss = jnp.mean((y_pred - y_batch)**2)\n"," # return loss\n"," return jnp.array(0.0) # Student TODO: replace this\n","\n"," # loss_value, grads = nnx.value_and_grad(loss_fn_for_grad)(model_arg)\n"," # optimizer_arg.update(grads)\n"," # return loss_value\n"," return jnp.array(-1.0) # Student TODO: replace this\n","\n","\n","# 4. Create dummy data\n","batch_s = 8\n","# Access features_in and features_out carefully\n","_din_from_model_ex7 = my_model_ex7.dense_layer.in_features if hasattr(my_model_ex7, 'dense_layer') else 3\n","_dout_from_model_ex7 = my_model_ex7.dense_layer.out_features if hasattr(my_model_ex7, 'dense_layer') else 2\n","\n","x_batch_data = jax.random.normal(key_ex7_x, (batch_s, _din_from_model_ex7))\n","y_batch_data = jax.random.normal(key_ex7_y, (batch_s, _dout_from_model_ex7))\n","\n","# Optional: Store initial param value for comparison\n","initial_kernel_val = None\n","if hasattr(my_model_ex7, 'get_state'):\n"," _current_model_state_ex7 = my_model_ex7.get_state()\n"," if 'dense_layer' in _current_model_state_ex7:\n"," initial_kernel_val = _current_model_state_ex7['dense_layer']['kernel'].value[0,0].copy()\n","print(f\"Initial kernel value (sample): {initial_kernel_val}\")\n","\n","# 5. Call the train_step\n","# loss_after_step = train_step(my_model_ex7, nnx_optimizer_ex7, x_batch_data, y_batch_data) # Student will uncomment\n","loss_after_step = jnp.array(-1.0) # Placeholder until student implements train_step\n","if train_step(my_model_ex7, nnx_optimizer_ex7, x_batch_data, y_batch_data).item() != -1.0: # Check if student implemented\n"," loss_after_step = train_step(my_model_ex7, nnx_optimizer_ex7, x_batch_data, y_batch_data)\n"," print(f\"Loss after one training step: {loss_after_step}\")\n","else:\n"," print(\"Student needs to implement `train_step` function.\")\n","\n","\n","# # 6. Optional: Verify parameter change\n","# updated_kernel_val_sol = None\n","# _, updated_model_state_sol = nnx.split(my_model_sol_ex7) # Get state again after update\n","# if 'dense_layer' in updated_model_state_sol:\n","# updated_kernel_val_sol = updated_model_state_sol['dense_layer']['kernel'].value[0,0]\n","# print(f\"Updated kernel value (sample): {updated_kernel_val_sol}\")\n","\n","# if initial_kernel_val_sol is not None and updated_kernel_val_sol is not None:\n","# assert not jnp.allclose(initial_kernel_val_sol, updated_kernel_val_sol), \"Kernel parameter did not change!\"\n","# print(\"Kernel parameter changed as expected after the training step.\")\n","# else:\n","# print(\"Could not verify kernel change (initial or updated value was None).\")"],"metadata":{"id":"KEQCcmBI-ce2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Solution 7: Training Step with Flax NNX and Optax\n","key_ex7_sol_main, main_key = jax.random.split(main_key)\n","key_ex7_sol_x, key_ex7_sol_y = jax.random.split(key_ex7_sol_main, 2)\n","\n","# 1. Use model and optimizer from previous exercises' solutions\n","if 'my_model_sol' not in globals() or 'nnx_optimizer_sol' not in globals():\n"," print(\"Re-initializing model and optimizer from Ex5/Ex6 solutions for Ex7 solution.\")\n"," key_ex7_sol_model_fallback, main_key = jax.random.split(main_key)\n"," _model_din_sol_ex7 = 3\n"," _model_dout_sol_ex7 = 2\n"," _model_rngs_sol_ex7 = nnx.Rngs(params=key_ex7_sol_model_fallback)\n"," # Ensure SimpleNNXModel_Sol is used for the solution\n"," my_model_sol_ex7 = SimpleNNXModel_Sol(din=_model_din_sol_ex7, dout=_model_dout_sol_ex7, rngs=_model_rngs_sol_ex7)\n"," _optax_tx_sol_ex7 = optax.adam(learning_rate=0.001)\n"," nnx_optimizer_sol_ex7 = nnx.Optimizer(my_model_sol_ex7, _optax_tx_sol_ex7)\n"," print(\"Model and optimizer re-created for Ex7 solution.\")\n","else:\n"," # If solutions are run sequentially, these will be the correct instances\n"," my_model_sol_ex7 = my_model_sol\n"," nnx_optimizer_sol_ex7 = nnx_optimizer_sol\n"," print(\"Using 'my_model_sol' and 'nnx_optimizer_sol' for Ex7 solution.\")\n","\n","\n","# 2. & 3. Define the train_step function\n","@nnx.jit # Decorate with @nnx.jit for JIT compilation\n","def train_step_sol(model_arg: nnx.Module, optimizer_arg: nnx.Optimizer, # Use base nnx.Module for generality\n"," x_batch: jnp.ndarray, y_batch: jnp.ndarray) -> jnp.ndarray:\n","\n"," # Define inner loss_fn_for_grad. It takes the model as its first argument.\n"," # It captures x_batch and y_batch from the outer scope.\n"," def loss_fn_for_grad(model_in_grad_fn: nnx.Module): # Use base nnx.Module\n"," y_pred = model_in_grad_fn(x_batch) # Use the model passed to this inner function\n"," loss = jnp.mean((y_pred - y_batch)**2)\n"," return loss\n","\n"," # Compute loss value and gradients using nnx.value_and_grad.\n"," # This will differentiate loss_fn_for_grad with respect to its first argument (model_in_grad_fn).\n"," # We pass the current state of our model (model_arg) to it.\n"," loss_value, grads = nnx.value_and_grad(loss_fn_for_grad)(model_arg)\n","\n"," # Update the optimizer. This updates the model_arg (which nnx_optimizer_sol_ex7 references) in-place.\n"," optimizer_arg.update(model_arg, grads)\n","\n"," return loss_value\n","\n","\n","# 4. Create dummy data\n","batch_s_sol = 8\n","# Ensure din and dout match the model instantiation from Ex5/Ex6\n","# my_model_sol_ex7.dense_layer is an nnx.Linear object\n","din_from_model_sol = my_model_sol_ex7.dense_layer.in_features\n","dout_from_model_sol = my_model_sol_ex7.dense_layer.out_features\n","\n","x_batch_data_sol = jax.random.normal(key_ex7_sol_x, (batch_s_sol, din_from_model_sol))\n","y_batch_data_sol = jax.random.normal(key_ex7_sol_y, (batch_s_sol, dout_from_model_sol))\n","\n","# Optional: Store initial param value for comparison\n","initial_kernel_val_sol = None\n","_, current_model_state_sol = nnx.split(my_model_sol_ex7)\n","if 'dense_layer' in current_model_state_sol:\n"," initial_kernel_val_sol = current_model_state_sol['dense_layer']['kernel'].value[0,0].copy()\n","print(f\"Initial kernel value (sample): {initial_kernel_val_sol}\")\n","\n","\n","# 5. Call the train_step\n","# First call will JIT compile the train_step_sol function.\n","loss_after_step_sol = train_step_sol(my_model_sol_ex7, nnx_optimizer_sol_ex7, x_batch_data_sol, y_batch_data_sol)\n","print(f\"Loss after one training step (1st call, JIT): {loss_after_step_sol}\")\n","# Second call to show it's faster (though %timeit is better for measurement)\n","loss_after_step_sol_2 = train_step_sol(my_model_sol_ex7, nnx_optimizer_sol_ex7, x_batch_data_sol, y_batch_data_sol)\n","print(f\"Loss after one training step (2nd call, cached): {loss_after_step_sol_2}\")\n","\n","\n","# 6. Optional: Verify parameter change\n","updated_kernel_val_sol = None\n","_, updated_model_state_sol = nnx.split(my_model_sol_ex7) # Get state again after update\n","if 'dense_layer' in updated_model_state_sol:\n"," updated_kernel_val_sol = updated_model_state_sol['dense_layer']['kernel'].value[0,0]\n"," print(f\"Updated kernel value (sample): {updated_kernel_val_sol}\")\n","\n","if initial_kernel_val_sol is not None and updated_kernel_val_sol is not None:\n"," assert not jnp.allclose(initial_kernel_val_sol, updated_kernel_val_sol), \"Kernel parameter did not change!\"\n"," print(\"Kernel parameter changed as expected after the training step.\")\n","else:\n"," print(\"Could not verify kernel change (initial or updated value was None).\")"],"metadata":{"id":"7bVlg9_-Ae6Z"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Exercise 8: Orbax - Saving and Restoring Checkpoints\n","\n","**Goal**: Learn to use Orbax to save and restore JAX PyTrees, specifically Flax NNX model states and Optax optimizer states.\n","\n","### Instructions:\n","1. You'll need your model (e.g., my_model_sol_ex7) and optimizer (e.g., nnx_optimizer_sol_ex7) from the previous exercise's solution.\n","2. Define a checkpoint directory (e.g., /tmp/my_nnx_checkpoint/).\n","3. Create an Orbax CheckpointManagerOptions and then a CheckpointManager.\n","4. Bundle the states you want to save into a dictionary. For NNX, this is my_model_sol_ex7.get_state() for the model, and nnx_optimizer_sol_ex7.state for the optimizer's internal state. Also include a training step counter.\n","5. Use checkpoint_manager.save() with ocp.args.StandardSave() to save the bundled state. Call checkpoint_manager.wait_until_finished() to ensure saving completes.\n","6. To restore:\n"," - Create new instances of your model (restored_model) and Optax transform (restored_optax_tx). The new model should have a different PRNG key for its initial parameters to demonstrate that restoration works.\n"," - Use checkpoint_manager.restore() with ocp.args.StandardRestore() to load the bundled state.\n"," - Apply the loaded model state to restored_model using restored_model.update_state(loaded_bundle['model']).\n"," - Create a new nnx.Optimizer (restored_optimizer) associating restored_model and restored_optax_tx.\n"," - Assign the loaded optimizer state to the new optimizer: restored_optimizer.state = loaded_bundle['optimizer'].\n","7. Verify that a parameter from restored_model matches the corresponding parameter from the original my_model_sol_ex7 (before saving, or from the saved state). Also, compare optimizer states if possible.\n","8. Clean up the checkpoint directory."],"metadata":{"id":"_t8KGFhqDoSu"}},{"cell_type":"code","source":["# Instructions for Exercise 8\n","# import orbax.checkpoint as ocp # Already imported\n","# import os, shutil # Already imported\n","\n","# 1. Use model and optimizer from previous exercise solution\n","if 'my_model_sol_ex7' not in globals() or 'nnx_optimizer_sol_ex7' not in globals():\n"," print(\"Re-initializing model and optimizer from Ex7 solution for Ex8.\")\n"," key_ex8_model_fallback, main_key = jax.random.split(main_key)\n"," _model_din_ex8 = 3\n"," _model_dout_ex8 = 2\n"," _model_rngs_ex8 = nnx.Rngs(params=key_ex8_model_fallback)\n"," _ModelClassEx8 = SimpleNNXModel_Sol if 'SimpleNNXModel_Sol' in globals() else SimpleNNXModel\n"," model_to_save = _ModelClassEx8(din=_model_din_ex8, dout=_model_dout_ex8, rngs=_model_rngs_ex8)\n"," _optax_tx_ex8 = optax.adam(learning_rate=0.001)\n"," optimizer_to_save = nnx.Optimizer(model_to_save, _optax_tx_ex8)\n"," print(\"Model and optimizer re-created for Ex8.\")\n","else:\n"," model_to_save = my_model_sol_ex7\n"," optimizer_to_save = nnx_optimizer_sol_ex7\n"," print(\"Using model and optimizer from Ex7 solution for Ex8.\")\n","\n","# 2. Define checkpoint directory\n","# TODO: Define checkpoint_dir\n","checkpoint_dir = None # Placeholder e.g., \"/tmp/my_nnx_checkpoint_exercise/\"\n","# if checkpoint_dir and os.path.exists(checkpoint_dir):\n","# shutil.rmtree(checkpoint_dir) # Clean up previous runs for safety\n","# if checkpoint_dir:\n","# os.makedirs(checkpoint_dir, exist_ok=True)\n","\n","\n","# 3. Create Orbax CheckpointManager\n","# TODO: Create options and manager\n","# options = ocp.CheckpointManagerOptions(...)\n","# mngr = ocp.CheckpointManager(...)\n","options = None\n","mngr = None\n","\n","# 4. Bundle states\n","# current_step = 100 # Example step\n","# TODO: Get model_state and optimizer_state\n","# model_state_to_save = nnx.split(model_to_save)\n","# The optimizer state is now accessed via the .state attribute.\n","# opt_state_to_save = optimizer_to_save.state\n","# save_bundle = {\n","# 'model': model_state_to_save,\n","# 'optimizer': opt_state_to_save,\n","# 'step': current_step\n","# }\n","save_bundle = None\n","\n","# 5. Save the checkpoint\n","# if mngr and save_bundle:\n","# TODO: Save checkpoint\n","# mngr.save(...)\n","# mngr.wait_until_finished()\n","# print(f\"Checkpoint saved at step {current_step} to {checkpoint_dir}\")\n","# else:\n","# print(\"Checkpoint manager or save_bundle not initialized.\")\n","\n","# --- Restoration ---\n","# 6.a Create new model and Optax transform (for restoration)\n","# key_ex8_restore_model, main_key = jax.random.split(main_key)\n","# din_restore = model_to_save.dense_layer.in_features if hasattr(model_to_save, 'dense_layer') else 3\n","# dout_restore = model_to_save.dense_layer.out_features if hasattr(model_to_save, 'dense_layer') else 2\n","# _ModelClassRestore = SimpleNNXModel_Sol if 'SimpleNNXModel_Sol' in globals() else SimpleNNXModel\n","# restored_model = _ModelClassRestore(\n","# din=din_restore, dout=dout_restore,\n","# rngs=nnx.Rngs(params=key_ex8_restore_model) # New key for different initial params\n","# )\n","# restored_optax_tx = optax.adam(learning_rate=0.001) # Same Optax config\n","restored_model = None\n","restored_optax_tx = None\n","\n","# 6.b Restore the checkpoint\n","# loaded_bundle = None\n","# if mngr:\n","# TODO: Restore checkpoint\n","# latest_step = mngr.latest_step()\n","# if latest_step is not None:\n","# loaded_bundle = mngr.restore(...)\n","# print(f\"Checkpoint restored from step {latest_step}\")\n","# else:\n","# print(\"No checkpoint found to restore.\")\n","# else:\n","# print(\"Checkpoint manager not initialized for restore.\")\n","\n","# 6.c Apply loaded states\n","# if loaded_bundle and restored_model:\n","# TODO: Update restored_model state\n","# nnx.update(restored_model, ...)\n","# print(\"Restored model state applied.\")\n","\n"," # TODO: Create new nnx.Optimizer and assign its state\n","# restored_optimizer = nnx.Optimizer(...)\n","# restored_optimizer.state = ...\n","# print(\"Restored optimizer state applied.\")\n","# else:\n","# print(\"Loaded_bundle or restored_model is None, cannot apply states.\")\n","restored_optimizer = None\n","\n","# 7. Verify restoration\n","# if loaded_bundle and restored_model and save_bundle:\n","# original_kernel = save_bundle['model']['dense_layer']['kernel']\n","# _, restored_model_state = nnx.split(restored_model_sol)\n","# kernel_after_restore_sol = restored_model_state['dense_layer']['kernel']\n","# assert jnp.array_equal(original_kernel.value, kernel_after_restore.value)\n","# print(\"Model parameters successfully restored and verified.\")\n","\n","# Optax state can be complex, a basic check on 'count' if Adam\n","# if hasattr(save_bundle['optimizer'].opt_state[0], 'count') and \\\n","# hasattr(restored_optimizer.state.opt_state[0], 'count'):\n","# original_opt_count = save_bundle['optimizer'].opt_state[0].count\n","# restored_opt_count = restored_optimizer.state.opt_state[0].count\n","# assert jnp.array_equal(original_opt_count, restored_opt_count)\n","# print(\"Optimizer state (count) successfully restored and verified.\")\n","# else:\n","# print(\"Verification skipped as some components are None.\")\n","\n","\n","# 8. Clean up\n","# if mngr:\n","# mngr.close()\n","# if checkpoint_dir and os.path.exists(checkpoint_dir):\n","# shutil.rmtree(checkpoint_dir)\n","# print(f\"Cleaned up checkpoint directory: {checkpoint_dir}\")"],"metadata":{"id":"V7XdNy-vAjpG"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Solution 8: Orbax - Saving and Restoring Checkpoints\n","\n","# 1. Use model and optimizer from previous exercise solution\n","if 'my_model_sol_ex7' not in globals() or 'nnx_optimizer_sol_ex7' not in globals():\n"," print(\"Re-initializing model and optimizer from Ex7 solution for Ex8 solution.\")\n"," key_ex8_sol_model_fallback, main_key = jax.random.split(main_key)\n"," _model_din_sol_ex8 = 3\n"," _model_dout_sol_ex8 = 2\n"," _model_rngs_sol_ex8 = nnx.Rngs(params=key_ex8_sol_model_fallback)\n"," # Ensure SimpleNNXModel_Sol is used for the solution\n"," model_to_save_sol = SimpleNNXModel_Sol(din=_model_din_sol_ex8,\n"," dout=_model_dout_sol_ex8,\n"," rngs=_model_rngs_sol_ex8)\n"," _optax_tx_sol_ex8 = optax.adam(learning_rate=0.001) # Store the transform for later\n"," optimizer_to_save_sol = nnx.Optimizer(model_to_save_sol, _optax_tx_sol_ex8)\n"," print(\"Model and optimizer re-created for Ex8 solution.\")\n","else:\n"," model_to_save_sol = my_model_sol_ex7\n"," optimizer_to_save_sol = nnx_optimizer_sol_ex7\n"," # We need the optax transform used to create the optimizer for restoration\n"," _optax_tx_sol_ex8 = optimizer_to_save_sol.tx # Access the original Optax transform\n"," print(\"Using model and optimizer from Ex7 solution for Ex8 solution.\")\n","\n","# 2. Define checkpoint directory\n","checkpoint_dir_sol = \"/tmp/my_nnx_checkpoint_exercise_solution/\"\n","if os.path.exists(checkpoint_dir_sol):\n"," shutil.rmtree(checkpoint_dir_sol) # Clean up previous runs\n","os.makedirs(checkpoint_dir_sol, exist_ok=True)\n","print(f\"Orbax checkpoint directory: {checkpoint_dir_sol}\")\n","\n","# 3. Create Orbax CheckpointManager\n","options_sol = ocp.CheckpointManagerOptions(save_interval_steps=1, max_to_keep=1)\n","mngr_sol = ocp.CheckpointManager(checkpoint_dir_sol, options=options_sol)\n","\n","# 4. Bundle states\n","current_step_sol = 100 # Example step\n","_, model_state_to_save_sol = nnx.split(model_to_save_sol)\n","# The optimizer state is now a PyTree directly available in the .state attribute.\n","_opt_state_to_save_sol = optimizer_to_save_sol.state\n","save_bundle_sol = {\n"," 'model': model_state_to_save_sol,\n"," 'optimizer': opt_state_to_save_sol,\n"," 'step': current_step_sol\n","}\n","print(\"\\nState bundle to be saved:\")\n","pprint.pprint(f\"Model state keys: {model_state_to_save_sol.keys()}\")\n","pprint.pprint(f\"Optimizer state type: {type(opt_state_to_save_sol)}\")\n","\n","\n","# 5. Save the checkpoint\n","mngr_sol.save(current_step_sol, args=ocp.args.StandardSave(save_bundle_sol))\n","mngr_sol.wait_until_finished()\n","print(f\"\\nCheckpoint saved at step {current_step_sol} to {checkpoint_dir_sol}\")\n","\n","# --- Restoration ---\n","# 6.a Create new model and Optax transform (for restoration)\n","key_ex8_sol_restore_model, main_key = jax.random.split(main_key)\n","# Ensure din/dout are correctly obtained from the saved model's structure if possible\n","# Assuming model_to_save_sol is SimpleNNXModel_Sol which has a dense_layer\n","din_restore_sol = model_to_save_sol.dense_layer.in_features\n","dout_restore_sol = model_to_save_sol.dense_layer.out_features\n","\n","restored_model_sol = SimpleNNXModel_Sol( # Use the solution's model class\n"," din=din_restore_sol, dout=dout_restore_sol,\n"," rngs=nnx.Rngs(params=key_ex8_sol_restore_model) # New key for different initial params\n",")\n","# We need the original Optax transform definition for the new nnx.Optimizer\n","# _optax_tx_sol_ex8 was stored earlier, or can be re-created if config is known\n","restored_optax_tx_sol = _optax_tx_sol_ex8\n","\n","# Print a param from new model BEFORE restoration to show it's different\n","_, kernel_before_restore_sol = nnx.split(restored_model_sol)\n","print(f\"\\nSample kernel from 'restored_model_sol' BEFORE restoration:\")\n","nnx.display(kernel_before_restore_sol['dense_layer']['kernel'])\n","\n","# 6.b Restore the checkpoint\n","loaded_bundle_sol = None\n","latest_step_sol = mngr_sol.latest_step()\n","if latest_step_sol is not None:\n"," # For NNX, we are restoring raw PyTrees, StandardRestore is suitable.\n"," loaded_bundle_sol = mngr_sol.restore(latest_step_sol,\n"," args=ocp.args.StandardRestore(save_bundle_sol))\n"," print(f\"\\nCheckpoint restored from step {latest_step_sol}\")\n"," print(f\"Loaded bundle contains keys: {loaded_bundle_sol.keys()}\")\n","else:\n"," raise ValueError(\"No checkpoint found to restore.\")\n","\n","# 6.c Apply loaded states\n","assert loaded_bundle_sol is not None, \"Loaded bundle is None\"\n","nnx.update(restored_model_sol, loaded_bundle_sol['model'])\n","print(\"Restored model state applied to 'restored_model_sol'.\")\n","\n","# Create new nnx.Optimizer with the restored_model and original optax_tx\n","restored_optimizer_sol = nnx.Optimizer(restored_model_sol, restored_optax_tx_sol)\n","# Now assign the loaded Optax state PyTree\n","restored_optimizer_sol.state = loaded_bundle_sol['optimizer']\n","print(\"Restored optimizer state applied to 'restored_optimizer_sol'.\")\n","\n","\n","# 7. Verify restoration\n","original_kernel_sol = save_bundle_sol['model']['dense_layer']['kernel']\n","_, restored_model_state = nnx.split(restored_model_sol)\n","kernel_after_restore_sol = restored_model_state['dense_layer']['kernel']\n","assert jnp.array_equal(original_kernel_sol.value, kernel_after_restore_sol.value), \\\n"," \"Model kernel parameters differ after restoration!\"\n","print(\"\\nModel parameters successfully restored and verified (kernel match).\")\n","assert not jnp.array_equal(kernel_before_restore_sol['dense_layer']['kernel'].value,\n"," kernel_after_restore_sol.value), \\\n"," \"Kernel should be different from pre-restore state.\"\n","print(\"Kernel is different from its pre-restoration random initialization, good.\")\n","\n","\n","# Verify optimizer state (e.g., Adam's 'count' and 'mu' for a specific parameter)\n","original_opt_state_adam_count = save_bundle_sol['optimizer'].opt_state[0].count.value\n","restored_opt_state_adam_count = restored_optimizer_sol.state.opt_state[0].count.value\n","assert jnp.array_equal(original_opt_state_adam_count, restored_opt_state_adam_count), \\\n"," \"Optimizer Adam count differs!\"\n","\n","original_opt_state_adam_mu_kernel = save_bundle_sol['optimizer'].opt_state[0].mu['dense_layer']['kernel'].value\n","restored_opt_state_adam_mu_kernel = restored_optimizer_sol.state.opt_state[0].mu['dense_layer']['kernel'].value\n","assert jnp.array_equal(original_opt_state_adam_mu_kernel, restored_opt_state_adam_mu_kernel), \\\n"," \"Optimizer Adam mu for kernel differs!\"\n","print(\"Optimizer state (Adam count and sample mu) successfully restored and verified.\")\n","\n","\n","# 8. Clean up\n","mngr_sol.close()\n","if os.path.exists(checkpoint_dir_sol):\n"," shutil.rmtree(checkpoint_dir_sol)\n"," print(f\"Cleaned up checkpoint directory: {checkpoint_dir_sol}\")"],"metadata":{"id":"2-Fk8aukEGVL"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Conclusion\n","\n","### Congratulations on completing the JAX AI Stack exercises!\n","\n","You've now had a hands-on introduction to:\n","\n","- Core JAX: jax.numpy, functional programming, jax.jit, jax.grad, jax.vmap.\n","- Flax NNX: Defining and instantiating Pythonic neural network models.\n","- Optax: Creating and using composable optimizers with Flax NNX.\n","- Training Loop: Implementing an end-to-end training step in Flax NNX.\n","- Orbax: Saving and restoring model and optimizer states.\n","\n","This forms a strong foundation for developing high-performance machine learning models with the JAX ecosystem.\n","\n","For further learning, refer to the official documentation:\n","- JAX AI Stack: https://jaxstack.ai\n","- JAX: https://jax.dev\n","- Flax NNX: https://flax.readthedocs.io\n","- Optax: https://optax.readthedocs.io\n","- Orbax: https://orbax.readthedocs.io\n","\n","Don't forget to provide feedback on the training session:\n","https://goo.gle/jax-training-feedback"],"metadata":{"id":"9kotBqE7Qhiv"}},{"cell_type":"code","source":[],"metadata":{"id":"TdQIp5G9QqwR"},"execution_count":null,"outputs":[]}]} \ No newline at end of file diff --git a/docs/learning_jax/code-exercises/10 - Sharding & Parallelism.ipynb b/docs/learning_jax/code-exercises/10 - Sharding & Parallelism.ipynb new file mode 100644 index 0000000..3417342 --- /dev/null +++ b/docs/learning_jax/code-exercises/10 - Sharding & Parallelism.ipynb @@ -0,0 +1 @@ +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[{"file_id":"1pPtRlg8z8UZb9s041S2SweacIZ-kpBP2","timestamp":1755114078390}],"toc_visible":true,"authorship_tag":"ABX9TyNQT2JyPg+rAuQbdE6mm9yv"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","source":["# Sharding and Parallelism with JAX & Flax NNX\n","\n","## Introduction\n","Welcome to the practical exercises for \"Scaling Up: Sharding and Parallelism with JAX and Flax NNX\"!\n","\n","This Colab notebook is designed to provide you with hands-on experience of the concepts covered in the lecture. You'll work through examples of defining hardware meshes, specifying sharding for your data and model parameters, initializing large models in a distributed fashion, and setting up a sharded training step using JAX's powerful SPMD (Single Program, Multiple Data) capabilities with Flax NNX.\n","\n","Let's dive into scaling your JAX models!"],"metadata":{"id":"z4HKhLIj00Fo"}},{"cell_type":"code","execution_count":null,"metadata":{"id":"me5v-lQZ0xJv"},"outputs":[],"source":["# @title Setup and Imports\n","# Install necessary libraries\n","!pip install -Uq flax optax chex\n","\n","import jax\n","import jax.numpy as jnp\n","import numpy as np\n","import flax.nnx as nnx\n","from flax.nnx import spmd # For sharding utilities like get_partition_spec\n","import chex\n","from jax.sharding import Mesh, PartitionSpec as P, NamedSharding\n","import optax # For the training loop example\n","\n","# --- IMPORTANT: Simulate a multi-device environment ---\n","# This allows us to run sharding examples in Colab which typically has only one accelerator.\n","# We'll simulate 8 CPU devices.\n","try:\n"," chex.set_n_cpu_devices(8)\n","except RuntimeError as e:\n"," print(f\"Note: {e}. This is expected if devices are already set or on a real multi-device TPU/GPU setup.\")\n","\n","# Verify the number of devices JAX sees\n","print(f\"JAX visible devices: {jax.devices()}\")\n","print(f\"Number of JAX devices: {jax.device_count()}\")\n","\n","# Helper function to create RNG keys for NNX\n","PRNGKey = jax.random.PRNGKey\n","# For NNX module parameter initialization\n","def new_key():\n"," return nnx.make_rng('params')\n","\n","# For general JAX operations\n","MAIN_KEY = PRNGKey(0)\n","\n","# Silence some NNX warnings for cleaner output in the notebook\n","import logging\n","logging.getLogger('flax.experimental.nnx.nnx.graph').setLevel(logging.ERROR)"]},{"cell_type":"markdown","source":["# Exercise 1: JAX Sharding Primitives - Mesh, PartitionSpec, NamedSharding, and device_put\n","\n","The foundation of explicit parallelism in JAX involves:\n","\n","- `Mesh`: A logical grid representing your physical accelerator devices.\n","- `PartitionSpec` (often aliased as `P`): A tuple describing how a tensor's dimensions map to `Mesh` axes.\n","- `NamedSharding`: Combines a `Mesh` and a `PartitionSpec` into a reusable sharding strategy.\n","- `jax.device_put`: Explicitly places data onto devices with a specific `NamedSharding`.\n","\n","**Instructions**:\n","\n","1. Create a Mesh for our 8 simulated devices, arranging them in a 2x4 grid. Name the mesh axes 'data' and 'model'.\n","2. Define three PartitionSpecs:\n"," - pspec_data_parallel: Shard dimension 0 along 'data', replicate dimension 1. (Typical for input batches [batch, features])\n"," - pspec_model_parallel_dim1: Replicate dimension 0, shard dimension 1 along 'model'. (Typical for a weight matrix [in_features, out_features] in some forms of tensor parallelism).\n"," - pspec_replicated: Fully replicate the tensor on all devices in the mesh.\n","3. Create a NamedSharding object for pspec_data_parallel using your mesh.\n","4. Create a sample NumPy array of shape (16, 128).\n","5. Use jax.device_put to shard this NumPy array according to the NamedSharding you created.\n","6. Print the mesh and the .sharding attribute of the sharded JAX array to verify."],"metadata":{"id":"uAb2MXm-2jMX"}},{"cell_type":"code","source":["# Exercise 1: Instructions Cell\n","\n","# TODO: 1. Create a Mesh for 8 devices in a 2x4 grid with axes 'data' and 'model'.\n","# Tip: jax.devices() gives a list of devices. Reshape a numpy array of these devices.\n","mesh_devices = np.array(jax.devices()).reshape((2, 4)) # Devices for 'data' axis, then 'model' axis\n","mesh = Mesh(mesh_devices, axis_names=('data', 'model'))\n","\n","# TODO: 2. Define PartitionSpecs\n","pspec_data_parallel = P('data', None) # Shard batch dim, replicate feature dim\n","pspec_model_parallel_dim1 = P(None, 'model') # Replicate in_feature dim, shard out_feature dim\n","pspec_replicated = P() # Fully replicated\n","\n","# TODO: 3. Create NamedSharding for data parallelism\n","data_named_sharding = NamedSharding(mesh, pspec_data_parallel)\n","\n","# TODO: 4. Create a sample NumPy array\n","numpy_array = np.arange(16 * 128, dtype=np.float32).reshape((16, 128))\n","\n","# TODO: 5. Shard the array using jax.device_put\n","sharded_array = jax.device_put(numpy_array, data_named_sharding)\n","\n","# TODO: 6. Print mesh and sharded array's sharding\n","print(\"Mesh:\", mesh)\n","print(\"\\nPartitionSpec for data parallel:\", pspec_data_parallel)\n","print(\"PartitionSpec for model parallel (dim 1):\", pspec_model_parallel_dim1)\n","print(\"PartitionSpec for replicated:\", pspec_replicated)\n","print(\"\\nNamedSharding for data:\", data_named_sharding)\n","print(\"\\nOriginal NumPy array shape:\", numpy_array.shape)\n","print(\"Sharded JAX array sharding:\", sharded_array.sharding)\n","print(\"Sharded JAX array device buffers (should show multiple devices):\")\n","for buffer in sharded_array.addressable_shards:\n"," print(f\"Buffer {i}: Device={db.device}, Shape={db.data.shape}\")\n","\n","# Verify that each device in the 'data' axis gets a slice\n","# For a (2,4) mesh with P('data', None) on a (16,128) array:\n","# Data axis has 2 devices. 16 / 2 = 8.\n","# Each of the 4 devices along 'model' axis within a data-slice group will get a replica of that (8,128) slice.\n","# So, each of the 8 devices will hold a slice of shape (8, 128) effectively,\n","# but logically the sharding is over the 'data' axis.\n","# The `shard_shape` method reflects the shape on one device.\n","# Total elements: 16*128 = 2048. Each shard has 8*128 = 1024 elements.\n","# 2048 / 8 devices = 256 elements per device if fully sharded.\n","# Here, with P('data', None), it's 16/2 = 8. So shape on each device is (8, 128).\n","\n","jax.debug.visualize_array_sharding(sharded_array)"],"metadata":{"id":"WVhTJWmr1x1u"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Exercise 1: Solution"],"metadata":{"id":"yYxsyR0R3jbe"}},{"cell_type":"code","source":["# Exercise 1: Solution Cell\n","\n","# 1. Create a Mesh for 8 devices in a 2x4 grid with axes 'data' and 'model'.\n","# The 'data' axis will have 2 devices, and the 'model' axis will have 4 devices.\n","mesh_devices = np.array(jax.devices()).reshape((2, 4))\n","mesh = Mesh(mesh_devices, axis_names=('data', 'model'))\n","\n","# 2. Define PartitionSpecs\n","pspec_data_parallel = P('data', None) # Shard batch (dim 0) along 'data', replicate features (dim 1)\n","pspec_model_parallel_dim1 = P(None, 'model') # Replicate dim 0, shard dim 1 along 'model'\n","pspec_replicated = P() # Fully replicated across all devices in the mesh\n","\n","# 3. Create NamedSharding for data parallelism\n","data_named_sharding = NamedSharding(mesh, pspec_data_parallel)\n","\n","# 4. Create a sample NumPy array\n","numpy_array = np.arange(16 * 128, dtype=np.float32).reshape((16, 128)) # (batch_size=16, features=128)\n","\n","# 5. Shard the array using jax.device_put\n","# This places the array onto the devices defined by the mesh, according to the NamedSharding.\n","sharded_array = jax.device_put(numpy_array, data_named_sharding)\n","\n","# 6. Print mesh and sharded array's sharding\n","print(\"Mesh:\", mesh)\n","print(\"\\nPartitionSpec for data parallel:\", pspec_data_parallel)\n","print(\"PartitionSpec for model parallel (dim 1):\", pspec_model_parallel_dim1)\n","print(\"PartitionSpec for replicated:\", pspec_replicated)\n","print(\"\\nNamedSharding for data:\", data_named_sharding)\n","print(\"\\nOriginal NumPy array shape:\", numpy_array.shape)\n","print(\"Sharded JAX array object:\", sharded_array)\n","print(\"Sharded JAX array sharding:\", sharded_array.sharding)\n","\n","print(\"\\nInspecting device buffers for the sharded array:\")\n","# For a (16, 128) array sharded with P('data', None) on a ('data':2, 'model':4) mesh:\n","# The 'data' axis (size 2) shards the first dimension (16). So, 16/2 = 8.\n","# The second dimension (128) is replicated (None).\n","# Each device will hold a piece of shape (8, 128).\n","for i, db in enumerate(sharded_array.addressable_shards):\n"," # Access shape from the data attribute of the Shard object\n"," print(f\"Buffer {i}: Device={db.device}, Shape={db.data.shape}\")\n","\n","# You should see that the array is split across devices.\n","# With P('data', None) on a 2x4 mesh, the first dimension (16) is split over the 'data' axis (size 2).\n","# So, devices (0,0),(0,1),(0,2),(0,3) will get the first half of the data (rows 0-7), replicated.\n","# And devices (1,0),(1,1),(1,2),(1,3) will get the second half (rows 8-15), replicated.\n","# Each device buffer will have shape (8, 128).\n","\n","jax.debug.visualize_array_sharding(sharded_array)"],"metadata":{"id":"utojQYFz3oPo"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# Exercise 2: Annotating Sharding in a Simple NNX Module\n","\n","Flax NNX modules can store sharding hints (PartitionSpec tuples) directly within their parameter metadata. This is crucial for guiding the JAX compiler when performing sharded initialization and distributed training.\n","\n","**Instructions:**\n","\n","1. Define a simple NNX Linear module.\n","2. In its `__init__`, initialize kernel (a 2D weight matrix) and bias (a 1D vector) as nnx.Param.\n","3. Use nnx.with_metadata to attach PartitionSpec sharding hints during initialization:\n"," - For kernel (e.g., shape [in_features, out_features]), shard its second dimension (output features) along a mesh axis named 'model'. Replicate the first dimension. So, P(None, 'model').\n"," - For bias (e.g., shape [out_features]), shard it along the 'model' mesh axis. So, P('model').\n","4. Instantiate the module (without a mesh context for now; these are just metadata annotations).\n","5. Access the State of a parameter (e.g., module.kernel.state) and print its .sharding attribute to verify the PartitionSpec was stored."],"metadata":{"id":"L_OnFhW87Wr0"}},{"cell_type":"code","source":["# Exercise 2: Instructions Cell\n","\n","class SimpleLinear(nnx.Module):\n"," def __init__(self, in_features: int, out_features: int, *, rngs: nnx.Rngs):\n"," key = rngs.params()\n"," # TODO: 3. Initialize kernel and bias with sharding metadata\n"," # Kernel sharding: P(None, 'model')\n"," # Bias sharding: P('model')\n"," # Use nnx.initializers.lecun_normal() for kernel and nnx.initializers.zeros for bias\n","\n"," # Example of using nnx.with_metadata:\n"," # self.my_param = nnx.Param(\n"," # nnx.with_metadata(\n"," # nnx.initializers.zeros,\n"," # sharding=P(...) # Your PartitionSpec tuple\n"," # )(key, param_shape)\n"," # )\n"," # Or, more directly if nnx.Param supports 'sharding' kwarg for its value:\n"," # self.my_param = nnx.Param(\n"," # nnx.initializers.zeros(key, param_shape),\n"," # sharding=P(...)\n"," # )\n"," # The slides show both `nnx.with_metadata` and direct `sharding=` to nnx.Param.\n"," # Let's use nnx.with_metadata for clarity as it's explicitly for metadata.\n","\n"," self.kernel = nnx.Param(\n"," nnx.with_metadata(\n"," nnx.initializers.lecun_normal(),\n"," sharding=P(None, 'model') # Shard out_features along 'model'\n"," )(key, (in_features, out_features))\n"," )\n"," self.bias = nnx.Param(\n"," nnx.with_metadata(\n"," nnx.initializers.zeros,\n"," sharding=P('model') # Shard bias along 'model'\n"," )(key, (out_features,))\n"," )\n"," self.in_features = in_features\n"," self.out_features = out_features\n","\n"," def __call__(self, x: jax.Array):\n"," # This part is not the focus of sharding annotation, but good to have\n"," return x @ self.kernel.value + self.bias.value\n","\n","# TODO: 4. Instantiate the module\n","rngs_init = nnx.Rngs(params=jax.random.key(0)) # Create Rngs for NNX module\n","linear_module_annotated = SimpleLinear(in_features=128, out_features=256, rngs=rngs_init)\n","\n","# TODO: 5. Print the .sharding metadata from kernel and bias states\n","print(\"Kernel State:\", linear_module_annotated.kernel.state) # The state object itself\n","print(\"Kernel sharding metadata:\", linear_module_annotated.kernel.state.sharding)\n","print(\"Bias State:\", linear_module_annotated.bias.state)\n","print(\"Bias sharding metadata:\", linear_module_annotated.bias.state.sharding)"],"metadata":{"id":"V04cp_Rm3tV4"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Exercise 2: Solution"],"metadata":{"id":"mUGzxHGf7w6u"}},{"cell_type":"code","source":["# Exercise 2: Solution Cell\n","\n","class SimpleLinear(nnx.Module):\n"," def __init__(self, in_features: int, out_features: int, *, rngs: nnx.Rngs):\n"," key = rngs.params() # Get a JAX PRNGKey for parameter initialization\n","\n"," # 3. Initialize kernel and bias with sharding metadata\n"," # Kernel: shape (in_features, out_features), shard out_features along 'model'\n"," self.kernel = nnx.Param(\n"," nnx.with_metadata(\n"," nnx.initializers.lecun_normal(), # Initializer function\n"," sharding=P(None, 'model') # PartitionSpec tuple for metadata\n"," )(key, (in_features, out_features)) # Call initializer with key and shape\n"," )\n","\n"," # Bias: shape (out_features,), shard along 'model'\n"," self.bias = nnx.Param(\n"," nnx.with_metadata(\n"," nnx.initializers.zeros, # Initializer function\n"," sharding=P('model') # PartitionSpec tuple for metadata\n"," )(key, (out_features,)) # Call initializer with key and shape\n"," )\n"," self.in_features = in_features\n"," self.out_features = out_features\n","\n"," def __call__(self, x: jax.Array):\n"," return x @ self.kernel.value + self.bias.value\n","\n","# 4. Instantiate the module\n","# We need an Rngs object for NNX modules, even if just for 'params'\n","rngs_for_module_init = nnx.Rngs(params=jax.random.key(0))\n","linear_module_annotated = SimpleLinear(in_features=128, out_features=256, rngs=rngs_for_module_init)\n","\n","# 5. Print the .sharding metadata from kernel and bias states\n","# The sharding information is stored as metadata on the *value* within the nnx.Variable's state.\n","print(f\"Type of kernel: {type(linear_module_annotated.kernel)}\")\n","print(f\"Type of kernel's state: {type(linear_module_annotated.kernel.state)}\") # State object\n","print(f\"Type of kernel's value: {type(linear_module_annotated.kernel.value)}\") # JAX array with metadata\n","\n","# Access the sharding from the value\n","print(\"\\nKernel sharding metadata:\", linear_module_annotated.kernel.value.sharding)\n","print(\"Bias sharding metadata:\", linear_module_annotated.bias.value.sharding)\n","\n","# Verify output:\n","# Kernel sharding metadata: PartitionSpec(None, 'model')\n","# Bias sharding metadata: PartitionSpec('model',)"],"metadata":{"id":"dr4SyyF470LJ"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# Exercise 3: Sharded Initialization of an NNX Module\n","\n","Initializing a very large model directly can cause Out-Of-Memory (OOM) errors if all parameters are created on a single default device. The solution is to perform initialization and apply sharding constraints inside a JIT-compiled function executed within a Mesh context.\n","\n","**Workflow**:\n","\n","1. Instantiate the unsharded NNX module (with sharding metadata already defined).\n","2. Extract its functional State PyTree: state = nnx.state(model).\n","3. Extract the PartitionSpec PyTree from metadata: pspecs = nnx.spmd.get_partition_spec(state).\n","4. Apply sharding constraints to the State: sharded_state = jax.lax.with_sharding_constraint(state, pspecs).\n","5. Update the original module object with the sharded state: nnx.update(model, sharded_state).\n","6. Return the model.\n","7. Execute this entire JITted function within a jax.sharding.Mesh context.\n","\n","**Instructions**:\n","\n","1. Reuse the SimpleLinear module definition from Exercise 2.\n","2. Define a Mesh (e.g., 1x4, with axes ('data', 'model') or just 'model' if only model parallelism is intended for this part). Let's use (1, 8) with ('data', 'model') to make it clear we are sharding across the 'model' axis which has 8 devices.\n","3. Implement a function create_sharded_linear_model(rngs, in_features, out_features) that performs steps 1-6 above. Decorate this function with @nnx.jit.\n","4. Call this create_sharded_linear_model function within the Mesh context (using with mesh:).\n","5. Verify that the parameters (kernel and bias) of the returned model are now physically sharded JAX arrays by printing their .sharding attribute (this time from the JAX array value, not the metadata)."],"metadata":{"id":"6TpuQd7E8rT3"}},{"cell_type":"code","source":["# Exercise 3: Instructions Cell\n","\n","# 1. Reuse SimpleLinear (already defined above)\n","\n","# 2. Define a Mesh. For this exercise, let's assume we want to shard\n","# the 'model' dimension of SimpleLinear across all 8 devices.\n","# A (1, 8) mesh with axes ('data', 'model') would mean 'data' axis has size 1 (no data parallelism here)\n","# and 'model' axis has size 8.\n","mesh_devices_ex3 = np.array(jax.devices()).reshape((1, 8))\n","mesh_ex3 = Mesh(mesh_devices_ex3, axis_names=('data', 'model')) # or just ('model',) if using 1D mesh\n","\n","# 3. Implement the sharded initialization function\n","@nnx.jit # nnx.jit handles split/merge of NNX state for JAX transformations\n","def create_sharded_linear_model(rngs_for_creation, in_f, out_f):\n"," # Step 1: Instantiate the NNX module (parameters are created here, typically on default device initially)\n"," model = SimpleLinear(in_features=in_f, out_features=out_f, rngs=rngs_for_creation)\n","\n"," # Step 2: Extract the functional State PyTree\n"," state = nnx.state(model)\n","\n"," # Step 3: Extract the PartitionSpec PyTree from metadata\n"," # This uses the .sharding attributes we defined in SimpleLinear's __init__\n"," pspecs = nnx.spmd.get_partition_spec(state)\n","\n"," # Step 4: Apply sharding constraints to the State\n"," # This tells the JAX compiler the desired final layout for the parameters.\n"," # The actual resharding happens when this JITted function is executed.\n"," sharded_state = jax.lax.with_sharding_constraint(state, pspecs)\n","\n"," # Step 5: Update the original module object with the now sharded state\n"," nnx.update(model, sharded_state)\n","\n"," # Step 6: Return the model (which now contains sharded parameters)\n"," return model\n","\n","# 4. Call the function within the Mesh context\n","rngs_for_sharded_init = nnx.Rngs(params=jax.random.key(1)) # Use a different key\n","# TODO: Call create_sharded_linear_model within the mesh_ex3 context\n","with mesh_ex3:\n"," sharded_linear_model = create_sharded_linear_model(rngs_for_sharded_init, 128, 256)\n","\n","\n","# 5. Verify sharding of the actual JAX arrays\n","# The .value of an nnx.Param is the JAX array\n","print(\"\\n--- Verification after sharded initialization ---\")\n","print(\"Sharded Kernel's JAX array sharding:\", sharded_linear_model.kernel.value.sharding)\n","print(\"Sharded Bias's JAX array sharding:\", sharded_linear_model.bias.value.sharding)\n","\n","# Expected output for kernel: NamedSharding(mesh=..., spec=PartitionSpec(None, 'model'))\n","# Expected output for bias: NamedSharding(mesh=..., spec=PartitionSpec('model',))\n","# The mesh in NamedSharding should match mesh_ex3."],"metadata":{"id":"wwUwiQQ6758-"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Exercise 3: Solution"],"metadata":{"id":"D-oiZqRX9OsX"}},{"cell_type":"code","source":["# Exercise 3: Solution Cell\n","\n","# 1. Reuse SimpleLinear (already defined above)\n","\n","# 2. Define a Mesh.\n","# We want to shard the 'model' dimension. Let's use a 1x8 mesh,\n","# dedicating all 8 devices to the 'model' axis for this example.\n","# 'data' axis size 1 means parameters are replicated along it (which is trivial here).\n","mesh_devices_ex3 = np.array(jax.devices()).reshape((1, 8))\n","mesh_ex3 = Mesh(mesh_devices_ex3, axis_names=('data', 'model'))\n","\n","# 3. Implement the sharded initialization function\n","# Tell JAX that in_f and out_f are static arguments\n","@nnx.jit(static_argnums=(1, 2))\n","def create_sharded_linear_model(rngs_for_creation, in_f, out_f):\n"," # Step 1: Instantiate the NNX module. Params are created with metadata hints.\n"," # At this point inside JIT, they might be on a default device or abstract.\n"," print(f\"Inside JIT: Instantiating SimpleLinear({in_f}, {out_f})\")\n"," model = SimpleLinear(in_features=in_f, out_features=out_f, rngs=rngs_for_creation)\n","\n"," # Step 2: Extract the functional State PyTree. This is JAX-compatible.\n"," state = nnx.state(model)\n"," # print(f\"Inside JIT: Extracted state - Kernel sharding metadata: {state['kernel'].sharding}\")\n","\n","\n"," # Step 3: Extract the PartitionSpec PyTree from metadata.\n"," pspecs = nnx.spmd.get_partition_spec(state)\n"," # print(f\"Inside JIT: Extracted PartitionSpecs - Kernel PSpec: {pspecs['kernel']}\")\n","\n","\n"," # Step 4: Apply sharding constraints to the State.\n"," # This is a hint to XLA; the actual sharding occurs when data is materialized on devices.\n"," sharded_state = jax.lax.with_sharding_constraint(state, pspecs)\n"," # print(f\"Inside JIT: Applied sharding constraint. Kernel value sharding (if available): {getattr(sharded_state['kernel'].value, 'sharding', 'Not yet concrete')}\")\n","\n","\n"," # Step 5: Update the original module object with the (conceptually) sharded state.\n"," nnx.update(model, sharded_state)\n","\n"," # Step 6: Return the model.\n"," return model\n","\n","# 4. Call the function within the Mesh context\n","# The 'with mesh:' block provides the context for JAX to fulfill the sharding.\n","rngs_for_sharded_init = nnx.Rngs(params=jax.random.key(1))\n","print(f\"Executing create_sharded_linear_model within mesh: {mesh_ex3}\")\n","with mesh_ex3:\n"," sharded_linear_model = create_sharded_linear_model(rngs_for_sharded_init, 128, 256)\n","print(\"Sharded model created.\")\n","\n","# 5. Verify sharding of the actual JAX arrays\n","# After execution, the .value of nnx.Param should be a sharded JAX GlobalDeviceArray (GDA).\n","print(\"\\n--- Verification after sharded initialization ---\")\n","print(\"Sharded Kernel's JAX array (.value) sharding:\", sharded_linear_model.kernel.value.sharding)\n","print(\"Sharded Kernel's JAX array shape:\", sharded_linear_model.kernel.value.shape)\n","print(\"Sharded Bias's JAX array (.value) sharding:\", sharded_linear_model.bias.value.sharding)\n","print(\"Sharded Bias's JAX array shape:\", sharded_linear_model.bias.value.shape)\n","\n","# For kernel (128, 256) with P(None, 'model') on a ('data':1, 'model':8) mesh:\n","# The 'model' axis (size 8) shards the second dimension (256). So, 256/8 = 32.\n","# Each device buffer for the kernel should have shape (128, 32).\n","print(\"\\nKernel device buffers:\")\n","for i, db in enumerate(sharded_linear_model.kernel.value.addressable_shards):\n"," print(f\" Buffer {i}: Device={db.device}, Shape={db.data.shape}\")\n","jax.debug.visualize_array_sharding(sharded_linear_model.kernel.value)\n","\n","# For bias (256,) with P('model') on a ('data':1, 'model':8) mesh:\n","# The 'model' axis (size 8) shards the first dimension (256). So, 256/8 = 32.\n","# Each device buffer for the bias should have shape (32,).\n","print(\"\\nBias device buffers:\")\n","for i, db in enumerate(sharded_linear_model.bias.value.addressable_shards):\n"," print(f\" Buffer {i}: Device={db.device}, Shape={db.data.shape}\")\n","jax.debug.visualize_array_sharding(sharded_linear_model.bias.value)"],"metadata":{"id":"lkiTtdSf9SH3"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# Exercise 4: Sharding a Mini FeedForward Block\n","\n","Let's apply these concepts to a slightly more complex block, akin to a part of a Transformer's FeedForward network: LayerNorm -> Linear1 -> GELU -> Linear2. We'll focus on model parallelism for the Linear layers and LayerNorm parameters.\n","\n","**Instructions**:\n","\n","1. Define an NNXFeedForward module containing:\n"," - nnx.LayerNorm: Shard its scale and bias parameters along the 'model' axis (P('model')).\n"," - nnx.Linear (linear1): Kernel P(None, 'model'), bias P('model').\n"," - nnx.Linear (linear2): Kernel P('model', None) (Note the change for variety, sharding input features), bias P(None) (replicated, or P() if not sharding bias). Let's use P(None, 'model') for kernel2 as well to be consistent with typical FFN output sharding, and bias P('model').\n","2. Use the sharded initialization workflow (create_sharded_ffn_model function decorated with @nnx.jit) to initialize this NNXFeedForward module.\n","3. Define a 2D mesh, e.g., (2, 4) with axes ('data', 'model'). The 'model' axis will be used for sharding the parameters as defined.\n","4. Instantiate and verify the sharding of parameters within this NNXFeedForward model."],"metadata":{"id":"7HlJv7hX6Qjv"}},{"cell_type":"code","source":["# Exercise 4: Instructions Cell\n","\n","class NNXFeedForward(nnx.Module):\n"," def __init__(self, embed_dim: int, ff_dim: int, *, rngs: nnx.Rngs):\n"," key_param, key_dropout = rngs.fork_key('params'), rngs.fork_key('dropout') # Example if using dropout\n","\n"," # TODO: 1. Define LayerNorm and Linear layers with sharding metadata\n"," # LayerNorm: scale P('model'), bias P('model')\n"," # Linear1 (embed_dim -> ff_dim): kernel P(None, 'model'), bias P('model')\n"," # Linear2 (ff_dim -> embed_dim): kernel P(None, 'model'), bias P('model')\n"," # (Note: For Linear2 kernel P('model', None) would shard along input dim,\n"," # P(None, 'model') shards along output dim. Let's be consistent with typical model sharding.)\n","\n"," self.layernorm = nnx.LayerNorm(\n"," num_features=embed_dim,\n"," epsilon=1e-6,\n"," scale_init=nnx.with_metadata(nnx.initializers.ones, sharding=P('model')),\n"," bias_init=nnx.with_metadata(nnx.initializers.zeros, sharding=P('model')),\n"," rngs=rngs.fork('params') # LayerNorm takes rngs for its own init\n"," )\n","\n"," self.linear1 = SimpleLinear( # Reusing SimpleLinear for convenience\n"," in_features=embed_dim,\n"," out_features=ff_dim,\n"," rngs=rngs.fork('params') # Pass down Rngs\n"," )\n"," # Ensure SimpleLinear's sharding annotations are: kernel P(None, 'model'), bias P('model')\n","\n"," self.linear2 = SimpleLinear(\n"," in_features=ff_dim,\n"," out_features=embed_dim,\n"," rngs=rngs.fork('params')\n"," )\n"," # Ensure SimpleLinear's sharding annotations are: kernel P(None, 'model'), bias P('model')\n"," # If we wanted linear2 kernel to be P('model', None), we'd need to modify SimpleLinear\n"," # or create a new Linear variant. For now, let's assume SimpleLinear is P(None, 'model') for kernel.\n","\n","\n"," def __call__(self, x: jax.Array, training: bool = False):\n"," x_norm = self.layernorm(x)\n"," x_ff = nnx.gelu(self.linear1(x_norm))\n"," output = self.linear2(x_ff)\n"," return output\n","\n","# TODO: 2. Implement the sharded initialization function for NNXFeedForward\n","@nnx.jit(static_argnums=(1, 2))\n","def create_sharded_ffn_model(rngs_for_creation, embed_dim, ff_dim):\n"," model = NNXFeedForward(embed_dim=embed_dim, ff_dim=ff_dim, rngs=rngs_for_creation)\n"," state = nnx.state(model)\n"," pspecs = nnx.spmd.get_partition_spec(state)\n"," sharded_state = jax.lax.with_sharding_constraint(state, pspecs)\n"," nnx.update(model, sharded_state)\n"," return model\n","\n","# TODO: 3. Define a 2D mesh ('data', 'model'), e.g., (2, 4)\n","mesh_devices_ex4 = np.array(jax.devices()).reshape((2, 4))\n","mesh_ex4 = Mesh(mesh_devices_ex4, axis_names=('data', 'model')) # 'model' axis has size 4\n","\n","# TODO: 4. Instantiate and verify sharding\n","rngs_ffn_init = nnx.Rngs(0, params=jax.random.key(2)) # Explicitly create 'params' stream\n","with mesh_ex4:\n"," sharded_ffn_model = create_sharded_ffn_model(rngs_ffn_init, embed_dim=128, ff_dim=512)\n","\n","print(\"\\n--- FFN Model Parameter Sharding Verification ---\")\n","print(\"LayerNorm scale sharding:\", sharded_ffn_model.layernorm.scale.value.sharding)\n","print(\"LayerNorm bias sharding:\", sharded_ffn_model.layernorm.bias.value.sharding)\n","print(\"Linear1 kernel sharding:\", sharded_ffn_model.linear1.kernel.value.sharding)\n","print(\"Linear2 kernel sharding:\", sharded_ffn_model.linear2.kernel.value.sharding)\n","\n","# Example: LayerNorm scale is (embed_dim=128,). Sharded with P('model')\n","# Mesh 'model' axis has size 4. So, 128 / 4 = 32 elements per device.\n","print(f\"\\nLayerNorm scale ({sharded_ffn_model.layernorm.scale.value.shape}) device buffers:\")\n","for i, db in enumerate(sharded_ffn_model.layernorm.scale.value.addressable_shards):\n"," print(f\" Buffer {i}: Device={db.device}, Shape={db.data.shape}\") # Expected shape (32,)\n","\n","jax.debug.visualize_array_sharding(sharded_ffn_model.layernorm.scale.value)\n","\n","# Example: Linear1 kernel is (embed_dim=128, ff_dim=512). Sharded with P(None, 'model')\n","# Mesh 'model' axis has size 4. ff_dim (512) / 4 = 128.\n","# Expected shape on device (128, 128).\n","print(f\"\\nLinear1 kernel ({sharded_ffn_model.linear1.kernel.value.shape}) device buffers:\")\n","for i, db in enumerate(sharded_ffn_model.linear1.kernel.value.addressable_shards):\n"," print(f\" Buffer {i}: Device={db.device}, Shape={db.data.shape}\")\n","\n","jax.debug.visualize_array_sharding(sharded_ffn_model.linear1.kernel.value)"],"metadata":{"id":"Ge5lVThk9VoO"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Exercise 4: Solution"],"metadata":{"id":"ijQ8pGtT6ySy"}},{"cell_type":"code","source":["# Exercise 4: Solution Cell\n","\n","# We need to ensure SimpleLinear used inside NNXFeedForward has the correct sharding.\n","# Let's redefine it or make it flexible if needed.\n","# The SimpleLinear from Ex2 already has:\n","# kernel: P(None, 'model'), bias: P('model')\n","# This is suitable for our FFN.\n","\n","class NNXFeedForward(nnx.Module):\n"," def __init__(self, embed_dim: int, ff_dim: int, *, rngs: nnx.Rngs):\n"," # 1. Define LayerNorm and Linear layers with sharding metadata\n"," # LayerNorm takes an Rngs object directly if it has its own params/dropout.\n"," # For nnx.LayerNorm, scale_init/bias_init are callables that take a key.\n"," # We can use nnx.with_metadata with these initializers.\n"," self.layernorm = nnx.LayerNorm(\n"," num_features=embed_dim,\n"," epsilon=1e-6,\n"," # nnx.LayerNorm will call these initializers with a key from its rngs\n"," scale_init=nnx.with_metadata(nnx.initializers.ones, sharding=P('model')),\n"," bias_init=nnx.with_metadata(nnx.initializers.zeros, sharding=P('model')),\n"," rngs=rngs # Pass the Rngs for LayerNorm to use for its initializers\n"," )\n","\n"," # For SimpleLinear, we pass the Rngs object, and it extracts the 'params' key.\n"," self.linear1 = SimpleLinear(\n"," in_features=embed_dim,\n"," out_features=ff_dim,\n"," rngs=rngs # Each submodule gets its own Rngs\n"," )\n"," # SimpleLinear is defined with: kernel P(None, 'model'), bias P('model')\n","\n"," self.linear2 = SimpleLinear(\n"," in_features=ff_dim,\n"," out_features=embed_dim,\n"," rngs=rngs\n"," )\n"," # SimpleLinear is defined with: kernel P(None, 'model'), bias P('model')\n","\n"," def __call__(self, x: jax.Array, training: bool = False):\n"," x_norm = self.layernorm(x)\n"," x_ff = nnx.gelu(self.linear1(x_norm)) # linear1.__call__\n"," output = self.linear2(x_ff) # linear2.__call__\n"," return output\n","\n","# 2. Implement the sharded initialization function for NNXFeedForward\n","@nnx.jit(static_argnums=(1, 2))\n","def create_sharded_ffn_model(rngs_for_creation, embed_dim, ff_dim):\n"," print(f\"Inside JIT (FFN): Instantiating NNXFeedForward({embed_dim}, {ff_dim})\")\n"," model = NNXFeedForward(embed_dim=embed_dim, ff_dim=ff_dim, rngs=rngs_for_creation)\n"," state = nnx.state(model)\n"," pspecs = nnx.spmd.get_partition_spec(state)\n"," # print(f\"Inside JIT (FFN): PSPECS = {pspecs}\")\n"," sharded_state = jax.lax.with_sharding_constraint(state, pspecs)\n"," nnx.update(model, sharded_state)\n"," return model\n","\n","# 3. Define a 2D mesh ('data', 'model'), e.g., (2, 4)\n","# 'data' axis size 2, 'model' axis size 4.\n","mesh_devices_ex4 = np.array(jax.devices()).reshape((2, 4))\n","mesh_ex4 = Mesh(mesh_devices_ex4, axis_names=('data', 'model'))\n","\n","# 4. Instantiate and verify sharding\n","# Top-level Rngs for the FFN model creation\n","rngs_ffn_init = nnx.Rngs(0, params=jax.random.key(2)) # Explicitly create 'params' stream\n","\n","print(f\"Executing create_sharded_ffn_model within mesh: {mesh_ex4}\")\n","with mesh_ex4:\n"," sharded_ffn_model = create_sharded_ffn_model(rngs_ffn_init, embed_dim=128, ff_dim=512)\n","print(\"Sharded FFN model created.\")\n","\n","print(\"\\n--- FFN Model Parameter Sharding Verification ---\")\n","# LayerNorm parameters are typically 'scale' and 'bias'\n","print(\"LayerNorm scale sharding:\", sharded_ffn_model.layernorm.scale.value.sharding)\n","print(\"LayerNorm bias sharding:\", sharded_ffn_model.layernorm.bias.value.sharding)\n","print(\"Linear1 kernel sharding:\", sharded_ffn_model.linear1.kernel.value.sharding)\n","print(\"Linear1 bias sharding:\", sharded_ffn_model.linear1.bias.value.sharding)\n","print(\"Linear2 kernel sharding:\", sharded_ffn_model.linear2.kernel.value.sharding)\n","print(\"Linear2 bias sharding:\", sharded_ffn_model.linear2.bias.value.sharding)\n","\n","\n","# Example: LayerNorm scale is (embed_dim=128,). Sharded with P('model')\n","# Mesh 'model' axis has size 4. So, 128 / 4 = 32 elements per device.\n","print(f\"\\nLayerNorm scale ({sharded_ffn_model.layernorm.scale.value.shape}) device buffers:\")\n","for i, db in enumerate(sharded_ffn_model.layernorm.scale.value.addressable_shards):\n"," print(f\" Buffer {i}: Device={db.device}, Shape={db.data.shape}\") # Expected shape (32,)\n","\n","jax.debug.visualize_array_sharding(sharded_ffn_model.layernorm.scale.value)\n","\n","# Example: Linear1 kernel is (embed_dim=128, ff_dim=512). Sharded with P(None, 'model')\n","# Mesh 'model' axis has size 4. ff_dim (512) / 4 = 128.\n","# Expected shape on device (128, 128).\n","print(f\"\\nLinear1 kernel ({sharded_ffn_model.linear1.kernel.value.shape}) device buffers:\")\n","for i, db in enumerate(sharded_ffn_model.linear1.kernel.value.addressable_shards):\n"," print(f\" Buffer {i}: Device={db.device}, Shape={db.data.shape}\")\n","\n","jax.debug.visualize_array_sharding(sharded_ffn_model.linear1.kernel.value)"],"metadata":{"id":"K-75p3BqJpfM"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# Exercise 5: Sharding Input Data and a Mock Training Step\n","\n","A distributed training loop involves:\n","\n","* Sharding input data batches.\n","* A JIT-compiled training step that operates on the sharded model and sharded data.\n","* Gradients are computed, and optimizer updates are applied in a distributed manner.\n","\n","**Instructions**:\n","\n","1. Use the sharded_ffn_model and mesh_ex4 (2x4, ('data', 'model')) from Exercise 4.\n","2. Create a dummy input batch (e.g., NumPy array (batch_size=16, embed_dim=128)).\n","3. Define a NamedSharding to shard this batch along the 'data' axis of mesh_ex4 (i.e., P('data', None)).\n","4. Use jax.device_put to shard the input batch. Also create sharded dummy labels P('data').\n","5. Define a train_step function decorated with @nnx.jit. This function should:\n"," - Take the sharded model, an optimizer (e.g., nnx.Optimizer), sharded batch, and sharded labels as input.\n"," - Define a loss_fn that takes the stateful model, performs a forward pass, and computes a simple mean loss (e.g., using optax.softmax_cross_entropy_with_integer_labels).\n"," - Use nnx.value_and_grad(loss_fn)(model) to get loss and gradients.\n"," - Call optimizer.update(model, grads) to apply gradients.\n"," - Return the loss.\n","6. Create a simple nnx.Optimizer (e.g., Adam) for the sharded_ffn_model. Crucially, the optimizer state should also be sharded consistently with the parameters it optimizes. NNX's Optimizer typically handles this if created from a sharded model state.\n","7. Execute the train_step once with the sharded inputs and model within the mesh_ex4 context. Print the resulting loss. (The actual numerics of the loss are not important, focus on the setup)."],"metadata":{"id":"J-YgwZPtLfUy"}},{"cell_type":"code","source":["# Exercise 5: Instructions Cell\n","\n","# 1. Use sharded_ffn_model and mesh_ex4 from previous exercise.\n","# Ensure they are available in this cell's scope.\n","# If not, you might need to re-run parts of Ex 4 or redefine them here.\n","# For simplicity, let's assume sharded_ffn_model and mesh_ex4 are accessible.\n","if 'sharded_ffn_model' not in globals() or 'mesh_ex4' not in globals():\n"," print(\"Please re-run Exercise 4 to define sharded_ffn_model and mesh_ex4.\")\n"," # As a fallback for running this cell independently:\n"," _mesh_devices_ex4 = np.array(jax.devices()).reshape((2, 4))\n"," mesh_ex4 = Mesh(_mesh_devices_ex4, axis_names=('data', 'model'))\n"," _rngs_ffn_init = nnx.Rngs(params=jax.random.key(2), layernorm_params=jax.random.key(3),\n"," linear1_params=jax.random.key(4), linear2_params=jax.random.key(5))\n"," with mesh_ex4:\n"," sharded_ffn_model = create_sharded_ffn_model(_rngs_ffn_init, embed_dim=128, ff_dim=512)\n","\n","\n","# 2. Create a dummy input batch and labels\n","BATCH_SIZE = 16\n","EMBED_DIM = 128 # Must match sharded_ffn_model.layernorm.num_features\n","NUM_CLASSES = 10 # For dummy labels\n","\n","numpy_batch = np.random.rand(BATCH_SIZE, EMBED_DIM).astype(np.float32)\n","numpy_labels = np.random.randint(0, NUM_CLASSES, size=(BATCH_SIZE,)).astype(np.int32)\n","\n","# TODO: 3. Define NamedSharding for input batch (shard along 'data') and labels\n","# Batch: P('data', None)\n","# Labels: P('data')\n","batch_input_sharding = NamedSharding(mesh_ex4, P('data', None))\n","label_input_sharding = NamedSharding(mesh_ex4, P('data'))\n","\n","\n","# TODO: 4. Shard the input batch and labels using jax.device_put (within mesh context is best)\n","# This should ideally be done inside the loop or just before calling train_step,\n","# and within the mesh context if the jax.device_put itself needs that context for device assignment.\n","# For jax.device_put, the mesh context isn't strictly necessary if NamedSharding already has the mesh.\n","with mesh_ex4: # Good practice to do device_put within mesh context\n"," sharded_batch = jax.device_put(numpy_batch, batch_input_sharding)\n"," sharded_labels = jax.device_put(numpy_labels, label_input_sharding)\n","\n","print(\"Sharded batch sharding:\", sharded_batch.sharding)\n","print(\"Sharded labels sharding:\", sharded_labels.sharding)\n","\n","\n","# TODO: 5. Define the train_step function\n","@nnx.jit\n","def train_step(model: NNXFeedForward, optimizer: nnx.Optimizer, batch: jax.Array, labels: jax.Array):\n"," # Define loss_fn for nnx.value_and_grad\n"," # It operates on the stateful NNX model directly\n"," def loss_fn(mdl_stateful: NNXFeedForward):\n"," logits = mdl_stateful(batch, training=True) # Forward pass\n"," # For FFN, output is (BATCH_SIZE, EMBED_DIM). For classification, it would go to (BATCH_SIZE, NUM_CLASSES)\n"," # Let's assume our FFN output is used as logits for simplicity, though dimensions might not match num_classes.\n"," # We'll average logits to NUM_CLASSES for a dummy loss.\n"," # A real scenario would have a final classification layer.\n"," if logits.shape[-1] != NUM_CLASSES:\n"," # Crude way to make shapes match for dummy loss: average features to NUM_CLASSES channels\n"," logits_for_loss = jnp.mean(logits.reshape(logits.shape[0], -1, NUM_CLASSES), axis=1)\n"," if logits_for_loss.shape[0] == 0: # Handle BATCH_SIZE / num_data_devices = 0 case\n"," logits_for_loss = jnp.zeros((logits.shape[0], NUM_CLASSES))\n","\n"," else:\n"," logits_for_loss = logits\n","\n"," loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(logits_for_loss, labels))\n"," return loss\n","\n"," # nnx.value_and_grad handles model state splitting/merging\n"," loss_value, grads = nnx.value_and_grad(loss_fn)(model)\n","\n"," # Optimizer updates model parameters (and its own state) in-place (conceptually for NNX)\n"," # The actual update happens via nnx.update internally by the optimizer on the model's state.\n"," optimizer.update(model, grads) # This will update sharded_ffn_model's parameters\n","\n"," return loss_value\n","\n","\n","# TODO: 6. Create an nnx.Optimizer for the sharded_ffn_model\n","# The optimizer will manage the sharded parameters and its own sharded state.\n","# When optimizer is created from a model with sharded state, its own state (e.g. momentum)\n","# should also be sharded appropriately.\n","# For nnx.Optimizer, we pass the model itself.\n","optimizer = nnx.Optimizer(sharded_ffn_model, optax.adam(learning_rate=1e-3), wrt=nnx.Param)\n","\n","\n","# TODO: 7. Execute the train_step once within the mesh context.\n","with mesh_ex4:\n"," loss = train_step(sharded_ffn_model, optimizer, sharded_batch, sharded_labels)\n","\n","print(f\"\\nComputed loss: {loss}\")\n","# Also verify that parameters in sharded_ffn_model have been updated (not easily visible by value change without running more steps)\n","# but the optimizer.update call should have functioned on sharded grads and params."],"metadata":{"id":"nZFO7B9bBx3b"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Exercise 5: Solution"],"metadata":{"id":"2uwI0YmHMILj"}},{"cell_type":"code","source":["# Exercise 5: Solution Cell\n","\n","# 1. Use sharded_ffn_model and mesh_ex4 from previous exercise.\n","# (Assuming they are available from Exercise 4 execution)\n","if 'sharded_ffn_model' not in globals() or 'mesh_ex4' not in globals():\n"," print(\"Fallback: Re-defining sharded_ffn_model and mesh_ex4 for Exercise 5.\")\n"," _mesh_devices_ex4_sol = np.array(jax.devices()).reshape((2, 4))\n"," mesh_ex4 = Mesh(_mesh_devices_ex4_sol, axis_names=('data', 'model'))\n"," _rngs_ffn_init_sol = nnx.Rngs(\n"," params=jax.random.key(2),\n"," layernorm_params=jax.random.key(30),\n"," linear1_params=jax.random.key(40),\n"," linear2_params=jax.random.key(50)\n"," )\n"," with mesh_ex4:\n"," sharded_ffn_model = create_sharded_ffn_model(_rngs_ffn_init_sol, embed_dim=128, ff_dim=512)\n"," print(\"Fallback definitions complete.\")\n","\n","\n","# 2. Create a dummy input batch and labels\n","BATCH_SIZE = 16 # Global batch size\n","EMBED_DIM = 128 # Must match sharded_ffn_model's input embed_dim\n","# Output of FFN is also EMBED_DIM. For classification, a final Linear layer to NUM_CLASSES is needed.\n","# For this exercise, we'll make a dummy adjustment if NUM_CLASSES doesn't match.\n","NUM_CLASSES = 10 # Example number of classes for dummy loss\n","\n","numpy_batch = np.random.rand(BATCH_SIZE, EMBED_DIM).astype(np.float32)\n","# Labels for classification, shape (BATCH_SIZE,)\n","numpy_labels = np.random.randint(0, NUM_CLASSES, size=(BATCH_SIZE,)).astype(np.int32)\n","\n","# 3. Define NamedSharding for input batch and labels\n","# Batch shape (BATCH_SIZE, EMBED_DIM), sharded P('data', None) -> ('data' shards BATCH_SIZE)\n","batch_input_sharding = NamedSharding(mesh_ex4, P('data', None))\n","# Labels shape (BATCH_SIZE,), sharded P('data') -> ('data' shards BATCH_SIZE)\n","label_input_sharding = NamedSharding(mesh_ex4, P('data'))\n","\n","# 4. Shard the input batch and labels using jax.device_put\n","# This is typically done inside the training loop for each new batch.\n","# Performing it within the mesh context ensures devices align if mesh is complex.\n","with mesh_ex4:\n"," sharded_batch = jax.device_put(numpy_batch, batch_input_sharding)\n"," sharded_labels = jax.device_put(numpy_labels, label_input_sharding)\n","\n","print(\"Sharded batch object:\", sharded_batch)\n","print(\"Sharded batch sharding:\", sharded_batch.sharding)\n","# With ('data':2, 'model':4) mesh and P('data', None) for (16, 128) batch:\n","# 'data' axis (size 2) shards dim 0 (16). 16/2 = 8.\n","# Each device buffer will have shape (8, 128).\n","print(f\"Sharded batch per-device shape: {sharded_batch.addressable_shards[0].data.shape}\")\n","\n","print(\"\\nSharded labels object:\", sharded_labels)\n","print(\"Sharded labels sharding:\", sharded_labels.sharding)\n","# With ('data':2, 'model':4) mesh and P('data') for (16,) labels:\n","# 'data' axis (size 2) shards dim 0 (16). 16/2 = 8.\n","# Each device buffer will have shape (8,).\n","print(f\"Sharded labels per-device shape: {sharded_labels.addressable_shards[0].data.shape}\")\n","\n","\n","# 5. Define the train_step function\n","@nnx.jit\n","def train_step(model: NNXFeedForward, optimizer: nnx.Optimizer, batch: jax.Array, labels: jax.Array):\n"," # This loss_fn is defined inside train_step to capture 'batch' and 'labels'\n"," # It takes the stateful NNX model as its argument.\n"," def loss_fn(mdl_stateful: NNXFeedForward):\n"," # Forward pass through the model\n"," logits = mdl_stateful(batch, training=True) # model's __call__\n","\n"," # The FFN output is (BATCH_SIZE_PER_DEVICE, EMBED_DIM).\n"," # For a typical classification loss, we'd need (BATCH_SIZE_PER_DEVICE, NUM_CLASSES).\n"," # This is a placeholder to make the loss calculation work.\n"," # In a real model, you'd have a final nnx.Linear layer projecting to NUM_CLASSES.\n"," current_out_features = logits.shape[-1]\n"," if current_out_features != NUM_CLASSES:\n"," # Simple (and somewhat arbitrary) projection for the sake of the exercise\n"," # This ensures the logits match the label dimensions for softmax_cross_entropy\n"," # A more realistic approach for a final layer would be needed in practice.\n"," # We are on a per-device shard of the batch here.\n"," # print(f\"Logits shape before adjustment: {logits.shape}\") # For debugging JIT prints\n"," # This projection is not well-posed for learning but allows loss computation.\n"," if logits.shape[0] > 0 : # check if batch per device is not zero\n"," # Create a dummy projection matrix on the fly (not trained)\n"," # This is just to make the shapes work for the loss function.\n"," # This is NOT how you would typically do a projection in a real model.\n"," dummy_projection_key = jax.random.key(99) # Fixed key for reproducibility inside JIT\n"," projection_matrix = jax.random.normal(dummy_projection_key, (current_out_features, NUM_CLASSES))\n"," projected_logits = logits @ projection_matrix\n"," else: # if batch per device is zero, create zero logits\n"," projected_logits = jnp.zeros((logits.shape[0], NUM_CLASSES), dtype=logits.dtype)\n","\n"," else:\n"," projected_logits = logits\n","\n"," # Compute loss. JAX automatically handles SPMD for this calculation\n"," # if inputs (logits, labels) are sharded. Gradients will also be sharded,\n"," # and all-reduce for gradients across data-parallel dimension is inserted by JAX.\n"," loss = jnp.mean(optax.softmax_cross_entropy_with_integer_labels(projected_logits, labels))\n"," return loss\n","\n"," # Compute loss and gradients. nnx.value_and_grad handles splitting/merging model state.\n"," loss_value, grads = nnx.value_and_grad(loss_fn)(model)\n","\n"," # Apply gradients. The optimizer updates the model's parameters in-place (conceptually).\n"," # If model parameters are sharded, optimizer state (e.g., momentum) is also sharded,\n"," # and updates are applied in a distributed manner.\n"," optimizer.update(model, grads)\n","\n"," return loss_value\n","\n","# 6. Create an nnx.Optimizer for the sharded_ffn_model\n","# Pass the sharded model to the optimizer and specify what to differentiate wrt (nnx.Param).\n","# NNX ensures that the optimizer's state (like Adam's moments) is initialized\n","# with the same sharding as the parameters.\n","optimizer = nnx.Optimizer(sharded_ffn_model, optax.adam(learning_rate=1e-3), wrt=nnx.Param)\n","print(f\"\\nOptimizer created. Optimizer state sharding should match param sharding.\")\n","# You can inspect optimizer.state to see its structure and (if concrete) sharding.\n","\n","\n","# 7. Execute the train_step once within the mesh context.\n","# All inputs (model params via optimizer, batch, labels) are sharded.\n","# JAX/XLA compiles the train_step for SPMD execution.\n","print(f\"\\nExecuting train_step within mesh: {mesh_ex4}\")\n","with mesh_ex4:\n"," # JIT compilation happens on the first call\n"," loss = train_step(sharded_ffn_model, optimizer, sharded_batch, sharded_labels)\n"," # Subsequent calls would be faster\n"," # loss_2 = train_step(sharded_ffn_model, optimizer, sharded_batch, sharded_labels)\n","\n","\n","print(f\"\\nComputed loss from sharded train_step: {loss}\")\n","# print(f\"Computed loss (2nd step): {loss_2}\")\n","# Note: The actual loss value isn't meaningful due to dummy projection and random data.\n","# The key is that the distributed computation ran."],"metadata":{"id":"sOfkAONbMLYJ"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# Exercise 6: Preparing for Sharded Checkpointing\n","\n","To save/load huge sharded models without OOM, libraries like Orbax are used. Orbax needs to know the target NamedSharding for each parameter to restore it correctly. nnx.spmd.get_named_sharding helps generate this.\n","\n","**Instructions**:\n","\n","1. Use the sharded_ffn_model and mesh_ex4 from previous exercises.\n","2. Get the state structure of the model using nnx.state(sharded_ffn_model).\n","3. Use nnx.spmd.get_named_sharding(state_struct, mesh_ex4) to generate the PyTree of NamedSharding objects.\n","4. Print the generated NamedSharding PyTree for a few parameters (e.g., LayerNorm scale, Linear1 kernel) to inspect them. This output is what you'd pass to Orbax for restoration."],"metadata":{"id":"kScSd-tPNCST"}},{"cell_type":"code","source":["# Exercise 6: Instructions Cell\n","\n","# 1. Use sharded_ffn_model and mesh_ex4\n","if 'sharded_ffn_model' not in globals() or 'mesh_ex4' not in globals():\n"," print(\"Fallback: Re-defining sharded_ffn_model and mesh_ex4 for Exercise 6.\")\n"," _mesh_devices_ex6 = np.array(jax.devices()).reshape((2, 4))\n"," mesh_ex4 = Mesh(_mesh_devices_ex6, axis_names=('data', 'model')) # Re-assign to global mesh_ex4 if needed\n"," _rngs_ffn_init_ex6 = nnx.Rngs(\n"," params=jax.random.key(20),\n"," layernorm_params=jax.random.key(300),\n"," linear1_params=jax.random.key(400),\n"," linear2_params=jax.random.key(500)\n"," )\n"," with mesh_ex4: # Ensure mesh_ex4 is correctly used\n"," sharded_ffn_model = create_sharded_ffn_model(_rngs_ffn_init_ex6, embed_dim=128, ff_dim=512)\n","\n","\n","# TODO: 2. Get the state structure of the model\n","# This can be from the concrete sharded model, or an abstract model from nnx.eval_shape\n","model_state_structure = nnx.state(sharded_ffn_model)\n","\n","# TODO: 3. Generate the target NamedSharding PyTree\n","# This uses the .sharding PartitionSpec metadata stored in the model's state\n","# and combines it with the provided mesh.\n","target_named_shardings_tree = nnx.spmd.get_named_sharding(model_state_structure, mesh_ex4)\n","\n","\n","# TODO: 4. Print some of the generated NamedSharding objects\n","print(\"\\n--- Target NamedShardings for Checkpointing ---\")\n","print(\"NamedSharding for LayerNorm scale:\")\n","nnx.display(target_named_shardings_tree['layernorm']['scale'])\n","\n","print(\"\\nNamedSharding for Linear1 kernel:\")\n","nnx.display(target_named_shardings_tree['linear1']['kernel'])\n","\n","print(\"\\nNamedSharding for Linear2 bias:\")\n","nnx.display(target_named_shardings_tree['linear2']['bias'])\n","\n","# These NamedSharding objects tell a checkpointing library (like Orbax)\n","# exactly how each parameter should be laid out across the devices upon restoration."],"metadata":{"id":"iFu-boNVMQi1"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Exercise 6: Solution"],"metadata":{"id":"UEjHsc7ROTkr"}},{"cell_type":"code","source":["# Exercise 6: Solution Cell\n","\n","# 1. Use sharded_ffn_model and mesh_ex4\n","if 'sharded_ffn_model' not in globals() or 'mesh_ex4' not in globals():\n"," print(\"Fallback: Re-defining sharded_ffn_model and mesh_ex4 for Exercise 6.\")\n"," # Ensure mesh_ex4 is the one associated with sharded_ffn_model\n"," _mesh_devices_ex6_sol = np.array(jax.devices()).reshape((2, 4))\n"," mesh_ex4 = Mesh(_mesh_devices_ex6_sol, axis_names=('data', 'model')) # Make sure this is the correct mesh\n"," _rngs_ffn_init_ex6_sol = nnx.Rngs(\n"," params=jax.random.key(201),\n"," layernorm_params=jax.random.key(301),\n"," linear1_params=jax.random.key(401),\n"," linear2_params=jax.random.key(501)\n"," )\n"," with mesh_ex4: # Use the correct mesh_ex4\n"," sharded_ffn_model = create_sharded_ffn_model(_rngs_ffn_init_ex6_sol, embed_dim=128, ff_dim=512)\n"," print(\"Fallback definitions complete for Ex6.\")\n","\n","\n","# 2. Get the state structure of the model\n","# This PyTree has the same structure as the model's parameters,\n","# and each leaf (parameter state) contains the .sharding PartitionSpec metadata.\n","model_state_structure = nnx.state(sharded_ffn_model)\n","\n","# 3. Generate the target NamedSharding PyTree\n","# nnx.spmd.get_named_sharding combines the PartitionSpec from metadata\n","# with the provided 'mesh_ex4' to create a full NamedSharding object for each parameter.\n","target_named_shardings_tree = nnx.spmd.get_named_sharding(model_state_structure, mesh_ex4)\n","\n","# 4. Print some of the generated NamedSharding objects\n","print(\"\\n--- Target NamedShardings for Checkpointing (from nnx.spmd.get_named_sharding) ---\")\n","\n","# For LayerNorm's scale parameter\n","# Original sharding metadata (PartitionSpec): P('model')\n","# Mesh: ('data':2, 'model':4)\n","# Expected NamedSharding: NamedSharding(mesh=mesh_ex4, spec=P('model'))\n","print(\"\\nNamedSharding for LayerNorm scale:\")\n","# The path to scale might be model_state_structure['layernorm']['scale'].sharding\n","# So target_named_shardings_tree should have a similar path.\n","nnx.display(target_named_shardings_tree['layernorm']['scale'])\n","\n","\n","# For Linear1's kernel parameter\n","# Original sharding metadata (PartitionSpec): P(None, 'model')\n","# Expected NamedSharding: NamedSharding(mesh=mesh_ex4, spec=P(None, 'model'))\n","print(\"\\nNamedSharding for Linear1 kernel:\")\n","nnx.display(target_named_shardings_tree['linear1']['kernel'])\n","\n","\n","# For Linear2's bias parameter\n","# Original sharding metadata (PartitionSpec): P('model')\n","# Expected NamedSharding: NamedSharding(mesh=mesh_ex4, spec=P('model'))\n","print(\"\\nNamedSharding for Linear2 bias:\")\n","nnx.display(target_named_shardings_tree['linear2']['bias'])\n","\n","# This target_named_shardings_tree would be passed to something like\n","# orbax.checkpoint.StandardRestore(target_named_shardings_tree)\n","# to tell Orbax how to reconstruct the sharded arrays when loading from a checkpoint."],"metadata":{"id":"LIjPJE2kOXBv"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Conclusion & Feedback\n","\n","**Congratulations on completing the exercises!**\n","\n","You've now practiced:\n","\n","- Using JAX sharding primitives (Mesh, PartitionSpec, NamedSharding, device_put).\n","- Annotating Flax NNX modules with sharding metadata.\n","- The critical sharded initialization workflow to avoid OOM errors.\n","- Applying these concepts to a multi-layer NNX module.\n","- Sharding input data for a distributed training step.\n","- Preparing sharding information for distributed checkpointing.\n","\n","These are foundational skills for scaling up your JAX and Flax NNX models.\n","\n","Further Learning:\n","\n","- JAX Documentation: https://jax.readthedocs.io/\n","- Flax NNX Documentation: https://flax.readthedocs.io/en/latest/nnx/index.html\n","- JAX SPMD Guide: https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html\n","- Orbax (for checkpointing): https://orbax.readthedocs.io/\n","\n","Please send us feedback at https://goo.gle/jax-training-feedback"],"metadata":{"id":"Ju7fgC0ZO4iJ"}},{"cell_type":"code","source":[],"metadata":{"id":"6Z14x9B-ObBK"},"execution_count":null,"outputs":[]}]} \ No newline at end of file diff --git a/docs/learning_jax/code-exercises/11 - Optax optimizers.ipynb b/docs/learning_jax/code-exercises/11 - Optax optimizers.ipynb new file mode 100644 index 0000000..83917fe --- /dev/null +++ b/docs/learning_jax/code-exercises/11 - Optax optimizers.ipynb @@ -0,0 +1 @@ +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[{"file_id":"1iMaqn7Cm7EveArzxC3ZfZ14eMHjxvMw6","timestamp":1755114116822}],"toc_visible":true,"authorship_tag":"ABX9TyOQ9kz0yLzxKYJMmWBUFGp7"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","source":["# Optax with Flax NNX: Exercises for PyTorch Users\n","\n","This Colab notebook contains a series of exercises designed to help you, a PyTorch user, get hands-on experience with Optax, the primary optimization library in the JAX ecosystem, specifically for training Flax NNX models.\n","We will cover everything from the basics of setting up an optimizer to advanced techniques like learning rate scheduling, per-parameter optimization, and sharding for distributed training.\n","\n","## Setup\n","First, let's install the necessary libraries and set up a simulated multi-device environment. We'll use chex to simulate having 8 CPU devices, which will allow us to explore distributed training concepts without needing multiple physical GPUs/TPUs."],"metadata":{"id":"57CCWeofdudt"}},{"cell_type":"code","execution_count":null,"metadata":{"id":"y6RKMYI3dsYy"},"outputs":[],"source":["!pip install -Uq flax jax optax chex\n","\n","import jax\n","import jax.numpy as jnp\n","from jax.sharding import Mesh, PartitionSpec, NamedSharding\n","import chex\n","import optax\n","from flax import nnx\n","\n","# Simulate an environment with 8 CPU devices for sharding exercises\n","try:\n"," chex.set_n_cpu_devices(8)\n","except RuntimeError as e:\n"," print(f\"Could not set n_cpu_devices: {e}\")\n"," print(\"Sharding exercises may not work as intended. Continuing anyway.\")\n","\n","# Helper to check available devices\n","print(f\"JAX is running on: {jax.default_backend()}\")\n","print(f\"Number of available devices: {jax.device_count()}\")\n","print(f\"Device details: {jax.devices()}\")"]},{"cell_type":"markdown","source":["## Exercise 1: The Basic Training Loop\n","\n","**Concept:** This exercise covers the fundamental workflow of training a Flax NNX model with Optax. You will:\n","\n","1. Define a simple MLP model using flax.nnx.Module.\n","2. Instantiate the model and a basic optax.adam optimizer using flax.nnx.Optimizer.\n","3. Write a Mean Squared Error (MSE) loss function.\n","4. Create a complete, JIT-compiled training step function that takes the model and optimizer as arguments, calculates the loss, computes gradients using flax.nnx.value_and_grad, and updates the model parameters using optimizer.update(model, grads).\n","\n","This process mirrors the standard \"instantiate, calculate loss, backpropagate, step\" cycle in PyTorch but introduces the JAX/Optax equivalents: nnx.Optimizer, nnx.value_and_grad, and optimizer.update().\n","\n","### Instructions\n","\n","Complete the TODO sections in the following code cell to implement the basic training loop."],"metadata":{"id":"bjBKyrSyeiTC"}},{"cell_type":"code","source":["# @title Exercise 1: Implement the Basic Training Loop\n","import jax\n","import jax.numpy as jnp\n","import optax\n","from flax import nnx\n","from typing import Sequence\n","\n","# 1. Define the Model\n","class SimpleMLP(nnx.Module):\n"," \"\"\"A simple Multi-Layer Perceptron.\"\"\"\n"," def __init__(self, features: Sequence[int], *, rngs: nnx.Rngs):\n"," self.layers = []\n"," for i in range(len(features) - 1):\n"," self.layers.append(nnx.Linear(features[i], features[i+1], rngs=rngs))\n"," if i < len(features) - 2:\n"," self.layers.append(nnx.relu)\n","\n"," def __call__(self, x: jax.Array):\n"," for layer in self.layers:\n"," x = layer(x)\n"," return x\n","\n","# 2. Define the Loss Function\n","def mse_loss(model: SimpleMLP, x_batch: jax.Array, y_batch: jax.Array) -> jax.Array:\n"," \"\"\"Calculates the Mean Squared Error loss.\"\"\"\n"," # TODO: Get predictions from the model and calculate the MSE.\n"," # Hint: The model is callable, e.g., model(x_batch).\n"," # YOUR CODE HERE\n"," return loss\n","\n","# 3. Define the Training Step\n","@nnx.jit\n","def train_step(model: SimpleMLP, optimizer: nnx.Optimizer, x_batch: jax.Array, y_batch: jax.Array):\n"," \"\"\"Performs a single training step.\"\"\"\n"," # TODO: Use nnx.value_and_grad to get both the loss and the gradients.\n"," # You'll need a loss function closure that takes only the model as an argument.\n"," def loss_fn_for_grad(model_to_train):\n"," return mse_loss(model_to_train, x_batch, y_batch)\n","\n"," loss_val, grads = # YOUR CODE HERE\n","\n"," # TODO: Update the optimizer with the gradients.\n"," # YOUR CODE HERE\n","\n"," # The optimizer's state is modified in-place by update(), but under jit,\n"," # we must return it to get the new state out.\n"," return model, optimizer, loss_val\n","\n","# --- Boilerplate for running the exercise ---\n","# Create dummy data\n","key = jax.random.key(42)\n","key_model, key_data = jax.random.split(key)\n","din, dmid, dout = 10, 20, 5\n","x_dummy = jax.random.normal(key_data, (32, din))\n","y_dummy = jax.random.normal(key_data, (32, dout))\n","\n","# Instantiate model and optimizer\n","model = SimpleMLP(features=[din, dmid, dout], rngs=nnx.Rngs(key_model))\n","opt = optax.adam(learning_rate=1e-3)\n","optimizer = nnx.Optimizer(model, opt, wrt=nnx.Param)\n","\n","# Training Loop\n","print(\"Starting basic training loop...\")\n","for i in range(101):\n"," optimizer, loss = train_step(optimizer, x_dummy, y_dummy)\n"," if i % 20 == 0:\n"," # The .value attribute is used to get the raw value from a State variable\n"," print(f\"Step {optimizer.step.value}, Loss: {loss:.4f}\")\n","print(\"Basic training loop finished.\")\n","# Verify the model parameters have been updated\n","assert optimizer.step.value == 101"],"metadata":{"id":"JgWREm3-eT_B"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Solution 1\n","import jax\n","import jax.numpy as jnp\n","import optax\n","from flax import nnx\n","from typing import Sequence\n","\n","# 1. Define the Model\n","class SimpleMLP(nnx.Module):\n"," \"\"\"A simple Multi-Layer Perceptron.\"\"\"\n"," def __init__(self, features: Sequence[int], *, rngs: nnx.Rngs):\n"," self.layers = []\n"," for i in range(len(features) - 1):\n"," self.layers.append(nnx.Linear(features[i], features[i+1], rngs=rngs))\n"," if i < len(features) - 2:\n"," self.layers.append(nnx.relu)\n","\n"," def __call__(self, x: jax.Array):\n"," for layer in self.layers:\n"," x = layer(x)\n"," return x\n","\n","# 2. Define the Loss Function\n","def mse_loss(model: SimpleMLP, x_batch: jax.Array, y_batch: jax.Array) -> jax.Array:\n"," \"\"\"Calculates the Mean Squared Error loss.\"\"\"\n"," predictions = model(x_batch)\n"," loss = jnp.mean((predictions - y_batch) ** 2)\n"," return loss\n","\n","# 3. Define the Training Step\n","@nnx.jit\n","def train_step(model: SimpleMLP, optimizer: nnx.Optimizer, x_batch: jax.Array, y_batch: jax.Array):\n"," \"\"\"Performs a single training step.\"\"\"\n"," # A closure to capture the current batch of data\n"," def loss_fn_for_grad(model_to_train: SimpleMLP):\n"," return mse_loss(model_to_train, x_batch, y_batch)\n","\n"," # Compute loss and gradients\n"," loss_val, grads = nnx.value_and_grad(loss_fn_for_grad)(model)\n","\n"," # Update the optimizer's state and model parameters\n"," optimizer.update(model, grads)\n","\n"," return model, optimizer, loss_val\n","\n","# --- Boilerplate for running the exercise ---\n","# Create dummy data\n","key = jax.random.key(42)\n","key_model, key_data = jax.random.split(key)\n","din, dmid, dout = 10, 20, 5\n","x_dummy = jax.random.normal(key_data, (32, din))\n","y_dummy = jax.random.normal(key_data, (32, dout))\n","\n","# Instantiate model and optimizer\n","model = SimpleMLP(features=[din, dmid, dout], rngs=nnx.Rngs(key_model))\n","opt = optax.adam(learning_rate=1e-3)\n","optimizer = nnx.Optimizer(model, opt, wrt=nnx.Param)\n","\n","# Training Loop\n","print(\"Starting basic training loop...\")\n","for i in range(101):\n"," optimizer, loss = train_step(optimizer, x_dummy, y_dummy)\n"," if i % 20 == 0:\n"," print(f\"Step {optimizer.step.value}, Loss: {loss:.4f}\")\n","print(\"Basic training loop finished.\")\n","assert optimizer.step.value == 101"],"metadata":{"id":"KZ0wslm9fSW5"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Exercise 2: Composing Gradient Transformations\n","\n","**Concept:** A core philosophy of Optax is composability. Instead of monolithic optimizers, Optax provides small, chainable \"gradient transformations.\" This exercise demonstrates how to build a custom optimization pipeline by chaining multiple transformations together.\n","\n","You will add gradient clipping and weight decay to the Adam optimizer, creating a more robust optimization rule. This is analogous to combining features that might be built-in flags in a PyTorch optimizer, but here you explicitly build the chain.\n","\n","### Instructions\n","\n","Complete the TODO section to create a chained Optax transformation."],"metadata":{"id":"exY9yGJ5faqc"}},{"cell_type":"code","source":["# @title Exercise 2: Build a Chained Optimizer\n","import jax\n","import jax.numpy as jnp\n","import optax\n","from flax import nnx\n","\n","# --- Using the same model and data setup from Exercise 1 ---\n","key = jax.random.key(42)\n","key_model, key_data = jax.random.split(key)\n","din, dmid, dout = 10, 20, 5\n","x_dummy = jax.random.normal(key_data, (32, din))\n","y_dummy = jax.random.normal(key_data, (32, dout))\n","\n","model_chained = SimpleMLP(features=[din, dmid, dout], rngs=nnx.Rngs(key_model))\n","\n","# Define hyperparameters\n","learning_rate = 1e-3\n","max_grad_norm = 1.0\n","weight_decay = 1e-4\n","\n","# TODO: Create a chained Optax transformation.\n","# The desired order is:\n","# 1. Clip gradients by their global norm (optax.clip_by_global_norm).\n","# 2. Add weight decay (optax.add_decayed_weights).\n","# 3. Apply Adam optimizer updates (optax.adam).\n","# Hint: Use optax.chain([...])\n","opt_chained = optax.chain(\n"," # YOUR CODE HERE\n",")\n","\n","\n","# --- Boilerplate for running the exercise ---\n","# The train_step and mse_loss from Exercise 1 can be reused directly!\n","optimizer_chained = nnx.Optimizer(model_chained, opt_chained, wrt=nnx.Param)\n","\n","print(\"Starting training with chained optimizer...\")\n","for i in range(101):\n"," model_chained, optimizer_chained, loss = train_step(model_chained, optimizer_chained, x_dummy, y_dummy)\n"," if i % 20 == 0:\n"," print(f\"Step {optimizer_chained.step.value}, Loss: {loss:.4f}\")\n","print(\"Chained optimizer training finished.\")"],"metadata":{"id":"UkMUe4wMfW8a"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Solution 2\n","import jax\n","import jax.numpy as jnp\n","import optax\n","from flax import nnx\n","\n","# --- Using the same model and data setup from Exercise 1 ---\n","key = jax.random.key(42)\n","key_model, key_data = jax.random.split(key)\n","din, dmid, dout = 10, 20, 5\n","x_dummy = jax.random.normal(key_data, (32, din))\n","y_dummy = jax.random.normal(key_data, (32, dout))\n","\n","model_chained = SimpleMLP(features=[din, dmid, dout], rngs=nnx.Rngs(key_model))\n","\n","# Define hyperparameters\n","max_grad_norm = 1.0\n","weight_decay = 1e-4\n","learning_rate = 1e-3\n","\n","# Create a chained Optax transformation\n","opt_chained = optax.chain(\n"," optax.clip_by_global_norm(max_grad_norm),\n"," optax.add_decayed_weights(weight_decay),\n"," optax.adam(learning_rate)\n",")\n","\n","# --- Boilerplate for running the exercise ---\n","# The train_step and mse_loss from Exercise 1 can be reused directly!\n","optimizer_chained = nnx.Optimizer(model_chained, opt_chained, wrt=nnx.Param)\n","\n","print(\"Starting training with chained optimizer...\")\n","for i in range(101):\n"," model_chained, optimizer_chained, loss = train_step(model_chained, optimizer_chained, x_dummy, y_dummy)\n"," if i % 20 == 0:\n"," print(f\"Step {optimizer_chained.step.value}, Loss: {loss:.4f}\")\n","print(\"Chained optimizer training finished.\")"],"metadata":{"id":"hKTMXfwNf6Bi"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Exercise 3: Learning Rate Scheduling\n","\n","**Concept:** Dynamically adjusting the learning rate during training is a crucial technique. In Optax, you don't use an external scheduler.step() like in PyTorch. Instead, the schedule is baked directly into the optimizer definition.\n","\n","This exercise asks you to create a learning rate schedule and pass it to your optimizer. Optax will handle the updates automatically at each step. You will implement a warmup-cosine-decay schedule, a very common and effective schedule.\n","\n","### Instructions\n","Complete the TODO sections to define a learning rate schedule and use it in an Adam optimizer."],"metadata":{"id":"Vub5qv9Ll1qk"}},{"cell_type":"code","source":["# @title Exercise 3: Implement a Learning Rate Schedule\n","import jax\n","import jax.numpy as jnp\n","import optax\n","from flax import nnx\n","import matplotlib.pyplot as plt\n","\n","# --- Using the same model and data setup from Exercise 1 ---\n","key = jax.random.key(42)\n","key_model, key_data = jax.random.split(key)\n","din, dmid, dout = 10, 20, 5\n","x_dummy = jax.random.normal(key_data, (32, din))\n","y_dummy = jax.random.normal(key_data, (32, dout))\n","\n","model_scheduled = SimpleMLP(features=[din, dmid, dout], rngs=nnx.Rngs(key_model))\n","\n","# --- Scheduling Hyperparameters ---\n","total_training_steps = 500\n","warmup_fraction = 0.1\n","peak_lr = 1e-3\n","final_lr = 1e-5\n","\n","# TODO: Define a warmup-cosine-decay learning rate schedule.\n","# Hint: Use optax.warmup_cosine_decay_schedule.\n","# It needs an initial value, a peak value, warmup steps, and decay steps.\n","warmup_steps = # YOUR CODE HERE\n","decay_steps = # YOUR CODE HERE\n","\n","lr_schedule_fn = optax.warmup_cosine_decay_schedule(\n"," init_value=0.0,\n"," peak_value=peak_lr,\n"," warmup_steps=warmup_steps,\n"," decay_steps=decay_steps,\n"," end_value=final_lr\n",")\n","\n","\n","# TODO: Create an Adam optimizer that uses this schedule.\n","# Hint: Simply pass the schedule function as the `learning_rate` argument.\n","opt_scheduled = # YOUR CODE HERE\n","\n","\n","# --- Boilerplate for running the exercise ---\n","optimizer_scheduled = nnx.Optimizer(model_scheduled, opt_scheduled, wrt=nnx.Param)\n","\n","# Training Loop\n","print(\"Starting training with scheduled LR...\")\n","lrs = []\n","for i in range(total_training_steps):\n"," # The LR is updated automatically inside train_step\n"," model_scheduled, optimizer_scheduled, loss = train_step(model_scheduled, optimizer_scheduled, x_dummy, y_dummy)\n"," # We can extract the current LR for plotting\n"," # Note: This requires the optimizer state to be on the host.\n"," # In a real scenario, you might not check this every step.\n"," current_lr = lr_schedule_fn(optimizer_scheduled.step.value)\n"," lrs.append(current_lr)\n"," if i % 50 == 0:\n"," print(f\"Step {optimizer_scheduled.step.value}, Loss: {loss:.5f}, LR: {current_lr:.6f}\")\n","print(\"Scheduled LR training finished.\")\n","\n","# Plot the learning rate over time\n","plt.figure(figsize=(10, 4))\n","plt.plot(lrs)\n","plt.title(\"Learning Rate Schedule\")\n","plt.xlabel(\"Training Step\")\n","plt.ylabel(\"Learning Rate\")\n","plt.grid(True)\n","plt.show()"],"metadata":{"id":"w55m9SrYf-hF"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Solution 3\n","import jax\n","import jax.numpy as jnp\n","import optax\n","from flax import nnx\n","import matplotlib.pyplot as plt\n","\n","# --- Using the same model and data setup from Exercise 1 ---\n","key = jax.random.key(42)\n","key_model, key_data = jax.random.split(key)\n","din, dmid, dout = 10, 20, 5\n","x_dummy = jax.random.normal(key_data, (32, din))\n","y_dummy = jax.random.normal(key_data, (32, dout))\n","\n","model_scheduled = SimpleMLP(features=[din, dmid, dout], rngs=nnx.Rngs(key_model))\n","\n","# --- Scheduling Hyperparameters ---\n","total_training_steps = 500\n","warmup_fraction = 0.1\n","peak_lr = 1e-3\n","final_lr = 1e-5\n","\n","# Define a warmup-cosine-decay learning rate schedule\n","warmup_steps = int(total_training_steps * warmup_fraction)\n","decay_steps = total_training_steps - warmup_steps\n","\n","lr_schedule_fn = optax.warmup_cosine_decay_schedule(\n"," init_value=0.0,\n"," peak_value=peak_lr,\n"," warmup_steps=warmup_steps,\n"," decay_steps=decay_steps,\n"," end_value=final_lr\n",")\n","\n","# Create an Adam optimizer that uses this schedule\n","opt_scheduled = optax.adam(learning_rate=lr_schedule_fn)\n","\n","# --- Boilerplate for running the exercise ---\n","optimizer_scheduled = nnx.Optimizer(model_scheduled, opt_scheduled, wrt=nnx.Param)\n","\n","# Training Loop\n","print(\"Starting training with scheduled LR...\")\n","lrs = []\n","for i in range(total_training_steps):\n"," model_scheduled, optimizer_scheduled, loss = train_step(model_scheduled, optimizer_scheduled, x_dummy, y_dummy)\n"," current_lr = lr_schedule_fn(optimizer_scheduled.step.value)\n"," lrs.append(current_lr)\n"," if i % 50 == 0:\n"," print(f\"Step {optimizer_scheduled.step.value}, Loss: {loss:.5f}, LR: {current_lr:.6f}\")\n","print(\"Scheduled LR training finished.\")\n","\n","# Plot the learning rate over time\n","plt.figure(figsize=(10, 4))\n","plt.plot(lrs)\n","plt.title(\"Learning Rate Schedule\")\n","plt.xlabel(\"Training Step\")\n","plt.ylabel(\"Learning Rate\")\n","plt.grid(True)\n","plt.show()"],"metadata":{"id":"oYZFC9udmXxp"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Exercise 4: Per-Parameter Optimization\n","\n","**Concept:** It's often beneficial to apply different optimization rules to different model parameters. For example, you might not want to apply weight decay to bias parameters or normalization layer scales. In PyTorch, this is handled with \"parameter groups.\" In Optax, the equivalent is `optax.partition`.\n","\n","This exercise will guide you through:\n","\n","1. Writing a \"labeling function\" that assigns a string label ('bias', 'kernel', or 'other') to each parameter in your model based on its name.\n","2. Using optax.partition to create a composite optimizer that applies different learning rates to biases and kernels.\n","\n","### Instructions\n","Complete the TODO sections to implement per-parameter optimization."],"metadata":{"id":"FFrIYmXEmiBP"}},{"cell_type":"code","source":["# @title Exercise 4: Implement Per-Parameter Optimization\n","import jax\n","import jax.numpy as jnp\n","import optax\n","from flax import nnx\n","\n","# --- Using the same model and data setup from Exercise 1 ---\n","key = jax.random.key(42)\n","key_model, key_data = jax.random.split(key)\n","din, dmid, dout = 10, 20, 5\n","x_dummy = jax.random.normal(key_data, (32, din))\n","y_dummy = jax.random.normal(key_data, (32, dout))\n","\n","model_partitioned = SimpleMLP(features=[din, dmid, dout], rngs=nnx.Rngs(key_model))\n","\n","# 1. Create the parameter labels PyTree\n","# Get a PyTree of the model's parameters to generate labels for\n","params_pytree = nnx.state(model_partitioned, nnx.Param)\n","\n","# TODO: Implement the labeling function.\n","# It should inspect the path to a parameter and return a string label.\n","def label_fn(path, leaf):\n"," \"\"\"Assigns a label to a parameter based on its path.\"\"\"\n"," # The path is a tuple of keys\n"," # We can check the name of the last attribute in the path.\n"," # YOUR CODE HERE\n","\n","# Use tree_map_with_path to apply the labeling function\n","param_labels = jax.tree.map_with_path(label_fn, params_pytree)\n","\n","print(\"Generated Parameter Labels PyTree:\")\n","nnx.display(param_labels)\n","\n","# 2. TODO: Define the partitioned optimizer.\n","# Use optax.partition, providing a dictionary mapping your labels\n","# to different Optax transformations.\n","# - Use Adam with LR 1e-3 for 'kernel'\n","# - Use SGD with LR 5e-3 for 'bias'\n","# - Use Adam with LR 1e-4 for 'other' (a default)\n","partitioned_opt = optax.partition(\n"," transforms={\n"," # YOUR CODE HERE\n"," },\n"," param_labels=param_labels\n",")\n","\n","\n","# --- Boilerplate for running the exercise ---\n","optimizer_partitioned = nnx.Optimizer(model_partitioned, partitioned_opt, wrt=nnx.Param)\n","\n","print(\"\\nStarting training with partitioned optimizer...\")\n","for i in range(101):\n"," model_partitioned, optimizer_partitioned, loss = train_step(model_partitioned, optimizer_partitioned, x_dummy, y_dummy)\n"," if i % 20 == 0:\n"," print(f\"Step {optimizer_partitioned.step.value}, Loss: {loss:.4f}\")\n","print(\"Partitioned optimizer training finished.\")\n","\n","# Verify the optimizer state structure\n","opt_state_structure = jax.tree_util.tree_map(\n"," lambda x: x.__class__.__name__, optimizer_partitioned.state.opt_state\n",")\n","print(\"\\nStructure of the partitioned optimizer's state:\")\n","nnx.display(opt_state_structure)\n","assert 'PartitionState' in str(opt_state_structure)"],"metadata":{"id":"3yZt1Cl6mctx"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Solution 4\n","import jax\n","import jax.numpy as jnp\n","import optax\n","from flax import nnx\n","\n","# --- Using the same model and data setup from Exercise 1 ---\n","key = jax.random.key(42)\n","key_model, key_data = jax.random.split(key)\n","din, dmid, dout = 10, 20, 5\n","x_dummy = jax.random.normal(key_data, (32, din))\n","y_dummy = jax.random.normal(key_data, (32, dout))\n","\n","model_partitioned = SimpleMLP(features=[din, dmid, dout], rngs=nnx.Rngs(key_model))\n","\n","# 1. Create the parameter labels PyTree\n","params_pytree = nnx.state(model_partitioned, nnx.Param)\n","\n","def label_fn(path, leaf):\n"," \"\"\"Assigns a label to a parameter based on its path.\"\"\"\n"," param_name = path[-1].name\n"," if 'bias' in param_name:\n"," return 'bias'\n"," elif 'kernel' in param_name:\n"," return 'kernel'\n"," return 'other'\n","\n","param_labels = jax.tree.map_with_path(label_fn, params_pytree)\n","\n","print(\"Generated Parameter Labels PyTree:\")\n","nnx.display(param_labels)\n","\n","# 2. Define the partitioned optimizer\n","partitioned_opt = optax.partition(\n"," transforms={\n"," 'kernel': optax.adam(learning_rate=1e-3),\n"," 'bias': optax.sgd(learning_rate=5e-3),\n"," 'other': optax.adam(learning_rate=1e-4),\n"," },\n"," param_labels=param_labels\n",")\n","\n","# --- Boilerplate for running the exercise ---\n","optimizer_partitioned = nnx.Optimizer(model_partitioned, partitioned_opt, wrt=nnx.Param)\n","\n","print(\"\\nStarting training with partitioned optimizer...\")\n","for i in range(101):\n"," model_partitioned, optimizer_partitioned, loss = train_step(model_partitioned, optimizer_partitioned, x_dummy, y_dummy)\n"," if i % 20 == 0:\n"," print(f\"Step {optimizer_partitioned.step.value}, Loss: {loss:.4f}\")\n","print(\"Partitioned optimizer training finished.\")\n","\n","# Verify the optimizer state structure\n","opt_state_structure = jax.tree_util.tree_map(\n"," lambda x: x.__class__.__name__, optimizer_partitioned.opt_state\n",")\n","print(\"\\nStructure of the partitioned optimizer's state:\")\n","nnx.display(opt_state_structure)\n","assert 'PartitionState' in str(opt_state_structure)"],"metadata":{"id":"Hi03anmEm9B1"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Exercise 5: Sharding the Model and Optimizer State\n","\n","**Concept:** JAX provides fine-grained control over how data and model parameters are distributed across devices. This is done by explicitly annotating PyTrees (like model parameters or optimizer state) with sharding information.\n","\n","In this exercise, you will:\n","\n","1. Create a 2D device Mesh from our 8 simulated CPUs.\n","2. Define a sharded MLP where the kernel of a linear layer is sharded across the 'model' axis of the mesh (Model Parallelism).\n","3. Create a sharded optimizer whose state (e.g., Adam's momentum and variance vectors) automatically inherits the same sharding as the corresponding model parameters.\n","\n","### Instructions\n","Complete the TODO sections to shard your model and optimizer."],"metadata":{"id":"P3VGU-0Rq0zo"}},{"cell_type":"code","source":["# @title Exercise 5: Sharding Model and Optimizer\n","import jax\n","import jax.numpy as jnp\n","import optax\n","from flax import nnx\n","from jax.sharding import Mesh, PartitionSpec as P\n","import numpy as np\n","\n","# Ensure we have our 8 simulated devices\n","if jax.device_count() != 8:\n"," print(\"Warning: This exercise expects 8 devices. Sharding may not behave as expected.\")\n","\n","# 1. Create a device mesh\n","# We'll create a 2x4 mesh, with a 'data' axis for data parallelism\n","# and a 'model' axis for model parallelism.\n","devices = np.array(jax.devices()).reshape(2, 4)\n","mesh = Mesh(devices, axis_names=('data', 'model'))\n","print(\"Created 2x4 device mesh:\")\n","print(mesh)\n","\n","# 2. Define a sharded model\n","class ShardedMLP(nnx.Module):\n"," def __init__(self, din, dmid, dout, *, rngs: nnx.Rngs):\n"," # TODO: Shard the kernel of the second linear layer.\n"," # The goal is to split the kernel's columns across the 'model' axis.\n"," # This is a form of model parallelism.\n"," # - The first dimension (input features) should be replicated.\n"," # - The second dimension (output features) should be sharded.\n"," # - The bias should also be sharded along the 'model' axis.\n"," # - All other parameters can be replicated (the default).\n"," self.linear1 = nnx.Linear(din, dmid, rngs=rngs)\n"," self.relu = nnx.relu\n"," self.linear2 = nnx.Linear(dmid, dout, rngs=rngs)\n","\n"," # Shard linear1 fully (replicated)\n"," self.linear1.kernel.sharding = P(None, None) # or just P()\n"," self.linear1.bias.sharding = P(None) # or just P()\n","\n"," # Shard linear2 for model parallelism\n"," # YOUR CODE HERE - Replicate rows, shard columns\n"," # YOUR CODE HERE - Shard the bias vector\n","\n"," def __call__(self, x):\n"," x = self.linear1(x)\n"," x = self.relu(x)\n"," x = self.linear2(x)\n"," return x\n","\n","# 3. Create sharded model and optimizer within the mesh context\n","@nnx.jit\n","def create_sharded_model_and_optimizer():\n"," key = jax.random.key(0)\n"," model = ShardedMLP(16, 32, 64, rngs=nnx.Rngs(key))\n"," optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)\n","\n"," # The sharding annotations on the model are automatically picked up.\n"," # Now, we need to ensure the optimizer state gets the same shardings.\n"," # nnx.Optimizer automatically infers this from the model's parameters!\n"," # We just need to use jax.lax.with_sharding_constraint to enforce it\n"," # during JIT compilation.\n","\n"," # Shard model state based on annotations\n"," model_state = nnx.state(model)\n"," model_shardings = nnx.spmd.get_partition_spec(model_state)\n"," sharded_model_state = jax.lax.with_sharding_constraint(model_state, model_shardings)\n"," nnx.update(model, sharded_model_state)\n","\n"," # TODO: Shard the optimizer state.\n"," # The process is identical to sharding the model, but you need to filter\n"," # for the optimizer's state using nnx.optimizer.OptState.\n"," # YOUR CODE HERE\n"," opt_shardings = nnx.spmd.get_partition_spec(opt_state_to_shard)\n"," sharded_opt_state = jax.lax.with_sharding_constraint(\n"," opt_state_to_shard, opt_shardings\n"," )\n"," nnx.update(optimizer, sharded_opt_state)\n","\n"," return model, optimizer\n","\n","# Run the creation function within the mesh context manager\n","with mesh:\n"," sharded_model, sharded_optimizer = create_sharded_model_and_optimizer()\n","\n","\n","# --- Verification ---\n","print(\"\\n--- Verifying Shardings ---\")\n","# Get the sharded state back from the JIT call\n","final_model_state = nnx.state(sharded_model)\n","final_opt_state = nnx.state(sharded_optimizer, nnx.optimizer.OptState)\n","\n","# Check the sharding of the second linear layer's kernel in the model\n","l2_kernel_sharding = final_model_state['layers']['1']['kernel'].sharding\n","print(f\"\\nModel's linear2.kernel sharding: {l2_kernel_sharding}\")\n","assert l2_kernel_sharding == NS(None, 'model')\n","\n","# Check the sharding of the corresponding momentum (m) in the optimizer state\n","# The optimizer state PyTree mirrors the parameter PyTree structure.\n","adam_state = final_opt_state['opt_state'][1] # (trace_state, adam_state)\n","l2_kernel_momentum_sharding = adam_state.m['layers']['1']['kernel'].sharding\n","print(f\"Optimizer's momentum for linear2.kernel sharding: {l2_kernel_momentum_sharding}\")\n","assert l2_kernel_momentum_sharding == NS(None, 'model')\n","\n","print(\"\\nSuccessfully verified that optimizer state sharding matches model parameter sharding.\")"],"metadata":{"id":"K9SE-quUnBhb"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Solution 5\n","import jax\n","import jax.numpy as jnp\n","import optax\n","from flax import nnx\n","from jax.sharding import Mesh, PartitionSpec as P\n","import numpy as np\n","\n","# Ensure we have our 8 simulated devices\n","if jax.device_count() != 8:\n"," print(\"Warning: This exercise expects 8 devices. Sharding may not behave as expected.\")\n","\n","# 1. Create a device mesh\n","devices = np.array(jax.devices()).reshape(2, 4)\n","mesh = Mesh(devices, axis_names=('data', 'model'))\n","print(\"Created 2x4 device mesh:\")\n","print(mesh)\n","\n","# 2. Define a sharded model\n","class ShardedMLP(nnx.Module):\n"," def __init__(self, din, dmid, dout, *, rngs: nnx.Rngs):\n"," self.linear1 = nnx.Linear(din, dmid, rngs=rngs)\n"," self.relu = nnx.relu\n"," self.linear2 = nnx.Linear(dmid, dout, rngs=rngs)\n","\n"," # Shard linear1 fully (replicated) - this is often the default\n"," self.linear1.kernel.sharding = P() # Replicated on all axes\n"," self.linear1.bias.sharding = P() # Replicated on all axes\n","\n"," # Shard linear2 for model parallelism\n"," # Shard the output dimension of the kernel and the bias\n"," self.linear2.kernel.sharding = P(None, 'model') # Replicate rows, shard columns\n"," self.linear2.bias.sharding = P('model') # Shard the bias vector\n","\n"," def __call__(self, x):\n"," x = self.linear1(x)\n"," x = self.relu(x)\n"," x = self.linear2(x)\n"," return x\n","\n","# 3. Create sharded model and optimizer within the mesh context\n","@nnx.jit\n","def create_sharded_model_and_optimizer():\n"," key = jax.random.key(0)\n"," model = ShardedMLP(16, 32, 64, rngs=nnx.Rngs(key))\n"," optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)\n","\n"," # Shard model state based on annotations\n"," model_state = nnx.state(model)\n"," model_shardings = nnx.spmd.get_partition_spec(model_state)\n"," sharded_model_state = jax.lax.with_sharding_constraint(model_state, model_shardings)\n"," nnx.update(model, sharded_model_state)\n","\n"," # Shard the optimizer state\n"," # Filter for the optimizer's state (step and Optax's internal state)\n"," opt_state_to_shard = nnx.state(optimizer, nnx.optimizer.OptState)\n"," # Infer the sharding specification from the parameter shardings\n"," opt_shardings = nnx.spmd.get_partition_spec(opt_state_to_shard)\n"," # Apply the sharding constraint\n"," sharded_opt_state = jax.lax.with_sharding_constraint(\n"," opt_state_to_shard, opt_shardings\n"," )\n"," nnx.update(optimizer, sharded_opt_state)\n","\n"," return model, optimizer\n","\n","# Run the creation function within the mesh context manager\n","with mesh:\n"," sharded_model, sharded_optimizer = create_sharded_model_and_optimizer()\n","\n","\n","# --- Verification ---\n","print(\"\\n--- Verifying Shardings ---\")\n","# Get the sharded state back from the JIT call\n","final_model_state = nnx.state(sharded_model)\n","final_opt_state = nnx.state(sharded_optimizer, nnx.optimizer.OptState)\n","\n","# Check the sharding of the second linear layer's kernel in the model\n","l2_kernel_sharding = final_model_state['linear2']['kernel'].sharding\n","print(f\"\\nModel's linear2.kernel sharding: {l2_kernel_sharding}\")\n","assert l2_kernel_sharding == P(None, 'model')\n","\n","# Check the sharding of the corresponding momentum (m) in the optimizer state\n","# The optimizer state PyTree mirrors the parameter PyTree structure.\n","# For optax.adam, the state is a tuple of (trace_state, adam_state).\n","# We look inside the AdamState.\n","adam_state = final_opt_state['opt_state'][0]\n","l2_kernel_momentum_sharding = adam_state.mu['linear2']['kernel'].sharding\n","print(f\"Optimizer's momentum for linear2.kernel sharding: {l2_kernel_momentum_sharding}\")\n","assert l2_kernel_momentum_sharding == P(None, 'model')\n","\n","print(\"\\nSuccessfully verified that optimizer state sharding matches model parameter sharding.\")"],"metadata":{"id":"S1PHHOj1rMKP"},"execution_count":null,"outputs":[]}]} \ No newline at end of file diff --git a/docs/learning_jax/code-exercises/2 - NumPy and JAX NumPy.ipynb b/docs/learning_jax/code-exercises/2 - NumPy and JAX NumPy.ipynb new file mode 100644 index 0000000..2266c59 --- /dev/null +++ b/docs/learning_jax/code-exercises/2 - NumPy and JAX NumPy.ipynb @@ -0,0 +1 @@ +{"cells":[{"cell_type":"markdown","metadata":{"id":"LQHmwePqryRU"},"source":["# How to think in JAX\n","\n","\n","\n","[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb)\n","\n","JAX provides a simple and powerful API for writing accelerated numerical code, but working effectively in JAX sometimes requires extra consideration. This document is meant to help build a ground-up understanding of how JAX operates, so that you can use it more effectively."]},{"cell_type":"markdown","metadata":{"id":"nayIExVUtsVD"},"source":["## JAX vs. NumPy\n","\n","**Key concepts:**\n","\n","- JAX provides a NumPy-inspired interface for convenience.\n","- Through duck-typing, JAX arrays can often be used as drop-in replacements of NumPy arrays.\n","- Unlike NumPy arrays, JAX arrays are always immutable.\n","\n","NumPy provides a well-known, powerful API for working with numerical data. For convenience, JAX provides `jax.numpy` which closely mirrors the numpy API and provides easy entry into JAX. Almost anything that can be done with `numpy` can be done with `jax.numpy`:"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"kZaOXL7-uvUP"},"outputs":[],"source":["import matplotlib.pyplot as plt\n","import numpy as np\n","\n","x_np = np.linspace(0, 10, 1000)\n","y_np = 2 * np.sin(x_np) * np.cos(x_np)\n","plt.plot(x_np, y_np);"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"18XbGpRLuZlr"},"outputs":[],"source":["import jax.numpy as jnp\n","\n","x_jnp = jnp.linspace(0, 10, 1000)\n","y_jnp = 2 * jnp.sin(x_jnp) * jnp.cos(x_jnp)\n","plt.plot(x_jnp, y_jnp);"]},{"cell_type":"markdown","metadata":{"id":"kTZcsCJiuPG8"},"source":["The code blocks are identical aside from replacing `np` with `jnp`, and the results are the same. As we can see, JAX arrays can often be used directly in place of NumPy arrays for things like plotting.\n","\n","The arrays themselves are implemented as different Python types:"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"PjFFunI7xNe8"},"outputs":[],"source":["type(x_np)"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"kpv5K7QYxQnX"},"outputs":[],"source":["type(x_jnp)"]},{"cell_type":"markdown","metadata":{"id":"Mx94Ri7euEZm"},"source":["Python's [duck-typing](https://en.wikipedia.org/wiki/Duck_typing) allows JAX arrays and NumPy arrays to be used interchangeably in many places.\n","\n","However, there is one important difference between JAX and NumPy arrays: JAX arrays are immutable, meaning that once created their contents cannot be changed.\n","\n","Here is an example of mutating an array in NumPy:"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"fzp-y1ZVyGD4"},"outputs":[],"source":["# NumPy: mutable arrays\n","x = np.arange(10)\n","x[0] = 10\n","print(x)"]},{"cell_type":"markdown","metadata":{"id":"nQ-De0xcJ1lT"},"source":["The equivalent in JAX results in an error, as JAX arrays are immutable:"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"l2AP0QERb0P7"},"outputs":[],"source":["%xmode minimal"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"pCPX0JR-yM4i","tags":["raises-exception"]},"outputs":[],"source":["# JAX: immutable arrays\n","x = jnp.arange(10)\n","x[0] = 10"]},{"cell_type":"markdown","metadata":{"id":"yRYF0YgO3F4H"},"source":["For updating individual elements, JAX provides an [indexed update syntax](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax-numpy-ndarray-at) that returns an updated copy:"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"8zqPEAeP3UK5"},"outputs":[],"source":["y = x.at[0].set(10)\n","print(x)\n","print(y)"]},{"cell_type":"markdown","metadata":{"id":"886BGDPeyXCu"},"source":["## NumPy, lax & XLA: JAX API layering\n","\n","**Key concepts:**\n","\n","- `jax.numpy` is a high-level wrapper that provides a familiar interface.\n","- `jax.lax` is a lower-level API that is stricter and often more powerful.\n","- All JAX operations are implemented in terms of operations in [XLA](https://www.tensorflow.org/xla/) – the Accelerated Linear Algebra compiler."]},{"cell_type":"markdown","metadata":{"id":"BjE4m2sZy4hh"},"source":["If you look at the source of `jax.numpy`, you'll see that all the operations are eventually expressed in terms of functions defined in `jax.lax`. You can think of `jax.lax` as a stricter, but often more powerful, API for working with multi-dimensional arrays.\n","\n","For example, while `jax.numpy` will implicitly promote arguments to allow operations between mixed data types, `jax.lax` will not:"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"c6EFPcj12mw0"},"outputs":[],"source":["import jax.numpy as jnp\n","jnp.add(1, 1.0) # jax.numpy API implicitly promotes mixed types."]},{"cell_type":"code","execution_count":null,"metadata":{"id":"0VkqlcXL2qSp","tags":["raises-exception"]},"outputs":[],"source":["from jax import lax\n","lax.add(1, 1.0) # jax.lax API requires explicit type promotion."]},{"cell_type":"markdown","metadata":{"id":"aC9TkXaTEu7A"},"source":["If using `jax.lax` directly, you'll have to do type promotion explicitly in such cases:"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"3PNQlieT81mi"},"outputs":[],"source":["lax.add(jnp.float32(1), 1.0)"]},{"cell_type":"markdown","metadata":{"id":"M3HDuM4x2eTL"},"source":["Along with this strictness, `jax.lax` also provides efficient APIs for some more general operations than are supported by NumPy.\n","\n","For example, consider a 1D convolution, which can be expressed in NumPy this way:"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"Bv-7XexyzVCN"},"outputs":[],"source":["x = jnp.array([1, 2, 1])\n","y = jnp.ones(10)\n","jnp.convolve(x, y)"]},{"cell_type":"markdown","metadata":{"id":"0GPqgT7S0q8r"},"source":["Under the hood, this NumPy operation is translated to a much more general convolution implemented by [`lax.conv_general_dilated`](https://docs.jax.dev/en/latest/_autosummary/jax.lax.conv_general_dilated.html):"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"pi4f6ikjzc3l"},"outputs":[],"source":["from jax import lax\n","result = lax.conv_general_dilated(\n"," x.reshape(1, 1, 3).astype(float), # note: explicit promotion\n"," y.reshape(1, 1, 10),\n"," window_strides=(1,),\n"," padding=[(len(y) - 1, len(y) - 1)]) # equivalent of padding='full' in NumPy\n","result[0, 0]"]},{"cell_type":"markdown","metadata":{"id":"7mdo6ycczlbd"},"source":["This is a batched convolution operation designed to be efficient for the types of convolutions often used in deep neural nets. It requires much more boilerplate, but is far more flexible and scalable than the convolution provided by NumPy (See [Convolutions in JAX](https://docs.jax.dev/en/latest/notebooks/convolutions.html) for more detail on JAX convolutions).\n","\n","At their heart, all `jax.lax` operations are Python wrappers for operations in XLA; here, for example, the convolution implementation is provided by [XLA:ConvWithGeneralPadding](https://www.tensorflow.org/xla/operation_semantics#convwithgeneralpadding_convolution).\n","Every JAX operation is eventually expressed in terms of these fundamental XLA operations, which is what enables just-in-time (JIT) compilation."]},{"cell_type":"markdown","metadata":{"id":"NJfWa2PktD5_"},"source":["## To JIT or not to JIT\n","\n","**Key concepts:**\n","\n","- By default JAX executes operations one at a time, in sequence.\n","- Using a just-in-time (JIT) compilation decorator, sequences of operations can be optimized together and run at once.\n","- Not all JAX code can be JIT compiled, as it requires array shapes to be static & known at compile time.\n","\n","The fact that all JAX operations are expressed in terms of XLA allows JAX to use the XLA compiler to execute blocks of code very efficiently.\n","\n","For example, consider this function that normalizes the rows of a 2D matrix, expressed in terms of `jax.numpy` operations:"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"SQj_UKGc-7kQ"},"outputs":[],"source":["import jax.numpy as jnp\n","\n","def norm(X):\n"," X = X - X.mean(0)\n"," return X / X.std(0)"]},{"cell_type":"markdown","metadata":{"id":"0yVo_OKSAolW"},"source":["A just-in-time compiled version of the function can be created using the `jax.jit` transform:"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"oHLwGmhZAnCY"},"outputs":[],"source":["from jax import jit\n","norm_compiled = jit(norm)"]},{"cell_type":"markdown","metadata":{"id":"Q3H9ig5GA2Ms"},"source":["This function returns the same results as the original, up to standard floating-point accuracy:"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"oz7zzyS3AwMc"},"outputs":[],"source":["np.random.seed(1701)\n","X = jnp.array(np.random.rand(10000, 10))\n","np.allclose(norm(X), norm_compiled(X), atol=1E-6)"]},{"cell_type":"markdown","metadata":{"id":"3GvisB-CA9M8"},"source":["But due to the compilation (which includes fusing of operations, avoidance of allocating temporary arrays, and a host of other tricks), execution times can be orders of magnitude faster in the JIT-compiled case (note the use of `block_until_ready()` to account for JAX's [asynchronous dispatch](https://docs.jax.dev/en/latest/async_dispatch.html)):"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"6mUB6VdDAEIY"},"outputs":[],"source":["%timeit norm(X).block_until_ready()\n","%timeit norm_compiled(X).block_until_ready()"]},{"cell_type":"markdown","metadata":{"id":"B1eGBGn0tMba"},"source":["That said, `jax.jit` does have limitations: in particular, it requires all arrays to have static shapes. That means that some JAX operations are incompatible with JIT compilation.\n","\n","For example, this operation can be executed in op-by-op mode:"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"YfZd9mW7CSKM"},"outputs":[],"source":["def get_negatives(x):\n"," return x[x < 0]\n","\n","x = jnp.array(np.random.randn(10))\n","get_negatives(x)"]},{"cell_type":"markdown","metadata":{"id":"g6niKxoQC2mZ"},"source":["But it returns an error if you attempt to execute it in jit mode:"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"yYWvE4rxCjPK","tags":["raises-exception"]},"outputs":[],"source":["jit(get_negatives)(x)"]},{"cell_type":"markdown","metadata":{"id":"vFL6DNpECfVz"},"source":["This is because the function generates an array whose shape is not known at compile time: the size of the output depends on the values of the input array, and so it is not compatible with JIT."]},{"cell_type":"markdown","metadata":{"id":"BzBnKbXwXjLV"},"source":["## JIT mechanics: tracing and static variables\n","\n","**Key concepts:**\n","\n","- JIT and other JAX transforms work by *tracing* a function to determine its effect on inputs of a specific shape and type.\n","\n","- Variables that you don't want to be traced can be marked as *static*\n","\n","To use `jax.jit` effectively, it is useful to understand how it works. Let's put a few `print()` statements within a JIT-compiled function and then call the function:"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"TfjVIVuD4gnc"},"outputs":[],"source":["@jit\n","def f(x, y):\n"," print(\"Running f():\")\n"," print(f\" x = {x}\")\n"," print(f\" y = {y}\")\n"," result = jnp.dot(x + 1, y + 1)\n"," print(f\" result = {result}\")\n"," return result\n","\n","x = np.random.randn(3, 4)\n","y = np.random.randn(4)\n","f(x, y)"]},{"cell_type":"markdown","metadata":{"id":"Ts1fP45A40QV"},"source":["Notice that the print statements execute, but rather than printing the data we passed to the function, though, it prints *tracer* objects that stand-in for them.\n","\n","These tracer objects are what `jax.jit` uses to extract the sequence of operations specified by the function. Basic tracers are stand-ins that encode the **shape** and **dtype** of the arrays, but are agnostic to the values. This recorded sequence of computations can then be efficiently applied within XLA to new inputs with the same shape and dtype, without having to re-execute the Python code.\n","\n","When we call the compiled function again on matching inputs, no re-compilation is required and nothing is printed because the result is computed in compiled XLA rather than in Python:"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"xGntvzNH7skE"},"outputs":[],"source":["x2 = np.random.randn(3, 4)\n","y2 = np.random.randn(4)\n","f(x2, y2)"]},{"cell_type":"markdown","metadata":{"id":"9EB9WkRX7fm0"},"source":["The extracted sequence of operations is encoded in a JAX expression, or *jaxpr* for short. You can view the jaxpr using the `jax.make_jaxpr` transformation:"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"89TMp_Op5-JZ"},"outputs":[],"source":["from jax import make_jaxpr\n","\n","def f(x, y):\n"," return jnp.dot(x + 1, y + 1)\n","\n","make_jaxpr(f)(x, y)"]},{"cell_type":"markdown","metadata":{"id":"0Oq9S4MZ90TL"},"source":["Note one consequence of this: because JIT compilation is done *without* information on the content of the array, control flow statements in the function cannot depend on traced values. For example, this fails:"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"A0rFdM95-Ix_","tags":["raises-exception"]},"outputs":[],"source":["@jit\n","def f(x, neg):\n"," return -x if neg else x\n","\n","f(1, True)"]},{"cell_type":"markdown","metadata":{"id":"DkTO9m8j-TYI"},"source":["If there are variables that you would not like to be traced, they can be marked as static for the purposes of JIT compilation:"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"K1C7ZnVv-lbv"},"outputs":[],"source":["from functools import partial\n","\n","@partial(jit, static_argnums=(1,))\n","def f(x, neg):\n"," return -x if neg else x\n","\n","f(1, True)"]},{"cell_type":"markdown","metadata":{"id":"dD7p4LRsGzhx"},"source":["Note that calling a JIT-compiled function with a different static argument results in re-compilation, so the function still works as expected:"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"sXqczBOrG7-w"},"outputs":[],"source":["f(1, False)"]},{"cell_type":"markdown","metadata":{"id":"ZESlrDngGVb1"},"source":["Understanding which values and operations will be static and which will be traced is a key part of using `jax.jit` effectively."]},{"cell_type":"markdown","metadata":{"id":"r-RCl_wD5lI7"},"source":["## Static vs traced operations\n","\n","**Key concepts:**\n","\n","- Just as values can be either static or traced, operations can be static or traced.\n","\n","- Static operations are evaluated at compile-time in Python; traced operations are compiled & evaluated at run-time in XLA.\n","\n","- Use `numpy` for operations that you want to be static; use `jax.numpy` for operations that you want to be traced.\n","\n","This distinction between static and traced values makes it important to think about how to keep a static value static. Consider this function:"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"XJCQ7slcD4iU","tags":["raises-exception"]},"outputs":[],"source":["import jax.numpy as jnp\n","from jax import jit\n","\n","@jit\n","def f(x):\n"," return x.reshape(jnp.array(x.shape).prod())\n","\n","x = jnp.ones((2, 3))\n","f(x)"]},{"cell_type":"markdown","metadata":{"id":"ZO3GMGrHBZDS"},"source":["This fails with an error specifying that a tracer was found instead of a 1D sequence of concrete values of integer type. Let's add some print statements to the function to understand why this is happening:"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"Cb4mbeVZEi_q"},"outputs":[],"source":["@jit\n","def f(x):\n"," print(f\"x = {x}\")\n"," print(f\"x.shape = {x.shape}\")\n"," print(f\"jnp.array(x.shape).prod() = {jnp.array(x.shape).prod()}\")\n"," # comment this out to avoid the error:\n"," # return x.reshape(jnp.array(x.shape).prod())\n","\n","f(x)"]},{"cell_type":"markdown","metadata":{"id":"viSQPc3jEwJr"},"source":["Notice that although `x` is traced, `x.shape` is a static value. However, when we use `jnp.array` and `jnp.prod` on this static value, it becomes a traced value, at which point it cannot be used in a function like `reshape()` that requires a static input (recall: array shapes must be static).\n","\n","A useful pattern is to use `numpy` for operations that should be static (i.e. done at compile-time), and use `jax.numpy` for operations that should be traced (i.e. compiled and executed at run-time). For this function, it might look like this:"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"GiovOOPcGJhg"},"outputs":[],"source":["from jax import jit\n","import jax.numpy as jnp\n","import numpy as np\n","\n","@jit\n","def f(x):\n"," return x.reshape((np.prod(x.shape),))\n","\n","f(x)"]},{"cell_type":"markdown","metadata":{"id":"C-QZ5d1DG-dv"},"source":["For this reason, a standard convention in JAX programs is to `import numpy as np` and `import jax.numpy as jnp` so that both interfaces are available for finer control over whether operations are performed in a static manner (with `numpy`, once at compile-time) or a traced manner (with `jax.numpy`, optimized at run-time)."]}],"metadata":{"colab":{"provenance":[{"file_id":"1TIoXWDq64CLViJWkDi8vnckYNEc8wtuo","timestamp":1755113622798},{"file_id":"1usaXKsNi10cUMxMBUKpzmXL0KUrAoNlW","timestamp":1755113587459},{"file_id":"https://github.com/jax-ml/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb","timestamp":1748370227419}]},"jupytext":{"formats":"ipynb,md:myst"},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.7.6"}},"nbformat":4,"nbformat_minor":0} \ No newline at end of file diff --git a/docs/learning_jax/code-exercises/3 - Intro to NNX for PyTorch Users.ipynb b/docs/learning_jax/code-exercises/3 - Intro to NNX for PyTorch Users.ipynb new file mode 100644 index 0000000..e9d12ea --- /dev/null +++ b/docs/learning_jax/code-exercises/3 - Intro to NNX for PyTorch Users.ipynb @@ -0,0 +1 @@ +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[{"file_id":"1jbsvh_ZWvXFaK-0FGsVYjoDc4QcBDV10","timestamp":1755113770723}],"toc_visible":true},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","source":["# Introduction to Flax NNX\n","\n","Welcome to the Flax NNX Colab Notebook! This notebook provides hands-on exercises designed to help PyTorch users transition to Flax NNX and the JAX ecosystem.\n","\n","We'll cover core concepts and build simple models."],"metadata":{"id":"wK90mE1fmGuk"}},{"cell_type":"code","source":["!pip install -Uq flax optax"],"metadata":{"id":"Pke20hU-A1iQ"},"execution_count":null,"outputs":[]},{"cell_type":"code","execution_count":null,"metadata":{"id":"C7OTedGZjhPz"},"outputs":[],"source":["# @title Exercise 1: Understanding Modules and Parameters (Coding Exercise)\n","\n","# Instructions:\n","# 1. Create a simple NNX Module called `MyLinearLayer`.\n","# 2. It should have an `nnx.Param` called `weight` (initialized randomly with shape [input_size, output_size]).\n","# 3. It should have an `nnx.Param` called `bias` (initialized with zeros with shape [output_size]).\n","# 4. The forward pass (`__call__` method) should perform a linear transformation: `x @ self.weight.value + self.bias.value`.\n","# 5. Instantiate the layer with `input_size=10` and `output_size=5`.\n","# 6. Print the shape of the `weight` and `bias` parameters.\n","\n","from flax import nnx\n","import jax\n","import jax.numpy as jnp\n","\n","class MyLinearLayer(nnx.Module):\n"," def __init__(self, input_size: int, output_size: int, *, rngs: nnx.Rngs):\n","\n"," pass # FILL IN THIS PART\n","\n"," def __call__(self, x: jax.Array):\n"," pass # FILL IN THIS PART\n","\n","# Instantiate the layer\n","key = jax.random.PRNGKey(0)\n","linear_layer = MyLinearLayer(\n"," input_size='FILL IN THIS PART',\n"," output_size='FILL IN THIS PART',\n"," rngs=nnx.Rngs(key))\n","\n","# Print the shapes of the parameters\n","print(\"Weight shape:\", 'FILL IN THIS PART')\n","print(\"Bias shape:\", 'FILL IN THIS PART')\n","\n","# Example usage:\n","dummy_input = jnp.ones((1, 10))\n","output = linear_layer(dummy_input)\n","print(\"Output shape:\", output.shape)"]},{"cell_type":"code","source":["# @title Exercise 1 Solution\n","\n","# from flax import nnx\n","# import jax\n","# import jax.numpy as jnp\n","\n","# class MyLinearLayer(nnx.Module):\n","# def __init__(self, input_size: int, output_size: int, *, rngs: nnx.Rngs):\n","# self.weight = nnx.Param(jax.random.normal(rngs.params(), (input_size, output_size)))\n","# self.bias = nnx.Param(jnp.zeros((output_size,)))\n","\n","# def __call__(self, x: jax.Array):\n","# return x @ self.weight.value + self.bias.value\n","\n","# # Instantiate the layer\n","# key = jax.random.PRNGKey(0)\n","# linear_layer = MyLinearLayer(input_size=10, output_size=5, rngs=nnx.Rngs(key))\n","\n","# # Print the shapes of the parameters\n","# print(\"Weight shape:\", linear_layer.weight.value.shape)\n","# print(\"Bias shape:\", linear_layer.bias.value.shape)\n","\n","# # Example usage:\n","# dummy_input = jnp.ones((1, 10))\n","# output = linear_layer(dummy_input)\n","# print(\"Output shape:\", output.shape)"],"metadata":{"id":"QeaDLUu_lXMA","cellView":"form"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Exercise 2: State Management (Coding Exercise)\n","# Instructions:\n","# 1. Create an NNX Module called `CounterModule`.\n","# 2. It should have a Python instance attribute called `count` initialized to 0.\n","# 3. The `__call__` method should increment the `count` by 1 and return the new value.\n","# 4. Instantiate the module.\n","# 5. Call the module multiple times and print the returned value.\n","# 6. Use `nnx.split` and `nnx.merge` to save and load the module's state. Verify that the counter resumes from where it left off.\n","\n","from flax import nnx\n","import jax.numpy as jnp\n","\n","class CounterModule(nnx.Module):\n"," def __init__(self):\n"," pass # FILL IN THIS PART\n","\n"," def __call__(self):\n"," pass # FILL IN THIS PART\n","\n","# Instantiate the module\n","pass # FILL IN THIS PART. Name it \"counter\"\n","\n","# Call the module and print the value\n","print(\"First call:\", counter())\n","print(\"Second call:\", counter())\n","\n","# Split the module into graphdef and state.\n","# Remember that state is an nnx.Variable\n","graphdef, state = # FILL IN THIS PART\n","\n","# Merge the graphdef and state to create a new module\n","new_counter = # FILL IN THIS PART\n","\n","# Call the new module and print the value\n","print(\"After split and merge, first call:\", new_counter())\n","print(\"After split and merge, second call:\", new_counter())"],"metadata":{"id":"Qa51jundpavu"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Exercise 2 Solution\n","\n","# from flax import nnx\n","# import jax.numpy as jnp\n","\n","# class CounterModule(nnx.Module):\n","# def __init__(self):\n","# self.count = 0\n","\n","# def __call__(self):\n","# self.count += 1\n","# return self.count\n","\n","# # Instantiate the module\n","# counter = CounterModule()\n","\n","# # Call the module and print the value\n","# print(\"First call:\", counter())\n","# print(\"Second call:\", counter())\n","\n","# # Split the module into graphdef and state\n","# graphdef, state = nnx.split(counter, nnx.Variable)\n","\n","# # Merge the graphdef and state to create a new module\n","# new_counter = nnx.merge(graphdef, state)\n","\n","# # Call the new module and print the value\n","# print(\"After split and merge, first call:\", new_counter())\n","# print(\"After split and merge, second call:\", new_counter())"],"metadata":{"id":"jVh1M8fYppnC","cellView":"form"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Exercise 3: Explicit Random Number Generation (Coding Exercise)\n","\n","# Instructions:\n","# 1. Create an NNX Module called `RandomNormalLayer`.\n","# 2. Its `__init__` method should receive a `size` argument defining the size of the random vector to generate.\n","# 3. The `__init__` method should receive a `rngs: nnx.Rngs` argument that is used to generate a random normal tensor\n","# using jax.random.normal and assign the tensor to `self.random_vector`.\n","# 4. The `__call__` method should return the value of `self.random_vector` (a new random normal tensor).\n","# 5. Instantiate the layer with a size of 10, passing in the rngs parameter with a jax.random.PRNGKey.\n","# 6. Call the module twice and observe that the returned values are different.\n","\n","from flax import nnx\n","import jax\n","import jax.numpy as jnp\n","\n","# CREATE RandomNormalLayer\n","\n","# Instantiate the module\n","key = # USE jax.random.PRNGKey to create a new key\n","random_layer = RandomNormalLayer(size='SIZE HERE', rngs=nnx.Rngs(key))\n","\n","# Call the module and print the value\n","print(\"First call:\", random_layer())\n","print(\"Second call:\", random_layer())"],"metadata":{"id":"QKKiri2rptgl"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Exercise 3 Solution\n","\n","# from flax import nnx\n","# import jax\n","# import jax.numpy as jnp\n","\n","# class RandomNormalLayer(nnx.Module):\n","# def __init__(self, size: int, *, rngs: nnx.Rngs):\n","# self.random_vector = nnx.Param(jax.random.normal(rngs.params(), (size,)))\n","\n","# def __call__(self):\n","# return self.random_vector.value\n","\n","# # Instantiate the module\n","# key = jax.random.PRNGKey(0)\n","# random_layer = RandomNormalLayer(size=10, rngs=nnx.Rngs(key))\n","\n","# # Call the module and print the value\n","# print(\"First call:\", random_layer())\n","# print(\"Second call:\", random_layer())"],"metadata":{"id":"RG420Ks9roNq","cellView":"form"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Exercise 4: Building a Simple CNN (Coding Exercise)\n","\n","# Instructions:\n","# 1. Create an NNX Module representing a simple CNN with the following layers:\n","# - Convolutional layer (nnx.Conv) with 32 filters, kernel size 3, and stride 1.\n","# - ReLU activation.\n","# - Max pooling layer (nnx.max_pool) with window size 2 and stride 2.\n","# - Flatten layer (jax.numpy.reshape).\n","# - Linear layer (nnx.Linear) to map to 10 output classes.\n","# 2. Initialize the CNN with appropriate input and output shapes.\n","# 3. Perform a forward pass with a dummy input and print the output shape.\n","\n","from flax import nnx\n","import jax\n","import jax.numpy as jnp\n","import jax.lax\n","\n","class SimpleCNN(nnx.Module):\n"," def __init__(self, num_classes: int, *, rngs: nnx.Rngs):\n"," self.conv = nnx.Conv('STRIDE', 'FILTERS', kernel_size=('X, X'), rngs=rngs)\n"," self.linear = nnx.Linear(in_features=6272, out_features=num_classes, rngs=rngs)\n","\n"," def __call__(self, x: jax.Array):\n"," x = self.conv(x)\n"," print(f'{x.shape = }') # For debug\n"," x = nnx.relu(x)\n"," print(f'{x.shape = }') # For debug\n"," x = nnx.max_pool(x, window_shape=('X, X'), strides=('X, X'))\n"," print(f'{x.shape = }') # For debug\n"," x = x.reshape(x.shape[0], -1) # flatten\n"," print(f'{x.shape = }') # For debug\n"," x = self.linear(x)\n"," return x\n","\n","# Instantiate the CNN\n","key = jax.random.PRNGKey(0)\n","cnn = SimpleCNN(num_classes='OUTPUT CLASSES', rngs=nnx.Rngs(key))\n","\n","# Dummy input\n","dummy_input = jnp.ones((1, 28, 28, 1))\n","\n","# Forward pass\n","output = cnn(dummy_input)\n","print(\"Output shape:\", output.shape)"],"metadata":{"id":"zafWVwtE3xgF"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Exercise 4 Solution\n","\n","# from flax import nnx\n","# import jax\n","# import jax.numpy as jnp\n","# import jax.lax\n","\n","# class SimpleCNN(nnx.Module):\n","# def __init__(self, num_classes: int, *, rngs: nnx.Rngs):\n","# self.conv = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)\n","# self.linear = nnx.Linear(in_features=6272, out_features=num_classes, rngs=rngs)\n","\n","# def __call__(self, x: jax.Array):\n","# x = self.conv(x)\n","# print(f'{x.shape = }')\n","# x = nnx.relu(x)\n","# print(f'{x.shape = }')\n","# x = nnx.max_pool(x, window_shape=(2, 2), strides=(2, 2))\n","# print(f'{x.shape = }')\n","# x = x.reshape(x.shape[0], -1) # flatten\n","# print(f'{x.shape = }')\n","# x = self.linear(x)\n","# return x\n","\n","# # Instantiate the CNN\n","# key = jax.random.PRNGKey(0)\n","# cnn = SimpleCNN(num_classes=10, rngs=nnx.Rngs(key))\n","\n","# # Dummy input\n","# dummy_input = jnp.ones((1, 28, 28, 1))\n","\n","# # Forward pass\n","# output = cnn(dummy_input)\n","# print(\"Output shape:\", output.shape)"],"metadata":{"id":"_XHKL8ZaYla4","cellView":"form"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Exercise 5: Training Loop with Optax (Coding Exercise)\n","\n","# Instructions:\n","# 1. Define a simple model (e.g., a linear layer).\n","# 2. Create an nnx.Optimizer, making sure to specify which variable types to\n","# update using the now required wrt argument (e.g., wrt=nnx.Param).\n","# 3. Implement a training step function that:\n","# - Calculates the loss (e.g., mean squared error).\n","# - Computes gradients using `nnx.value_and_grad`.\n","# - Updates the model's state using `optimizer.update(model, grads)`.\n","# 4. Run the training loop for a few steps.\n","\n","from flax import nnx\n","import jax\n","import jax.numpy as jnp\n","import optax\n","\n","# Define a simple model\n","class LinearModel(nnx.Module):\n"," def __init__(self, *, rngs: nnx.Rngs):\n"," self.linear = 'LINEAR LAYER HERE'\n","\n"," def __call__(self, x: jax.Array):\n"," return self.linear(x)\n","\n","# Instantiate the model\n","key = jax.random.PRNGKey(0)\n","model = LinearModel(rngs=nnx.Rngs(key))\n","\n","# Create an Optax optimizer\n","tx = 'OPTAX SGD HERE'\n","optimizer = nnx.Optimizer('WRAP THE OPTIMIZER')\n","\n","# Dummy data\n","x = jnp.array([[2.0]])\n","y = jnp.array([[4.0]])\n","\n","# Training step function\n","@nnx.jit\n","def train_step(model, optimizer, x, y):\n"," def loss_fn(model):\n"," y_pred = model(x)\n"," return jnp.mean((y_pred - y) ** 2)\n","\n"," loss, grads = nnx.value_and_grad(loss_fn)(model)\n"," optimizer.update(model, grads)\n"," return loss, model\n","\n","# Training loop\n","num_steps = 10\n","for i in range(num_steps):\n"," loss, model = train_step(model, optimizer, x, y)\n"," print(f\"Step {i+1}, Loss: {loss}\")\n","\n","print(\"Trained model output:\", model(x))"],"metadata":{"id":"Sf4P1AEO3_Rp"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Exercise 5 Solution\n","\n","# from flax import nnx\n","# import jax\n","# import jax.numpy as jnp\n","# import optax\n","\n","# # Define a simple model\n","# class LinearModel(nnx.Module):\n","# def __init__(self, *, rngs: nnx.Rngs):\n","# self.linear = nnx.Linear(in_features=1, out_features=1, rngs=rngs)\n","\n","# def __call__(self, x: jax.Array):\n","# return self.linear(x)\n","\n","# # Instantiate the model\n","# key = jax.random.PRNGKey(0)\n","# model = LinearModel(rngs=nnx.Rngs(key))\n","\n","# # Create an Optax optimizer\n","# tx = optax.sgd(learning_rate=0.01)\n","# optimizer = nnx.Optimizer(model, tx=tx, wrt=nnx.Param)\n","\n","# # Dummy data\n","# x = jnp.array([[2.0]])\n","# y = jnp.array([[4.0]])\n","\n","# # Training step function\n","# @nnx.jit\n","# def train_step(model, optimizer, x, y):\n","# def loss_fn(model):\n","# y_pred = model(x)\n","# return jnp.mean((y_pred - y) ** 2)\n","\n","# loss, grads = nnx.value_and_grad(loss_fn)(model)\n","# optimizer.update(model, grads)\n","# return loss, model\n","\n","# # Training loop\n","# num_steps = 10\n","# for i in range(num_steps):\n","# loss, model = train_step(model, optimizer, x, y)\n","# print(f\"Step {i+1}, Loss: {loss}\")\n","\n","# print(\"Trained model output:\", model(x))"],"metadata":{"id":"CaLOsG6paLam"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Congratulations!\n","You've now worked through the fundamentals of Flax NNX!\n","\n","Remember to consult the official documentation for more in-depth details:\n","\n","* Flax NNX: (Part of the Flax documentation) https://flax.readthedocs.io\n","* JAX: https://jax.readthedocs.io\n","\n","Keep practicing, and happy JAXing!\n","\n","Please send us feedback at https://goo.gle/jax-training-feedback"],"metadata":{"id":"khX7Io6749dt"}},{"cell_type":"markdown","source":[],"metadata":{"id":"_S3rApFP3hum"}}]} \ No newline at end of file diff --git a/docs/learning_jax/code-exercises/4 - MNIST example.ipynb b/docs/learning_jax/code-exercises/4 - MNIST example.ipynb new file mode 100644 index 0000000..10b6dbf --- /dev/null +++ b/docs/learning_jax/code-exercises/4 - MNIST example.ipynb @@ -0,0 +1 @@ +{"cells":[{"cell_type":"markdown","id":"0","metadata":{"id":"0"},"source":["[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/flax/blob/main/docs_nnx/mnist_tutorial.ipynb)\n","[![Open On GitHub](https://img.shields.io/badge/Open-on%20GitHub-blue?logo=GitHub)](https://github.com/google/flax/blob/main/docs_nnx/mnist_tutorial.ipynb)\n","\n","# MNIST tutorial\n","\n","Welcome to Flax NNX! In this tutorial you will learn how to build and train a simple convolutional neural network (CNN) to classify handwritten digits on the MNIST dataset using the Flax NNX API.\n","\n","Flax NNX is a Python neural network library built upon [JAX](https://github.com/jax-ml/jax). If you have used the Flax Linen API before, check out [Why Flax NNX](https://flax.readthedocs.io/en/latest/why.html). You should have some knowledge of the main concepts of deep learning.\n","\n","Let’s get started!"]},{"cell_type":"markdown","id":"1","metadata":{"id":"1"},"source":["## 1. Install Flax and Optax"]},{"cell_type":"code","execution_count":null,"id":"2","metadata":{"tags":["skip-execution"],"id":"2"},"outputs":[],"source":["!pip install -Uq flax optax"]},{"cell_type":"markdown","id":"3","metadata":{"id":"3"},"source":["## 2. Load the MNIST dataset\n","\n","First, you need to load the MNIST dataset and then prepare the training and testing sets via Tensorflow Datasets (TFDS). You normalize image values, shuffle the data and divide it into batches, and prefetch samples to enhance performance."]},{"cell_type":"code","execution_count":null,"id":"4","metadata":{"id":"4"},"outputs":[],"source":["import tensorflow_datasets as tfds # TFDS to download MNIST.\n","import tensorflow as tf # TensorFlow / `tf.data` operations.\n","\n","tf.random.set_seed(0) # Set the random seed for reproducibility.\n","\n","train_steps = 1200\n","eval_every = 200\n","batch_size = 32\n","\n","train_ds: tf.data.Dataset = tfds.load('mnist', split='train')\n","test_ds: tf.data.Dataset = tfds.load('mnist', split='test')\n","\n","train_ds = train_ds.map(\n"," lambda sample: {\n"," 'image': tf.cast(sample['image'], tf.float32) / 255,\n"," 'label': sample['label'],\n"," }\n",") # Normalize train set\n","\n","test_ds = test_ds.map(\n"," lambda sample: {\n"," 'image': tf.cast(sample['image'], tf.float32) / 255,\n"," 'label': sample['label'],\n"," }\n",") # Normalize the test set\n","\n","# Create a shuffled dataset by allocating a buffer size of 1024 to randomly draw elements from.\n","train_ds = train_ds.repeat().shuffle(1024)\n","# Group into batches of `batch_size` and skip incomplete batches, prefetch the next sample to improve latency.\n","train_ds = train_ds.batch(batch_size, drop_remainder=True).take(train_steps).prefetch(1)\n","# Group into batches of `batch_size` and skip incomplete batches, prefetch the next sample to improve latency.\n","test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1)"]},{"cell_type":"markdown","id":"5","metadata":{"id":"5"},"source":["## 3. Define the model with Flax NNX\n","\n","Create a CNN for classification with Flax NNX by subclassing `nnx.Module`:"]},{"cell_type":"code","execution_count":null,"id":"6","metadata":{"id":"6"},"outputs":[],"source":["from flax import nnx # The Flax NNX API.\n","from functools import partial\n","\n","class CNN(nnx.Module):\n"," \"\"\"A simple CNN model.\"\"\"\n","\n"," def __init__(self, *, rngs: nnx.Rngs):\n"," self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)\n"," self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)\n"," self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2))\n"," self.linear1 = nnx.Linear(3136, 256, rngs=rngs)\n"," self.linear2 = nnx.Linear(256, 10, rngs=rngs)\n","\n"," def __call__(self, x):\n"," x = self.avg_pool(nnx.relu(self.conv1(x)))\n"," x = self.avg_pool(nnx.relu(self.conv2(x)))\n"," x = x.reshape(x.shape[0], -1) # flatten\n"," x = nnx.relu(self.linear1(x))\n"," x = self.linear2(x)\n"," return x\n","\n","# Instantiate the model.\n","model = CNN(rngs=nnx.Rngs(0))\n","# Visualize it.\n","nnx.display(model)"]},{"cell_type":"markdown","id":"7","metadata":{"id":"7"},"source":["### Run the model\n","\n","Let's put the CNN model to the test! Here, you’ll perform a forward pass with arbitrary data and print the results."]},{"cell_type":"code","execution_count":null,"id":"8","metadata":{"id":"8"},"outputs":[],"source":["import jax.numpy as jnp # JAX NumPy\n","\n","y = model(jnp.ones((1, 28, 28, 1)))\n","y"]},{"cell_type":"markdown","id":"9","metadata":{"id":"9"},"source":["## 4. Create the optimizer and define some metrics\n","\n","In Flax NNX, you need to create an `nnx.Optimizer` object to manage the model's parameters and apply gradients during training. The `nnx.Optimizer` is initialized with the model to infer the structure of the optimizer state, and an Optax optimizer to define the update rules."]},{"cell_type":"code","execution_count":null,"id":"12","metadata":{"id":"12"},"outputs":[],"source":["import optax\n","\n","learning_rate = 0.005\n","momentum = 0.9\n","\n","optimizer = nnx.Optimizer(model, optax.adamw(learning_rate, momentum), wrt=nnx.Param)\n","metrics = nnx.MultiMetric(\n"," accuracy=nnx.metrics.Accuracy(),\n"," loss=nnx.metrics.Average('loss'),\n",")\n","\n","nnx.display(optimizer)"]},{"cell_type":"markdown","id":"13","metadata":{"id":"13"},"source":["## 5. Define training step functions\n","\n","In this section, you will define a loss function using the cross entropy loss ([`optax.softmax_cross_entropy_with_integer_labels()`](https://optax.readthedocs.io/en/latest/api/losses.html#optax.softmax_cross_entropy_with_integer_labels)) that the CNN model will optimize over.\n","\n","In addition to the `loss`, during training and testing you will also get the `logits`, which will be used to calculate the accuracy metric.\n","\n","During training - the `train_step` - you will use `nnx.value_and_grad` to compute the gradients and update the model's parameters using the `optimizer` you have already defined. And during both training and testing (the `eval_step`), the `loss` and `logits` will be used to calculate the metrics."]},{"cell_type":"code","execution_count":null,"id":"14","metadata":{"id":"14"},"outputs":[],"source":["def loss_fn(model: CNN, batch):\n"," logits = model(batch['image'])\n"," loss = optax.softmax_cross_entropy_with_integer_labels(\n"," logits=logits, labels=batch['label']\n"," ).mean()\n"," return loss, logits\n","\n","@nnx.jit\n","def train_step(model: CNN, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch):\n"," \"\"\"Train for a single step.\"\"\"\n"," grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)\n"," (loss, logits), grads = grad_fn(model, batch)\n"," metrics.update(loss=loss, logits=logits, labels=batch['label']) # In-place updates.\n"," optimizer.update(model, grads) # In-place updates.\n","\n","@nnx.jit\n","def eval_step(model: CNN, metrics: nnx.MultiMetric, batch):\n"," loss, logits = loss_fn(model, batch)\n"," metrics.update(loss=loss, logits=logits, labels=batch['label']) # In-place updates."]},{"cell_type":"markdown","id":"17","metadata":{"id":"17"},"source":["In the code above, the [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) transformation decorator traces the `train_step` function for just-in-time compilation with [XLA](https://www.tensorflow.org/xla), optimizing performance on hardware accelerators, such as Google TPUs and GPUs. `nnx.jit` is a \"lifted\" version of the `jax.jit` transform that allows its function input and outputs to be Flax NNX objects. Similarly, `nnx.value_and_grad ` is a lifted version of `jax.value_and_grad `. Check out [the lifted transforms guide](https://flax.readthedocs.io/en/latest/guides/transforms.html) to learn more.\n","\n","> **Note:** The code shows how to perform several in-place updates to the model, the optimizer, and the metrics, but _state updates_ were not explicitly returned. This is because Flax NNX transformations respect _reference semantics_ for Flax NNX objects, and will propagate the state updates of the objects passed as input arguments. This is a key feature of Flax NNX that allows for a more concise and readable code. You can learn more in [Why Flax NNX](https://flax.readthedocs.io/en/latest/why.html)."]},{"cell_type":"markdown","id":"21","metadata":{"id":"21"},"source":["## 6. Train and evaluate the model\n","\n","Now, you can train the CNN model using batches of data for 10 epochs, evaluate the model’s performance\n","on the test set after each epoch, and log the training and testing metrics (the loss and\n","the accuracy) during the process. Typically this leads to the model achieving around 99% accuracy."]},{"cell_type":"code","execution_count":null,"id":"22","metadata":{"id":"22"},"outputs":[],"source":["from IPython.display import clear_output\n","import matplotlib.pyplot as plt\n","\n","metrics_history = {\n"," 'train_loss': [],\n"," 'train_accuracy': [],\n"," 'test_loss': [],\n"," 'test_accuracy': [],\n","}\n","\n","for step, batch in enumerate(train_ds.as_numpy_iterator()):\n"," # Run the optimization for one step and make a stateful update to the following:\n"," # - The train state's model parameters\n"," # - The optimizer state\n"," # - The training loss and accuracy batch metrics\n"," train_step(model, optimizer, metrics, batch)\n","\n"," if step > 0 and (step % eval_every == 0 or step == train_steps - 1): # One training epoch has passed.\n"," # Log the training metrics.\n"," for metric, value in metrics.compute().items(): # Compute the metrics.\n"," metrics_history[f'train_{metric}'].append(value) # Record the metrics.\n"," metrics.reset() # Reset the metrics for the test set.\n","\n"," # Compute the metrics on the test set after each training epoch.\n"," for test_batch in test_ds.as_numpy_iterator():\n"," eval_step(model, metrics, test_batch)\n","\n"," # Log the test metrics.\n"," for metric, value in metrics.compute().items():\n"," metrics_history[f'test_{metric}'].append(value)\n"," metrics.reset() # Reset the metrics for the next training epoch.\n","\n"," clear_output(wait=True)\n"," # Plot loss and accuracy in subplots\n"," fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))\n"," ax1.set_title('Loss')\n"," ax2.set_title('Accuracy')\n"," for dataset in ('train', 'test'):\n"," ax1.plot(metrics_history[f'{dataset}_loss'], label=f'{dataset}_loss')\n"," ax2.plot(metrics_history[f'{dataset}_accuracy'], label=f'{dataset}_accuracy')\n"," ax1.legend()\n"," ax2.legend()\n"," plt.show()"]},{"cell_type":"markdown","id":"25","metadata":{"id":"25"},"source":["## 7. Perform inference on the test set\n","\n","Create a `jit`-compiled model inference function (with `nnx.jit`) - `pred_step` - to generate predictions on the test set using the learned model parameters. This will enable you to visualize test images alongside their predicted labels for a qualitative assessment of model performance."]},{"cell_type":"code","execution_count":null,"id":"26","metadata":{"id":"26"},"outputs":[],"source":["model.eval() # Switch to evaluation mode.\n","\n","@nnx.jit\n","def pred_step(model: CNN, batch):\n"," logits = model(batch['image'])\n"," return logits.argmax(axis=1)"]},{"cell_type":"markdown","id":"1d6cb81f","metadata":{"id":"1d6cb81f"},"source":["Note that we use `.eval()` to ensure that the model is in evaluation mode, even though we are not using `Dropout` or `BatchNorm` in this model, `.eval()` ensure that the outputs are deterministic."]},{"cell_type":"code","execution_count":null,"id":"27","metadata":{"id":"27"},"outputs":[],"source":["test_batch = test_ds.as_numpy_iterator().next()\n","pred = pred_step(model, test_batch)\n","\n","fig, axs = plt.subplots(5, 5, figsize=(12, 12))\n","for i, ax in enumerate(axs.flatten()):\n"," ax.imshow(test_batch['image'][i, ..., 0], cmap='gray')\n"," ax.set_title(f'label={pred[i]}')\n"," ax.axis('off')"]},{"cell_type":"markdown","source":["## Exploring your model with Model Explorer\n","\n","To really dig into a model and understand the operations and connections, [Model Explorer](https://github.com/google-ai-edge/model-explorer/wiki/) is a great tool! Let's take a look now at our MNIST model. **Please feel free to poke around and explore the model!**"],"metadata":{"id":"N7n2PfE4LwMC"},"id":"N7n2PfE4LwMC"},{"cell_type":"code","source":["# Install Model Explorer\n","\n","!pip install --no-deps ai-edge-model-explorer-adapter ai-edge-model-explorer"],"metadata":{"id":"RLhM4CYkOaIh"},"id":"RLhM4CYkOaIh","execution_count":null,"outputs":[]},{"cell_type":"code","source":["# Use some dummy input and write the model MLIR to a file\n","\n","import jax\n","dummy_input = jnp.ones((1, 28, 28, 1))\n","stablehlo_mlir = jax.jit(model).lower(dummy_input).as_text(debug_info=True)\n","mlir_file = open(\"stablehlo_mlir.mlir\", \"w\")\n","mlir_file.write(stablehlo_mlir)\n","mlir_file.close()"],"metadata":{"id":"wo0F_ytxMAEq"},"id":"wo0F_ytxMAEq","execution_count":null,"outputs":[]},{"cell_type":"code","source":["# Import and run Model Explorer with the model\n","\n","import model_explorer\n","\n","model_explorer.visualize(\"stablehlo_mlir.mlir\")"],"metadata":{"id":"bbtTtGTmOmGX"},"id":"bbtTtGTmOmGX","execution_count":null,"outputs":[]},{"cell_type":"markdown","id":"28","metadata":{"id":"28"},"source":["Congratulations! You have learned how to use Flax NNX to build and train a simple classification model end-to-end on the MNIST dataset.\n","\n","Next, check out [Why Flax NNX?](https://flax.readthedocs.io/en/latest/why.html) and get started with a series of [Flax NNX Guides](https://flax.readthedocs.io/en/latest/guides/index.html)."]}],"metadata":{"jupytext":{"formats":"ipynb,md:myst","main_language":"python"},"language_info":{"name":"python"},"colab":{"provenance":[{"file_id":"14rZHWSa2Kw_vv-EMYr0eGqa21E4iald9","timestamp":1755113849998},{"file_id":"https://github.com/google/flax/blob/main/docs_nnx/mnist_tutorial.ipynb","timestamp":1748370481980}]},"kernelspec":{"name":"python3","display_name":"Python 3"}},"nbformat":4,"nbformat_minor":5} \ No newline at end of file diff --git a/docs/learning_jax/code-exercises/5 - Chex_ JAX & Flax NNX Reliability.ipynb b/docs/learning_jax/code-exercises/5 - Chex_ JAX & Flax NNX Reliability.ipynb new file mode 100644 index 0000000..3b89c79 --- /dev/null +++ b/docs/learning_jax/code-exercises/5 - Chex_ JAX & Flax NNX Reliability.ipynb @@ -0,0 +1 @@ +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[{"file_id":"19EP-EMOlRQ9LTWr8KpcGeiiXCKI76lWY","timestamp":1755113898559}],"toc_visible":true,"authorship_tag":"ABX9TyPYH8zwEBLMAY390LkhEEty"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","source":[" # Chex Exercises: Building Robust JAX & Flax NNX Applications\n","\n"," Welcome! This notebook contains exercises to help you practice using Chex\n"," with JAX and Flax NNX, based on the concepts covered in the lecture.\n","\n"," **Goal:** Solidify your understanding of how Chex enhances reliability and\n"," debuggability in JAX-based projects.\n","\n"," **Instructions:**\n"," 1. Read the problem description for each exercise.\n"," 2. Fill in the `TODO` sections with your code.\n"," 3. Run the cells to test your solutions.\n"," 4. Compare your results with the expected outcomes or hints provided.\n","\n","\n"," Let's get started!"],"metadata":{"id":"qrr2EgAVNLIL"}},{"cell_type":"code","source":["# Run this cell first to install and import necessary libraries.\n","!pip install -Uq flax chex\n","\n","import jax\n","import jax.numpy as jnp\n","import chex\n","import flax\n","from flax import nnx\n","import functools # For functools.partial\n","\n","# Helper to reset trace counter for assert_max_traces exercises\n","def reset_trace_counter():\n"," chex.clear_trace_counter()\n"," # For some JAX versions, a small trick might be needed to fully reset\n"," # internal JAX caches if you're re-running cells aggressively.\n"," # This is usually not needed for these exercises if cells are run in order.\n","\n","print(f\"JAX version: {jax.__version__}\")\n","print(f\"Chex version: {chex.__version__}\")\n","print(f\"Flax version: {flax.__version__}\")\n","print(f\"Running on: {jax.default_backend()}\")"],"metadata":{"id":"J9gmo6dRNGTV"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Section 1: Core Chex Assertions\n","Chex provides a suite of assertion functions to validate array properties.\n","Let's practice with the most common ones."],"metadata":{"id":"pJjjwML9Otv-"}},{"cell_type":"markdown","source":["### Exercise 1.1: `chex.assert_shape` and `chex.assert_type`\n","Complete the `process_data` function below.\n","- Add assertions to check if `input_array` has a shape of `(3, None)`\n"," (meaning 3 rows, any number of columns).\n","- Add an assertion to check if `input_array` has a `jnp.float32` dtype.\n","- Add an assertion to check if `output_array` has a shape of `(3, 1)`."],"metadata":{"id":"i_p5oqtqMGFG"}},{"cell_type":"code","source":["def process_data_v1(input_array: chex.Array) -> chex.Array:\n"," \"\"\"Processes an array, asserting shapes and types.\"\"\"\n"," # TODO: Assert input_array shape is (3, None)\n"," chex.assert_shape(input_array, )\n","\n"," # TODO: Assert input_array type is jnp.float32\n"," chex.assert_type()\n","\n"," # Simulate some processing that reduces the last dimension to 1\n"," output_array = input_array[:, :1] * 2.0\n","\n"," # TODO: Assert output_array shape is (3, 1)\n"," chex.assert_shape(output_array, (3, 1))\n","\n"," return output_array\n","\n","# Test cases\n","key = jax.random.PRNGKey(0)\n","valid_input = jax.random.normal(key, (3, 5), dtype=jnp.float32)\n","print(\"Testing with valid input...\")\n","result = process_data_v1(valid_input)\n","print(f\"Successfully processed valid input. Output shape: {result.shape}\\n\")\n","\n","print(\"Testing with invalid shape input...\")\n","invalid_shape_input = jax.random.normal(key, (4, 5), dtype=jnp.float32)\n","try:\n"," process_data_v1(invalid_shape_input)\n","except AssertionError as e:\n"," print(f\"Caught expected error for invalid shape:\\n{e}\\n\")\n","\n","print(\"Testing with invalid type input...\")\n","invalid_type_input = jnp.ones((3, 5), dtype=jnp.int32)\n","try:\n"," process_data_v1(invalid_type_input)\n","except AssertionError as e:\n"," print(f\"Caught expected error for invalid type: {e}\\n\")"],"metadata":{"id":"DGvjG5t7N3Ef"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Exercise 1.1 Solution"],"metadata":{"id":"PjTjZN_LL8Rx"}},{"cell_type":"code","source":["def process_data_v1(input_array: chex.Array) -> chex.Array:\n"," \"\"\"Processes an array, asserting shapes and types.\"\"\"\n"," # TODO: Assert input_array shape is (3, None)\n"," chex.assert_shape(input_array, (3, None))\n","\n"," # TODO: Assert input_array type is jnp.float32\n"," chex.assert_type(input_array, expected_types=jnp.float32)\n","\n"," # Simulate some processing that reduces the last dimension to 1\n"," output_array = input_array[:, :1] * 2.0\n","\n"," # TODO: Assert output_array shape is (3, 1)\n"," chex.assert_shape(output_array, (3, 1))\n","\n"," return output_array\n","\n","# Test cases\n","key = jax.random.PRNGKey(0)\n","valid_input = jax.random.normal(key, (3, 5), dtype=jnp.float32)\n","print(\"Testing with valid input...\")\n","result = process_data_v1(valid_input)\n","print(f\"Successfully processed valid input. Output shape: {result.shape}\\n\")\n","\n","print(\"Testing with invalid shape input...\")\n","invalid_shape_input = jax.random.normal(key, (4, 5), dtype=jnp.float32)\n","try:\n"," process_data_v1(invalid_shape_input)\n","except AssertionError as e:\n"," print(f\"Caught expected error for invalid shape:\\n{e}\\n\")\n","\n","print(\"Testing with invalid type input...\")\n","invalid_type_input = jnp.ones((3, 5), dtype=jnp.int32)\n","try:\n"," process_data_v1(invalid_type_input)\n","except AssertionError as e:\n"," print(f\"Caught expected error for invalid type: {e}\\n\")"],"metadata":{"id":"PsVXz6heLLOe"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Exercise 1.2: `chex.assert_rank` and `chex.assert_scalar`\n","Complete the `process_data_v2` function.\n","- Add an assertion to ensure `matrix_input` is a 2D array (rank 2).\n","- Add an assertion to ensure `scalar_input` is a scalar.\n","- Add an assertion to ensure the `result` is also a 2D array."],"metadata":{"id":"EH-SjocvR8Yc"}},{"cell_type":"code","source":["def process_data_v2(matrix_input: chex.Array, scalar_input: chex.Array) -> chex.Array:\n"," \"\"\"Processes a matrix and a scalar.\"\"\"\n"," # TODO: Assert matrix_input has rank 2\n"," chex.assert_rank(matrix_input, )\n","\n"," # TODO: Assert scalar_input is a scalar\n"," chex.assert_scalar()\n","\n"," result = matrix_input * scalar_input + 1.0\n","\n"," # TODO: Assert result has rank 2\n"," chex.assert_rank(result, )\n"," return result\n","\n","# Test cases\n","matrix = jnp.ones((3, 4))\n","scalar = 5.0\n","not_a_scalar = jnp.array([5.0])\n","not_a_matrix = jnp.ones((3,4,1))\n","\n","print(\"Testing with valid rank/scalar inputs...\")\n","try:\n"," res_valid = process_data_v2(matrix, scalar)\n"," print(f\"Successfully processed valid rank/scalar. Result shape: {res_valid.shape}\\n\")\n","except AssertionError as e:\n"," print(f\"Caught unexpected error for valid rank/scalar:\\n{e}\\n\")\n","\n","print(\"Testing with invalid rank input...\")\n","try:\n"," process_data_v2(not_a_matrix, scalar)\n"," print(f\"Successfully processed invalid rank. Result shape: {res_valid.shape}\\n\")\n","except AssertionError as e:\n"," print(f\"Caught expected error for invalid rank:\\n{e}\\n\")\n","\n","print(\"Testing with non-scalar input...\")\n","try:\n"," process_data_v2(matrix, not_a_scalar)\n"," print(f\"Successfully processed non-scalar. Result shape: {res_valid.shape}\\n\")\n","except AssertionError as e:\n"," print(f\"Caught expected error for non-scalar:\\n{e}\\n\")"],"metadata":{"id":"YWljMNCWPgh9"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Exercise 1.2 Solution"],"metadata":{"id":"RUUveZp7MkKZ"}},{"cell_type":"code","source":["def process_data_v2(matrix_input: chex.Array, scalar_input: chex.Array) -> chex.Array:\n"," \"\"\"Processes a matrix and a scalar.\"\"\"\n"," # TODO: Assert matrix_input has rank 2\n"," chex.assert_rank(matrix_input, expected_ranks=2)\n","\n"," # TODO: Assert scalar_input is a scalar\n"," chex.assert_scalar(scalar_input)\n","\n"," result = matrix_input * scalar_input + 1.0\n","\n"," # TODO: Assert result has rank 2\n"," chex.assert_rank(result, expected_ranks=2)\n"," return result\n","\n","# Test cases\n","matrix = jnp.ones((3, 4))\n","scalar = 5.0\n","not_a_scalar = jnp.array([5.0])\n","not_a_matrix = jnp.ones((3,4,1))\n","\n","print(\"Testing with valid rank/scalar inputs...\")\n","try:\n"," res_valid = process_data_v2(matrix, scalar)\n"," print(f\"Successfully processed valid rank/scalar. Result shape: {res_valid.shape}\\n\")\n","except AssertionError as e:\n"," print(f\"Caught unexpected error for valid rank/scalar:\\n{e}\\n\")\n","\n","print(\"Testing with invalid rank input...\")\n","try:\n"," process_data_v2(not_a_matrix, scalar)\n"," print(f\"Successfully processed invalid rank. Result shape: {res_valid.shape}\\n\")\n","except AssertionError as e:\n"," print(f\"Caught expected error for invalid rank:\\n{e}\\n\")\n","\n","print(\"Testing with non-scalar input...\")\n","try:\n"," process_data_v2(matrix, not_a_scalar)\n"," print(f\"Successfully processed non-scalar. Result shape: {res_valid.shape}\\n\")\n","except AssertionError as e:\n"," print(f\"Caught expected error for non-scalar:\\n{e}\\n\")"],"metadata":{"id":"GLXFFI7XMpT0"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Exercise 1.3: PyTree Assertions (`assert_trees_all_close`, `assert_tree_all_finite`)\n","PyTrees (nested structures of arrays, like model parameters) are common in JAX.\n","Chex provides assertions for them.\n"],"metadata":{"id":"FHwEQLHvTjdj"}},{"cell_type":"code","source":["def process_pytree(tree1, tree2):\n"," \"\"\"\n"," Checks if two PyTrees are close and if the first tree is finite.\n"," Returns a new tree where elements are tree1 + tree2.\n"," \"\"\"\n"," # TODO: Assert tree1 and tree2 are (close to) equal. Use a small tolerance.\n"," chex.assert_trees_all_close( rtol=1e-5, atol=1e-8)\n","\n"," # TODO: Assert all elements in tree1 are finite (not NaN or Inf).\n"," chex.assert_tree_all_finite()\n","\n"," # Perform some operation\n"," return jax.tree_util.tree_map(lambda x, y: x + y, tree1, tree2)\n","\n","# Test cases\n","tree_a = {'params': {'w': jnp.array([1.0, 2.0]), 'b': jnp.array(0.5)}}\n","tree_b_close = {'params': {'w': jnp.array([1.000001, 2.000001]), 'b': jnp.array(0.500001)}}\n","tree_c_not_close = {'params': {'w': jnp.array([1.1, 2.1]), 'b': jnp.array(0.6)}}\n","tree_d_nan = {'params': {'w': jnp.array([1.0, jnp.nan]), 'b': jnp.array(0.5)}}\n","\n","print(\"Testing with close and finite PyTrees...\")\n","result_valid = process_pytree(tree_a, tree_b_close)\n","print(\"Successfully processed valid PyTrees.\\n\")\n","\n","print(\"Testing with non-close PyTrees...\")\n","try:\n"," process_pytree(tree_a, tree_c_not_close)\n","except AssertionError as e:\n"," print(f\"Caught expected error for non-close trees:\\n\\n{e}\\n\")\n","\n","print(\"Testing with non-finite PyTree...\")\n","try:\n"," process_pytree(tree_d_nan, tree_b_close) # tree_d_nan will be checked for finiteness\n","except AssertionError as e:\n"," print(f\"Caught expected error for non-finite tree:\\n\\n{e}\\n\")"],"metadata":{"id":"NWxlO6o6Sczm"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Exercise 1.3 Solution"],"metadata":{"id":"d0d-QNyBNDhO"}},{"cell_type":"code","source":["def process_pytree(tree1, tree2):\n"," \"\"\"\n"," Checks if two PyTrees are close and if the first tree is finite.\n"," Returns a new tree where elements are tree1 + tree2.\n"," \"\"\"\n"," # TODO: Assert tree1 and tree2 are (close to) equal. Use a small tolerance.\n"," chex.assert_trees_all_close(tree1, tree2, rtol=1e-5, atol=1e-8)\n","\n"," # TODO: Assert all elements in tree1 are finite (not NaN or Inf).\n"," chex.assert_tree_all_finite(tree1)\n","\n"," # Perform some operation\n"," return jax.tree_util.tree_map(lambda x, y: x + y, tree1, tree2)\n","\n","# Test cases\n","tree_a = {'params': {'w': jnp.array([1.0, 2.0]), 'b': jnp.array(0.5)}}\n","tree_b_close = {'params': {'w': jnp.array([1.000001, 2.000001]), 'b': jnp.array(0.500001)}}\n","tree_c_not_close = {'params': {'w': jnp.array([1.1, 2.1]), 'b': jnp.array(0.6)}}\n","tree_d_nan = {'params': {'w': jnp.array([1.0, jnp.nan]), 'b': jnp.array(0.5)}}\n","\n","print(\"Testing with close and finite PyTrees...\")\n","result_valid = process_pytree(tree_a, tree_b_close)\n","print(\"Successfully processed valid PyTrees.\\n\")\n","\n","print(\"Testing with non-close PyTrees...\")\n","try:\n"," process_pytree(tree_a, tree_c_not_close)\n","except AssertionError as e:\n"," print(f\"Caught expected error for non-close trees:\\n\\n{e}\\n\")\n","\n","print(\"Testing with non-finite PyTree...\")\n","try:\n"," process_pytree(tree_d_nan, tree_b_close) # tree_d_nan will be checked for finiteness\n","except AssertionError as e:\n"," print(f\"Caught expected error for non-finite tree:\\n\\n{e}\\n\")"],"metadata":{"id":"FjbrC5EENG_D"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Section 2: Chex Assertions with JAX Transformations\n","A key strength of Chex is that its assertions work correctly inside JAX transformations like `jax.jit` and `jax.vmap`.\n","\n","### Exercise 2.1: Assertions inside `@jax.jit`\n","- Take the `process_data_v1` function from Exercise 1.1.\n","- JIT-compile it and verify that the Chex assertions still work as expected."],"metadata":{"id":"2fRr4rjVVQx9"}},{"cell_type":"code","source":["@jax.jit\n","def process_data_jitted(input_array: chex.Array) -> chex.Array:\n"," \"\"\"JIT-compiled version of process_data_v1 with its Chex assertions.\"\"\"\n"," # (Assertions are inside process_data_v1, which we'll effectively re-use here)\n"," # For clarity, let's re-define it with assertions directly here.\n"," chex.assert_shape(input_array, (3, None))\n"," chex.assert_type(input_array, jnp.float32)\n"," output_array = input_array[:, :1] * 2.0\n"," chex.assert_shape(output_array, (3, 1))\n"," return output_array\n","\n","# Test cases for JIT version\n","key = jax.random.PRNGKey(1) # Use a different key for potentially different values\n","valid_input_jit = jax.random.normal(key, (3, 5), dtype=jnp.float32)\n","print(\"Testing JITted function with valid input...\")\n","\n","# First call will compile\n","result_jit = process_data_jitted()\n","print(f\"Successfully processed JITted valid input. Output shape: {result_jit.shape}\")\n","\n","# Second call uses cached compilation\n","result_jit_cached = process_data_jitted( * 2)\n","print(f\"Successfully processed JITted valid input (cached). Output shape: {result_jit_cached.shape}\\n\")\n","\n","print(\"Testing JITted function with invalid shape input...\")\n","invalid_shape_input_jit = jax.random.normal(key, (4, 5), dtype=jnp.float32)\n","try:\n"," process_data_jitted()\n","except AssertionError as e:\n"," print(f\"Caught expected JITted error for invalid shape:\\n\\n{e}\\n\")"],"metadata":{"id":"1u8_6x-wUJOi"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Exercise 2.1 Solution"],"metadata":{"id":"inyiNeL3N7VD"}},{"cell_type":"code","source":["@jax.jit\n","def process_data_jitted(input_array: chex.Array) -> chex.Array:\n"," \"\"\"JIT-compiled version of process_data_v1 with its Chex assertions.\"\"\"\n"," # (Assertions are inside process_data_v1, which we'll effectively re-use here)\n"," # For clarity, let's re-define it with assertions directly here.\n"," chex.assert_shape(input_array, (3, None))\n"," chex.assert_type(input_array, jnp.float32)\n"," output_array = input_array[:, :1] * 2.0\n"," chex.assert_shape(output_array, (3, 1))\n"," return output_array\n","\n","# Test cases for JIT version\n","key = jax.random.PRNGKey(1) # Use a different key for potentially different values\n","valid_input_jit = jax.random.normal(key, (3, 5), dtype=jnp.float32)\n","print(\"Testing JITted function with valid input...\")\n","\n","# First call will compile\n","result_jit = process_data_jitted(valid_input_jit)\n","print(f\"Successfully processed JITted valid input. Output shape: {result_jit.shape}\")\n","\n","# Second call uses cached compilation\n","result_jit_cached = process_data_jitted(valid_input_jit * 2)\n","print(f\"Successfully processed JITted valid input (cached). Output shape: {result_jit_cached.shape}\\n\")\n","\n","print(\"Testing JITted function with invalid shape input...\")\n","invalid_shape_input_jit = jax.random.normal(key, (4, 5), dtype=jnp.float32)\n","try:\n"," process_data_jitted(invalid_shape_input_jit)\n","except AssertionError as e:\n"," print(f\"Caught expected JITted error for invalid shape:\\n\\n{e}\\n\")"],"metadata":{"id":"5NluMK76OACI"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["**Observation:**\n","\n","Chex assertions work seamlessly within JITted functions, catching errors based on the concrete values passed during runtime, even though the checks are defined within the compiled code.\n","\n","---\n","### Exercise 2.2: Multi-Level Validation with `@jax.vmap`\n","We want to process a batch of items. Each item is a 1D array of shape `(10,)`.\n","\n","1. Define `process_single_item_vmap` that processes one item.\n"," - Inside this function, assert the `item` has shape `(10,)`.\n"," - The function should double the item's values.\n"," - Assert the `result` (output of `process_single_item_vmap`) also has shape `(10,)`.\n","2. Use `jax.vmap` to create `process_batch`.\n","3. Before calling `process_batch`, assert the `batch_input` has shape `(BATCH_SIZE, 10)`.\n","4. After calling `process_batch`, assert the `batch_output` has shape `(BATCH_SIZE, 10)`."],"metadata":{"id":"m_emU0mfWB4F"}},{"cell_type":"code","source":["BATCH_SIZE = 5\n","ITEM_SIZE = 10\n","\n","def process_single_item_vmap(item: chex.Array) -> chex.Array:\n"," \"\"\"Processes a single item, asserting its shape.\"\"\"\n"," # TODO: Assert shape of a SINGLE item is (ITEM_SIZE,)\n"," chex.assert_shape(item, )\n"," result = item * 2.0\n"," # TODO: Assert shape of single item output is (ITEM_SIZE,)\n"," chex.assert_shape(result, )\n"," return result\n","\n","# TODO: Vectorize the process_single_item_vmap function using jax.vmap\n","process_batch = jax.vmap(, in_axes=0, out_axes=0)\n","\n","# Test cases\n","key = jax.random.PRNGKey(2)\n","valid_batch_input = jax.random.normal(key, (BATCH_SIZE, ITEM_SIZE))\n","invalid_batch_input_item_shape = jax.random.normal(key, (BATCH_SIZE, ITEM_SIZE + 1))\n","\n","print(\"Testing vmap with valid batch input...\")\n","# TODO: Assert shape of the full BATCHED input BEFORE vmap\n","chex.assert_shape(valid_batch_input, )\n","\n","batch_output = process_batch(valid_batch_input)\n","\n","# TODO: Assert shape of the full BATCHED output AFTER vmap\n","chex.assert_shape(batch_output, )\n","print(f\"Vmap assertion passed. Output shape: {batch_output.shape}\\n\")\n","\n","print(\"Testing vmap with invalid item shape in batch (error from inside vmap)...\")\n","try:\n"," # This will fail inside the vmapped function 'process_single_item_vmap'\n"," process_batch(invalid_batch_input_item_shape)\n","except AssertionError as e:\n"," print(f\"Caught expected vmap error (from inner function):\\n{e}\\n\")\n","\n","print(\"Testing vmap with invalid batch shape (error from outer assertion)...\")\n","invalid_batch_input_outer_shape = jax.random.normal(key, (BATCH_SIZE + 1, ITEM_SIZE))\n","try:\n"," # This will fail the assertion *before* calling process_batch\n"," chex.assert_shape(invalid_batch_input_outer_shape, (BATCH_SIZE, ITEM_SIZE)) # This line will fail\n"," process_batch(invalid_batch_input_outer_shape)\n","except AssertionError as e:\n"," print(f\"Caught expected vmap error (from outer assertion):\\n{e}\\n\")"],"metadata":{"id":"NVbMomrdV4wZ"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Exercise 2.2 Solution"],"metadata":{"id":"6Mja-kBpOoSc"}},{"cell_type":"code","source":["BATCH_SIZE = 5\n","ITEM_SIZE = 10\n","\n","def process_single_item_vmap(item: chex.Array) -> chex.Array:\n"," \"\"\"Processes a single item, asserting its shape.\"\"\"\n"," # TODO: Assert shape of a SINGLE item is (ITEM_SIZE,)\n"," chex.assert_shape(item, (ITEM_SIZE,))\n"," result = item * 2.0\n"," # TODO: Assert shape of single item output is (ITEM_SIZE,)\n"," chex.assert_shape(result, (ITEM_SIZE,))\n"," return result\n","\n","# TODO: Vectorize the function using jax.vmap\n","process_batch = jax.vmap(process_single_item_vmap, in_axes=0, out_axes=0)\n","\n","# Test cases\n","key = jax.random.PRNGKey(2)\n","valid_batch_input = jax.random.normal(key, (BATCH_SIZE, ITEM_SIZE))\n","invalid_batch_input_item_shape = jax.random.normal(key, (BATCH_SIZE, ITEM_SIZE + 1))\n","\n","print(\"Testing vmap with valid batch input...\")\n","# TODO: Assert shape of the full BATCHED input BEFORE vmap\n","chex.assert_shape(valid_batch_input, (BATCH_SIZE, ITEM_SIZE))\n","\n","batch_output = process_batch(valid_batch_input)\n","\n","# TODO: Assert shape of the full BATCHED output AFTER vmap\n","chex.assert_shape(batch_output, (BATCH_SIZE, ITEM_SIZE))\n","print(f\"Vmap assertion passed. Output shape: {batch_output.shape}\\n\")\n","\n","print(\"Testing vmap with invalid item shape in batch (error from inside vmap)...\")\n","try:\n"," # This will fail inside the vmapped function 'process_single_item_vmap'\n"," process_batch(invalid_batch_input_item_shape)\n","except AssertionError as e:\n"," print(f\"Caught expected vmap error (from inner function):\\n{e}\\n\")\n","\n","print(\"Testing vmap with invalid batch shape (error from outer assertion)...\")\n","invalid_batch_input_outer_shape = jax.random.normal(key, (BATCH_SIZE + 1, ITEM_SIZE))\n","try:\n"," # This will fail the assertion *before* calling process_batch\n"," chex.assert_shape(invalid_batch_input_outer_shape, (BATCH_SIZE, ITEM_SIZE)) # This line will fail\n"," process_batch(invalid_batch_input_outer_shape)\n","except AssertionError as e:\n"," print(f\"Caught expected vmap error (from outer assertion):\\n{e}\\n\")"],"metadata":{"id":"_ma-wKQgOs-a"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Section 3: Chex with Flax NNX\n","Neural networks are complex, making validation crucial. Chex integrates naturally into Flax NNX Modules, typically within the `__call__` method.\n","\n","### Exercise 3.1: Input/Output Validation in an NNX Module\n","Complete the `SimpleMLP` module:\n","- In `__call__`, validate the input `x`:\n"," - Must be 2D (`[batch, features]`).\n"," - The feature dimension (axis 1) must match `self.linear1.in_features`.\n"," - Type must be `jnp.float32`.\n","- In `__call__`, validate the output `x` before returning:\n"," - Must be 2D.\n"," - The feature dimension (axis 1) must match `self.linear2.out_features`."],"metadata":{"id":"cDJ1et86XGLe"}},{"cell_type":"code","source":["class SimpleMLP(nnx.Module):\n"," def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):\n"," self.linear1 = nnx.Linear(din, dmid, rngs=rngs)\n"," self.linear2 = nnx.Linear(dmid, dout, rngs=rngs)\n","\n"," def __call__(self, x: chex.Array) -> chex.Array:\n"," # TODO: Validate input x\n"," # - Must be 2D ([batch, features])\n"," chex.assert_rank()\n"," # - Feature dimension (axis 1) must match self.linear1.in_features\n"," chex.assert_axis_dimension(x, 1, )\n"," # - Type must be jnp.float32\n"," chex.assert_type(x, )\n","\n"," # Forward pass\n"," x = self.linear1(x)\n"," x = nnx.relu(x)\n"," x = self.linear2(x)\n","\n"," # TODO: Validate output x before returning\n"," # - Must be 2D\n"," chex.assert_rank()\n"," # - Feature dimension (axis 1) must match self.linear2.out_features\n"," chex.assert_axis_dimension(x, 1, self.linear2.out_features)\n","\n"," return x\n","\n","# Test cases for SimpleMLP\n","key_nnx = nnx.Rngs(params=jax.random.key(0)) # NNX Rngs for stateful operations\n","din, dmid, dout = 10, 20, 5\n","batch_size_nnx = 4\n","\n","model = SimpleMLP(din, dmid, dout, rngs=key_nnx)\n","\n","print(\"Testing NNX Module with valid input:\")\n","x_valid_nnx = jnp.ones((batch_size_nnx, din), dtype=jnp.float32)\n","output_nnx = model(x_valid_nnx)\n","print(f\"NNX I/O Check passed. Output shape: {output_nnx.shape}\\n\")\n","\n","\n","print(\"Testing NNX Module with invalid input rank:\")\n","x_invalid_rank_nnx = jnp.ones((batch_size_nnx, din, 1), dtype=jnp.float32)\n","try:\n"," model(x_invalid_rank_nnx)\n","except AssertionError as e:\n"," print(f\"Caught expected NNX error (invalid input rank):\\n{e}\\n\")\n","\n","print(\"Testing NNX Module with invalid input feature dimension:\")\n","x_invalid_feat_nnx = jnp.ones((batch_size_nnx, din + 1), dtype=jnp.float32)\n","try:\n"," model(x_invalid_feat_nnx)\n","except AssertionError as e:\n"," print(f\"Caught expected NNX error (invalid input features):\\n{e}\\n\")\n","\n","print(\"Testing NNX Module with invalid input type:\")\n","x_invalid_type_nnx = jnp.ones((batch_size_nnx, din), dtype=jnp.int32)\n","try:\n"," model(x_invalid_type_nnx)\n","except AssertionError as e:\n"," print(f\"Caught expected NNX error (invalid input type):\\n{e}\\n\")"],"metadata":{"id":"eu3llFxAWkLG"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Exercise 3.1 Solution"],"metadata":{"id":"XrxRhSFAPWti"}},{"cell_type":"code","source":["class SimpleMLP(nnx.Module):\n"," def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):\n"," self.linear1 = nnx.Linear(din, dmid, rngs=rngs)\n"," self.linear2 = nnx.Linear(dmid, dout, rngs=rngs)\n","\n"," def __call__(self, x: chex.Array) -> chex.Array:\n"," # TODO: Validate input x\n"," # - Must be 2D ([batch, features])\n"," chex.assert_rank(x, 2)\n"," # - Feature dimension (axis 1) must match self.linear1.in_features\n"," chex.assert_axis_dimension(x, 1, self.linear1.in_features)\n"," # - Type must be jnp.float32\n"," chex.assert_type(x, jnp.float32)\n","\n"," # Forward pass\n"," x = self.linear1(x)\n"," x = nnx.relu(x)\n"," x = self.linear2(x)\n","\n"," # TODO: Validate output x before returning\n"," # - Must be 2D\n"," chex.assert_rank(x, 2)\n"," # - Feature dimension (axis 1) must match self.linear2.out_features\n"," chex.assert_axis_dimension(x, 1, self.linear2.out_features)\n","\n"," return x\n","\n","# Test cases for SimpleMLP\n","key_nnx = nnx.Rngs(params=jax.random.key(0)) # NNX Rngs for stateful operations\n","din, dmid, dout = 10, 20, 5\n","batch_size_nnx = 4\n","\n","model = SimpleMLP(din, dmid, dout, rngs=key_nnx)\n","\n","print(\"Testing NNX Module with valid input:\")\n","x_valid_nnx = jnp.ones((batch_size_nnx, din), dtype=jnp.float32)\n","output_nnx = model(x_valid_nnx)\n","print(f\"NNX I/O Check passed. Output shape: {output_nnx.shape}\\n\")\n","\n","\n","print(\"Testing NNX Module with invalid input rank:\")\n","x_invalid_rank_nnx = jnp.ones((batch_size_nnx, din, 1), dtype=jnp.float32)\n","try:\n"," model(x_invalid_rank_nnx)\n","except AssertionError as e:\n"," print(f\"Caught expected NNX error (invalid input rank):\\n{e}\\n\")\n","\n","print(\"Testing NNX Module with invalid input feature dimension:\")\n","x_invalid_feat_nnx = jnp.ones((batch_size_nnx, din + 1), dtype=jnp.float32)\n","try:\n"," model(x_invalid_feat_nnx)\n","except AssertionError as e:\n"," print(f\"Caught expected NNX error (invalid input features):\\n{e}\\n\")\n","\n","print(\"Testing NNX Module with invalid input type:\")\n","x_invalid_type_nnx = jnp.ones((batch_size_nnx, din), dtype=jnp.int32)\n","try:\n"," model(x_invalid_type_nnx)\n","except AssertionError as e:\n"," print(f\"Caught expected NNX error (invalid input type):\\n{e}\\n\")"],"metadata":{"id":"uKHHvhcuPcH2"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["**Self-reflection:**\n","\n","How would these assertions help you catch bugs early when composing multiple layers or changing model configurations? They act as contracts between layers and for the model's external API."],"metadata":{"id":"ZgmPRssPtHbB"}},{"cell_type":"markdown","source":["---\n","### 🏆 Congratulations!\n","You've completed the Chex exercises. You should now have a better understanding of:\n","- Using core Chex assertions for shapes, types, ranks, and PyTrees.\n","- How Chex assertions behave within `jax.jit` and `jax.vmap`.\n","- The purpose and usage of `@chex.chexify` (and its caveats).\n","- Detecting recompilation issues with `@chex.assert_max_traces`.\n","- Integrating Chex assertions into Flax NNX modules for robust model development.\n","\n","Using Chex consistently can significantly improve the reliability and\n","maintainability of your JAX projects.\n","\n","**Further Exploration (Optional):**\n","- Explore using `chex.chexify` outside of a Colab environment.\n","- Explore other Chex assertions not covered here (e.g., `chex.assert_devices_available`).\n","- Look into Chex testing utilities like `@chex.variants` if you write comprehensive test suites.\n","- Consider when and where to add Chex assertions in a typical training loop."],"metadata":{"id":"2zIbByO9tlqf"}}]} \ No newline at end of file diff --git a/docs/learning_jax/code-exercises/6 - Debugging JAX and Flax NNX.ipynb b/docs/learning_jax/code-exercises/6 - Debugging JAX and Flax NNX.ipynb new file mode 100644 index 0000000..fe25123 --- /dev/null +++ b/docs/learning_jax/code-exercises/6 - Debugging JAX and Flax NNX.ipynb @@ -0,0 +1 @@ +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[{"file_id":"1VX3x1EduykqtaT1xhZkoaHZxNn3CqBrJ","timestamp":1755113930444}],"authorship_tag":"ABX9TyPmBiaJXUTc1jPQEsiFdPg8"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","source":["# Colab Notebook: Debugging JAX & Flax NNX - Exercises\n","\n","Welcome! This notebook contains exercises to help you practice the JAX and Flax NNX debugging techniques discussed in the lecture. If you're a PyTorch user you'll find some concepts familiar, while others are specific to JAX's compiled nature. Remember to run the setup cells first!\n","\n","First, let's install the necessary libraries and import them."],"metadata":{"id":"yJcIGTl1Wqb8"}},{"cell_type":"code","source":["# Start by updating the protobuf version, which may require a restart\n","\n","!pip install -U protobuf"],"metadata":{"id":"B490gzJL8Xlp"},"execution_count":null,"outputs":[]},{"cell_type":"code","execution_count":null,"metadata":{"id":"1gdDkIB1WjSg"},"outputs":[],"source":["!pip install -Uq flax jax jaxlib chex optax"]},{"cell_type":"code","source":["import jax\n","import jax.numpy as jnp\n","from jax import jit, grad, vmap\n","import flax\n","from flax import nnx\n","import chex\n","import pdb # Python's built-in debugger\n","import functools # For functools.partial\n","import optax # For optimizers, though we won't train deeply\n","\n","chex.set_n_cpu_devices(8) # Fake an environment with 8 CPUs. This must be done before any JAX operations\n","print(f\"Fake devices: {jax.devices()}\")\n","\n","# NOTE for Flax v0.11+: The flax.nnx.Optimizer API has changed.\n","# It now requires a `wrt` argument at construction (e.g., wrt=nnx.Param)\n","# and the update call is now `optimizer.update(model, grads)` instead of `optimizer.update(grads)`.\n","\n","# Helper to clear chex trace counter for repeatable examples\n","chex.clear_trace_counter()\n","\n","print(f\"JAX version: {jax.__version__}\")\n","print(f\"Flax version: {flax.__version__}\") # NNX is part of flax\n","print(f\"Chex version: {chex.__version__}\")\n","print(f\"Device: {jax.devices()}\")"],"metadata":{"id":"ECwCGI6VXD7e"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## 1. \"printf Debugging\" in JAX: jax.debug.print()\n","\n","JAX's JIT compilation means standard Python print() behaves differently inside JITted functions. It sees tracers during compilation, not runtime values. jax.debug.print() is the JAX-aware alternative."],"metadata":{"id":"ShqGVC2hXkYx"}},{"cell_type":"markdown","source":["### Exercise 1.1:\n","1. Uncomment and complete the line # YOUR CODE HERE in the compute_and_print function above.\n","2. Add a jax.debug.print() statement to display the runtime value of z.\n","3. Run the cell. Observe the outputs.\n"," - What does the standard print(y) show?\n"," - What do the jax.debug.print statements show for y and z? Why is this different?"],"metadata":{"id":"whTw53ilYNpJ"}},{"cell_type":"code","source":["@jit\n","def compute_and_print(x):\n"," y = x * 10\n"," print(\"Standard print (sees tracer):\", y)\n"," jax.debug.print(\"jax.debug.print (sees runtime value for y): {y_val}\", y_val=y, ordered=True)\n","\n"," z = y / 2\n"," # Exercise 1.1: Add another jax.debug.print here to see the runtime value of 'z'\n"," # Make sure to give it a descriptive message and use the ordered=True argument.\n"," # YOUR CODE HERE\n","\n"," return z\n","\n","input_val = jnp.array(5.0)\n","print(f\"Input value: {input_val}\\n\")\n","output_val = compute_and_print(input_val)\n","print(f\"\\nFinal output: {output_val}\")"],"metadata":{"id":"UAL9lJ_-XN99"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Solution (for Exercise 1.1, after attempting):"],"metadata":{"id":"NLK-ncceYlhk"}},{"cell_type":"code","source":["# @jit\n","# def compute_and_print_solution(x):\n","# y = x * 10\n","# print(\"Standard print (sees tracer):\", y)\n","# jax.debug.print(\"jax.debug.print (sees runtime value for y): {y_val}\", y_val=y, ordered=True)\n","\n","# z = y / 2\n","# jax.debug.print(\"jax.debug.print (sees runtime value for z): {z_val}\", z_val=z, ordered=True) # SOLUTION\n","\n","# return z\n","\n","# input_val = jnp.array(5.0)\n","# print(f\"Input value: {input_val}\\n\")\n","# output_val = compute_and_print_solution(input_val)\n","# print(f\"\\nFinal output: {output_val}\")"],"metadata":{"id":"qY5aNK0uX7p7"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Standard print shows a tracer object (e.g., Tracedwith). This is because it executes during JAX's tracing phase. jax.debug.print shows the concrete numerical values (e.g., 50.0 for y, 25.0 for z) because it's embedded into the compiled computation graph and executes with runtime data."],"metadata":{"id":"kiK87sd1Y_8H"}},{"cell_type":"markdown","source":["## 2. Interactive Debugging in JIT: jax.debug.breakpoint()\n","jax.debug.breakpoint() is JAX's equivalent of pdb.set_trace() for use inside transformed functions. It pauses execution and gives you a (jaxdb) prompt."],"metadata":{"id":"zM0niKH1cbE0"}},{"cell_type":"markdown","source":["### Exercise 2.1:\n","1. Uncomment and complete the line # YOUR CODE HERE in the interact_with_values function above.\n","2. Add jax.debug.breakpoint() where indicated.\n","3. Run the cell.\n","4. When execution pauses at the (jaxdb) prompt:\n"," - Inspect the value of y by typing p y and pressing Enter.\n"," - Continue execution by typing c and pressing Enter.\n","5. Note that jaxdb has a subset of pdb commands (e.g., stepping n or s is not available)."],"metadata":{"id":"tqNY12kxc3Rb"}},{"cell_type":"code","source":["@jit\n","def interact_with_values(x):\n"," y = jnp.sin(x)\n"," jax.debug.print(\"Value of y before breakpoint: {y_val}\", y_val=y)\n","\n"," # Exercise 2.1: Place the breakpoint here.\n"," # YOUR CODE HERE\n","\n"," z = jnp.cos(y)\n"," jax.debug.print(\"Value of z after breakpoint: {z_val}\", z_val=z)\n"," return z\n","\n","input_angle = jnp.array(0.75)\n","print(\"Calling interact_with_values...\")\n","result = interact_with_values(input_angle)\n","print(f\"Result: {result}\")"],"metadata":{"id":"LerEFq1yYwPd"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Solution (for Exercise 2.1, after attempting):"],"metadata":{"id":"TW-HX9bGd1iO"}},{"cell_type":"code","source":["# @jit\n","# def interact_with_values_solution(x):\n","# y = jnp.sin(x)\n","# jax.debug.print(\"Value of y before breakpoint: {y_val}\", y_val=y)\n","\n","# jax.debug.breakpoint() # SOLUTION\n","\n","# z = jnp.cos(y)\n","# jax.debug.print(\"Value of z after breakpoint: {z_val}\", z_val=z)\n","# return z\n","\n","# input_angle = jnp.array(0.75)\n","# print(\"Calling interact_with_values...\")\n","# result = interact_with_values_solution(input_angle)\n","# print(f\"Result: {result}\")"],"metadata":{"id":"LLKAPQAWciOs"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## 3. Back to Basics: Temporarily Disabling JIT with jax.disable_jit()\n","Sometimes, you need the full power of standard Python debugging tools. jax.disable_jit() allows JAX functions to execute eagerly.\n","\n","### Exercise 3.1 & 3.2:\n","1. In `complex_calculation`, add pdb.set_trace() where indicated (# YOUR CODE HERE for 3.1).\n","2. First, try running the cell as is (with Scenario 1 uncommented and Scenario 2's call commented out). Observe what happens with pdb.set_trace() inside a JITted function.\n","3. Then, comment out Scenario 1.\n","4. In Scenario 2, within the with jax.disable_jit(): block, call `complex_calculation` (where # YOUR CODE HERE for 3.2 is) with value (try 0.1 first, then 5.0 to ensure the conditional is met) and threshold=0.5.\n","5. When pdb triggers:\n"," - Inspect a, b, and c.\n"," - Type c to continue.\n","6. Reflect: When would you use jax.disable_jit() over jax.debug.breakpoint()?"],"metadata":{"id":"HC5gk0aF4Fye"}},{"cell_type":"code","source":["@jit\n","def complex_calculation(x, threshold):\n"," a = x * 2.0\n"," b = jnp.log(a)\n"," c = b + x\n"," # Imagine 'c' sometimes becomes NaN, and it's hard to see why.\n"," # We want to inspect 'a', 'b', and 'c' using standard pdb.\n"," if c > threshold: # This condition might be tricky under JIT\n"," # Exercise 3.1: Add a pdb.set_trace() here.\n"," # It will only work if JIT is disabled for this function call.\n"," # YOUR CODE HERE\n"," print(\"Inside conditional pdb trace\") # This will print if pdb is hit\n"," d = jnp.sqrt(jnp.abs(c)) # abs to avoid NaNs from sqrt of negative\n"," return d\n","\n","value = jnp.array(0.1) # Try with 0.1 then with 5.0\n","\n","# Scenario 1: JIT enabled (pdb.set_trace() will be skipped or might error)\n","# print(\"--- Running WITH JIT (pdb will likely be skipped) ---\")\n","# try:\n","# result_jit = complex_calculation(value, threshold=0.5)\n","# print(f\"Result with JIT: {result_jit}\")\n","# except Exception as e:\n","# print(f\"Scenario 1 Error:\\n{e}\\n\")\n","\n","# Scenario 2: JIT disabled\n","print(\"\\n--- Running with JIT DISABLED for this block ---\")\n","with jax.disable_jit():\n"," # Exercise 3.2: Call complex_calculation here with value and threshold=0.5\n"," # so that your pdb.set_trace() (from Ex 3.1) gets triggered.\n"," # YOUR CODE HERE\n"," pass # remove this pass\n","\n","print(\"Finished disable_jit block.\")"],"metadata":{"id":"zBF2wvwe4Psb"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Solution (for Exercise 3.1 & 3.2, after attempting):"],"metadata":{"id":"ta9o2YnI5CbE"}},{"cell_type":"code","source":["# @jit\n","# def complex_calculation_solution(x, threshold):\n","# a = x * 2.0\n","# b = jnp.log(a)\n","# c = b + x\n","# if c > threshold:\n","# pdb.set_trace() # SOLUTION 3.1\n","# print(\"Inside conditional pdb trace\")\n","# d = jnp.sqrt(jnp.abs(c))\n","# return d\n","\n","# value_for_pdb = jnp.array(5.0) # This value will trigger the condition c > threshold\n","\n","# # Scenario 1: JIT enabled (pdb.set_trace() will be skipped or might error)\n","# print(\"--- Running WITH JIT (pdb will likely be skipped) ---\")\n","# try:\n","# result_jit = complex_calculation_solution(value_for_pdb, threshold=0.5)\n","# print(f\"Result with JIT: {result_jit}\")\n","# except Exception as e:\n","# print(f\"Scenario 1 Error:\\n{e}\\n\")\n","\n","# print(\"\\n--- Running with JIT DISABLED for this block ---\")\n","# with jax.disable_jit():\n","# result_no_jit = complex_calculation_solution(value_for_pdb, threshold=0.5) # SOLUTION 3.2\n","# print(f\"Result with JIT disabled: {result_no_jit}\")\n","# print(\"Finished disable_jit block.\")"],"metadata":{"id":"wYGhziLN43D1"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["You'd use jax.disable_jit() when jax.debug.breakpoint() is insufficient, e.g., when you need the full pdb features (like stepping), want to use an IDE debugger, or when jax.debug.breakpoint() itself doesn't give enough context. The trade-off is performance loss."],"metadata":{"id":"Xc49FTqw61A1"}},{"cell_type":"markdown","source":["## 4. Automatic NaN Hunting: jax_debug_nans Flag\n","NaNs can be a nightmare. jax_debug_nans helps JAX pinpoint the exact operation causing them.\n","\n","### Exercise 4.1 & 4.2:\n","1. In Scenario 1, uncomment the example call or create your own call to problematic_function_for_nans that results in a NaN (e.g., x = jnp.array(-1.0), divisor = jnp.array(1.0) or x = jnp.array(1.0), divisor = jnp.array(0.0)). Run and observe the error.\n","2. In Scenario 2:\n"," - Uncomment the line jax.config.update(\"jax_debug_nans\", True).\n"," - Uncomment the example call or use the same NaN-causing inputs as in 4.1.\n"," - Run and observe the error message. Is it more helpful in pinpointing the source of the NaN?\n"," - Make sure the finally block runs to disable the flag.\n","3. Why is jax_debug_nans not enabled by default?"],"metadata":{"id":"Sr75lY78658N"}},{"cell_type":"code","source":["@jit\n","def problematic_function_for_nans(x, divisor):\n"," y = x * 100\n"," # This operation can cause NaN if divisor is 0 or x is negative and we take log\n"," z = jnp.log(y) / divisor # Potential NaN source\n"," return z + y\n","\n","# Scenario 1: Run without jax_debug_nans\n","print(\"--- Scenario 1: Running without jax_debug_nans ---\")\n","try:\n"," # Exercise 4.1: Call problematic_function_for_nans with inputs that cause a NaN\n"," # For example, x = jnp.array(-1.0), divisor = jnp.array(1.0)\n"," # OR x = jnp.array(1.0), divisor = jnp.array(0.0)\n"," # Observe the error. Is it specific?\n"," # YOUR CODE HERE\n"," # result1 = problematic_function_for_nans(jnp.array(-1.0), jnp.array(1.0))\n"," # print(f\"Result 1: {result1}\")\n"," pass # remove this\n","except Exception as e:\n"," print(f\"Caught exception (without jax_debug_nans): {e}\\n\")\n","\n","\n","# Scenario 2: Run WITH jax_debug_nans\n","print(\"--- Scenario 2: Running WITH jax_debug_nans ---\")\n","# jax.config.update(\"jax_debug_nans\", True) # Enable NaN debugging\n","\n","try:\n"," # Exercise 4.2: Call problematic_function_for_nans again with the SAME NaN-causing inputs.\n"," # Observe the error now. Is it more specific?\n"," # YOUR CODE HERE\n"," # result2 = problematic_function_for_nans(jnp.array(-1.0), jnp.array(1.0))\n"," # print(f\"Result 2: {result2}\")\n"," pass # remove this\n","except Exception as e:\n"," print(f\"Caught exception (WITH jax_debug_nans): {e}\\n\")\n","finally:\n"," # jax.config.update(\"jax_debug_nans\", False) # Disable after use\n"," print(\"jax_debug_nans has been disabled.\")"],"metadata":{"id":"fUEN-nY07RrM"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Solution (for Exercise 4.1 & 4.2, after attempting):"],"metadata":{"id":"pvrr7nys7Y_5"}},{"cell_type":"code","source":["# @jit\n","# def problematic_function_for_nans_solution(x, divisor):\n","# y = x * 100\n","# z = jnp.log(y) / divisor\n","# return z + y\n","\n","# # Scenario 1: Run without jax_debug_nans\n","# print(\"--- Scenario 1: Running without jax_debug_nans ---\")\n","# try:\n","# # For example, x = jnp.array(-1.0), divisor = jnp.array(1.0)\n","# # OR x = jnp.array(1.0), divisor = jnp.array(0.0)\n","# result1 = problematic_function_for_nans_solution(jnp.array(-1.0), jnp.array(1.0)) # SOLUTION for 4.1\n","# print(f\"Result 1: {result1}\")\n","# except Exception as e:\n","# print(f\"Caught exception (without jax_debug_nans): {e}\\n\")\n","\n","\n","# # Scenario 2: Run WITH jax_debug_nans\n","# print(\"--- Scenario 2: Running WITH jax_debug_nans ---\")\n","# jax.config.update(\"jax_debug_nans\", True) # Enable NaN debugging\n","\n","# try:\n","# result2 = problematic_function_for_nans_solution(jnp.array(-1.0), jnp.array(1.0)) # SOLUTION for 4.2\n","# print(f\"Result 2: {result2}\")\n","# except Exception as e:\n","# print(f\"Caught exception (WITH jax_debug_nans): {e}\\n\")\n","# finally:\n","# jax.config.update(\"jax_debug_nans\", False) # Disable after use\n","# print(\"jax_debug_nans has been disabled.\")"],"metadata":{"id":"BykyxEHK7cj_"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Without jax_debug_nans, the error might be a generic NaN detection or occur later in the computation. With jax_debug_nans enabled, JAX re-runs the failing operations in eager mode and raises an error at the exact primitive operation that produced the NaN, making it much easier to find. It's not enabled by default because it adds overhead (checks and potential eager re-runs), significantly slowing down execution."],"metadata":{"id":"XcigqZTi7oZM"}},{"cell_type":"markdown","source":["## 5. Inspecting Flax NNX Models: nnx.display()\n","\n","**A Note on NNX Modules in Flax v0.11+:** In this version, `nnx.Module` and other NNX objects are now registered as JAX Pytrees. This means JAX transformations like `jax.jit` and `jax.vmap` can be used on them directly. However, if you use functions like `jax.tree.map` on a data structure containing NNX modules, they will be traversed by default. To treat them as leaves (the old behavior), you must use the `is_leaf` argument: `is_leaf=lambda x: isinstance(x, nnx.Pytree)`.\n","\n","`nnx.display()` provides a clear view of your NNX Module's structure, parameters, and state.\n","\n","### Exercise 5.1 - 5.4:\n","1. 5.1: In `SimpleNNXModel.__init__`, add a second `nnx.Linear` layer named `self.dense2` that maps from `dhidden` to `dout` features. Remember to provide `rngs`.\n","2. 5.2: In `SimpleNNXModel.__call__`, pass the intermediate `x` through `self.dense2` (if you added it).\n","3. 5.3: When instantiating `SimpleNNXModel`, ensure `din`, `dhidden`, and `dout` match your intended architecture (e.g., `dout=5` if your `dense2` outputs 5 features).\n","4. 5.4: Use `nnx.display(model)` to print the structure. Examine the output. Can you see both dense layers and their parameters (`kernel`, `bias`)? Can you see the `PReLU` parameters?"],"metadata":{"id":"h3xGIJ5V7wjA"}},{"cell_type":"code","source":["class SimpleNNXModel(nnx.Module):\n"," def __init__(self, din: int, dhidden: int, dout: int, *, rngs: nnx.Rngs):\n"," key = rngs.params()\n"," self.dense1 = nnx.Linear(din, dhidden, rngs=rngs)\n"," # Exercise 5.1: Add another Linear layer called 'dense2' (dhidden -> dout)\n"," # YOUR CODE HERE\n"," self.activation = nnx.relu # Example of a layer with its own parameters\n","\n"," def __call__(self, x):\n"," x = self.dense1(x)\n"," x = nnx.relu(x)\n"," # Exercise 5.2: Pass x through 'dense2' if you added it.\n"," # YOUR CODE HERE\n"," x = self.activation(x)\n"," return x\n","\n","# Initialize RNGs for parameters\n","key = jax.random.key(0)\n","model_rngs = nnx.Rngs(params=key)\n","\n","# Instantiate the model\n","# Exercise 5.3: Update din, dhidden, dout if you changed the model structure\n","model = SimpleNNXModel(din=10, dhidden=20, dout=5, rngs=model_rngs)\n","\n","# Display the model structure\n","print(\"--- Model Structure using nnx.display() ---\")\n","# Exercise 5.4: Use nnx.display() to show the model's structure\n","# YOUR CODE HERE\n","\n","# If you have treescope installed and are in a compatible environment (like Colab default),\n","# nnx.display() will give an interactive tree. Otherwise, it falls back to print."],"metadata":{"id":"L_dyIQsX7gj2"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Solution (for Exercise 5.1-5.4, after attempting):"],"metadata":{"id":"BOw60Gqz8LYD"}},{"cell_type":"code","source":["# class SimpleNNXModelSolution(nnx.Module):\n","# def __init__(self, din: int, dhidden: int, dout: int, *, rngs: nnx.Rngs):\n","# self.dense1 = nnx.Linear(din, dhidden, rngs=rngs)\n","# self.dense2 = nnx.Linear(dhidden, dout, rngs=rngs) # SOLUTION 5.1\n","# self.activation = nnx.relu\n","\n","# def __call__(self, x):\n","# x = self.dense1(x)\n","# x = nnx.relu(x)\n","# x = self.dense2(x) # SOLUTION 5.2\n","# x = self.activation(x)\n","# return x\n","\n","# # Initialize RNGs for parameters\n","# key = jax.random.key(0)\n","# model_rngs = nnx.Rngs(params=key)\n","\n","# # Instantiate the model\n","# model_solution = SimpleNNXModelSolution(din=10, dhidden=20, dout=5, rngs=model_rngs) # SOLUTION 5.3 (dout adjusted)\n","\n","# # Display the model structure\n","# print(\"--- Model Structure using nnx.display() ---\")\n","# nnx.display(model_solution) # SOLUTION 5.4"],"metadata":{"id":"HYV14gzU8Oia"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## 6. Capturing Intermediate Values: nnx.sow()\n","Module.sow() allows you to \"plant\" intermediate values during the forward pass for later retrieval.\n","\n","### Exercise 6.1 & 6.2:\n","1. 6.1: In ModelWithSow.__call__, after x1_act is computed, use self.sow(nnx.Intermediate, 'activation_layer1', x1_act) to store it.\n","2. 6.2: After running the model, retrieve the sown value. It will be an attribute on sow_model named activation_layer1. Access its .value and print its shape.\n","3. What would happen if you called sow multiple times with the same name within one forward pass (e.g., inside a loop)?"],"metadata":{"id":"-oLfTi_R9h8x"}},{"cell_type":"code","source":["class ModelWithSow(nnx.Module):\n"," def __init__(self, *, rngs: nnx.Rngs):\n"," self.dense1 = nnx.Linear(5, 10, rngs=rngs)\n"," self.dense2 = nnx.Linear(10, 3, rngs=rngs)\n","\n"," def __call__(self, x):\n"," x1_act = self.dense1(x)\n"," x1_act = nnx.relu(x1_act)\n","\n"," # Exercise 6.1: Use self.sow() to store the value of x1_act.\n"," # Use nnx.Intermediate as the variable_type and 'activation_layer1' as the name.\n"," # YOUR CODE HERE\n","\n"," x2_out = self.dense2(x1_act)\n"," return x2_out\n","\n","# Setup\n","key = jax.random.key(1)\n","model_sow_rngs = nnx.Rngs(params=key)\n","sow_model = ModelWithSow(rngs=model_sow_rngs)\n","dummy_input = jnp.ones((1, 5))\n","\n","# Run the model\n","output = sow_model(dummy_input)\n","\n","# Retrieve the sown value\n","# Exercise 6.2: Retrieve the 'activation_layer1' value from the sow_model instance.\n","# Remember it's stored as an attribute, and the actual data is in its .value property.\n","# Print the shape of the retrieved value.\n","# YOUR CODE HERE\n","# retrieved_activation = ...\n","# print(f\"Shape of retrieved activation: {retrieved_activation.shape}\") # Adjust if it's a tuple"],"metadata":{"id":"hlqi9q0T8Rwf"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Solution (for Exercise 6.1 & 6.2, after attempting):"],"metadata":{"id":"xcoeBO-n93rN"}},{"cell_type":"code","source":["# class ModelWithSowSolution(nnx.Module):\n","# def __init__(self, *, rngs: nnx.Rngs):\n","# self.dense1 = nnx.Linear(5, 10, rngs=rngs)\n","# self.dense2 = nnx.Linear(10, 3, rngs=rngs)\n","\n","# def __call__(self, x):\n","# x1_act = self.dense1(x)\n","# x1_act = nnx.relu(x1_act)\n","\n","# self.sow(nnx.Intermediate, 'activation_layer1', x1_act) # SOLUTION 6.1\n","\n","# x2_out = self.dense2(x1_act)\n","# return x2_out\n","\n","# # Setup\n","# key = jax.random.key(1)\n","# model_sow_rngs = nnx.Rngs(params=key)\n","# sow_model_solution = ModelWithSowSolution(rngs=model_sow_rngs)\n","# dummy_input = jnp.ones((1, 5))\n","\n","# # Run the model\n","# output = sow_model_solution(dummy_input)\n","\n","# # Retrieve the sown value\n","# retrieved_sown_value_obj = sow_model_solution.activation_layer1 # This is the Variable object\n","# retrieved_activation = retrieved_sown_value_obj.value # This is the actual data (often a tuple)\n","# print(f\"Retrieved activation (raw): {retrieved_activation}\")\n","# # By default, sow appends to a tuple. So value is likely ((1,10))\n","# print(f\"Shape of retrieved activation (first element): {retrieved_activation[0].shape}\") # SOLUTION 6.2"],"metadata":{"id":"atf7bn4g91cj"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["If sow is called multiple times with the same name in one forward pass, by default, it appends each new value to a tuple stored in the .value property of the sown attribute."],"metadata":{"id":"bKzh4Ibi-Dx-"}},{"cell_type":"markdown","source":["## 7. Robustness with Chex Assertions\n","Chex provides powerful assertions for JAX code.\n","\n","### Exercise 7.1.1:\n","1. Fill in the # YOUR CODE HERE section in process_image_data with the specified Chex static assertions.\n","2. Run the cell and observe how the assertions catch the errors for wrong_shape_data and wrong_type_data."],"metadata":{"id":"Hg5Z4qCR-Lsz"}},{"cell_type":"code","source":["@jit\n","def process_image_data(image_batch: chex.Array):\n"," # Exercise 7.1.1: Add Chex assertions to verify:\n"," # 1. image_batch has a rank of 4 (e.g., Batch, Height, Width, Channels).\n"," # 2. image_batch has a dtype of jnp.float32.\n"," # 3. image_batch has a specific shape, e.g., (32, 224, 224, 3).\n"," # You can use a placeholder for batch_size if needed: chex.assert_shape(image_batch, (None, 224, 224, 3))\n"," # YOUR CODE HERE\n","\n"," # Dummy computation\n"," processed = image_batch * 2.0 - 1.0\n"," return processed\n","\n","# Test cases\n","correct_data = jnp.ones((32, 224, 224, 3), dtype=jnp.float32)\n","wrong_shape_data = jnp.ones((32, 224, 3), dtype=jnp.float32) # Missing a dim\n","wrong_type_data = jnp.ones((32, 224, 224, 3), dtype=jnp.int32)\n","\n","print(\"--- Testing with correct data ---\")\n","try:\n"," _ = process_image_data(correct_data)\n"," print(\"Correct data processed successfully!\")\n","except Exception as e:\n"," print(f\"Error with correct data:\\n{e}\")\n","\n","print(\"\\n--- Testing with wrong shape data ---\")\n","try:\n"," _ = process_image_data(wrong_shape_data)\n"," print(\"Wrong shape data processed successfully (this shouldn't happen if assertions are correct).\")\n","except AssertionError as e:\n"," print(f\"Caught expected AssertionError for wrong shape:\\n{e}\")\n","\n","print(\"\\n--- Testing with wrong type data ---\")\n","try:\n"," _ = process_image_data(wrong_type_data)\n"," print(\"Wrong type data processed successfully (this shouldn't happen if assertions are correct).\")\n","except AssertionError as e:\n"," print(f\"Caught expected AssertionError for wrong type:\\n{e}\")"],"metadata":{"id":"f7z_sqG09-lP"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Solution (for Exercise 7.1.1, after attempting):"],"metadata":{"id":"27e3MpqF-ksp"}},{"cell_type":"code","source":["# @jit\n","# def process_image_data_solution(image_batch: chex.Array):\n","# chex.assert_rank(image_batch, 4) # SOLUTION\n","# chex.assert_type(image_batch, jnp.float32) # SOLUTION\n","# chex.assert_shape(image_batch, (None, 224, 224, 3)) # SOLUTION (using None for batch)\n","\n","# processed = image_batch * 2.0 - 1.0\n","# return processed\n","\n","# # Test cases (same as above)\n","# correct_data_sol = jnp.ones((32, 224, 224, 3), dtype=jnp.float32)\n","# wrong_shape_data_sol = jnp.ones((32, 224, 3), dtype=jnp.float32)\n","# wrong_type_data_sol = jnp.ones((32, 224, 224, 3), dtype=jnp.int32)\n","\n","# print(\"--- SOLUTION: Testing with correct data ---\")\n","# try:\n","# _ = process_image_data_solution(correct_data_sol)\n","# print(\"Correct data processed successfully!\")\n","# except Exception as e:\n","# print(f\"Error with correct data:\\n{e}\")\n","\n","# print(\"\\n--- SOLUTION: Testing with wrong shape data ---\")\n","# try:\n","# _ = process_image_data_solution(wrong_shape_data_sol)\n","# except AssertionError as e:\n","# print(f\"Caught expected AssertionError for wrong shape:\\n{e}\")\n","\n","# print(\"\\n--- SOLUTION: Testing with wrong type data ---\")\n","# try:\n","# _ = process_image_data_solution(wrong_type_data_sol)\n","# except AssertionError as e:\n","# print(f\"Caught expected AssertionError for wrong type:\\n{e}\")"],"metadata":{"id":"4LbQjfxr-n_Z"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## 7.2. Performance Debugging: @chex.assert_max_traces()\n","Unintended JIT recompilations kill performance. @chex.assert_max_traces(n=N) helps detect this.\n","\n","### Exercise 7.2.1 & 7.2.2:\n","1. 7.2.1: In process_dynamic_shape, add the @chex.assert_max_traces(n=1) decorator.\n","2. 7.2.2: Uncomment and complete the second call to process_dynamic_shape using an input array with a different shape (e.g., jnp.ones((3,3))).\n","3. Run the cell.\n"," - Observe that Scenario 1 (with static_argnums) passes because the shape information critical for compilation (shape_tuple) is static and doesn't change.\n"," - Observe that Scenario 2 should raise an AssertionError. Why does this happen?"],"metadata":{"id":"Fjpi1eVTAHHP"}},{"cell_type":"code","source":["chex.clear_trace_counter() # Reset counter for this specific example\n","\n","# Scenario 1: Function with static argument for shape\n","@functools.partial(jit, static_argnums=(1,)) # shape_tuple is static\n","@chex.assert_max_traces(n=1)\n","def process_fixed_shape_staticarg(x: chex.Array, shape_tuple: tuple):\n"," chex.assert_shape(x, shape_tuple) # Check the shape matches\n"," return x * 2.0\n","\n","print(\"--- Scenario 1: Static argnum, consistent shape tuple ---\")\n","fixed_shape = (3, 4)\n","input_data_s1_c1 = jnp.ones(fixed_shape)\n","input_data_s1_c2 = jnp.zeros(fixed_shape) # Same shape, different values\n","_ = process_fixed_shape_staticarg(input_data_s1_c1, fixed_shape) # First call, traces\n","print(\"First call to process_fixed_shape_staticarg successful (traces).\")\n","_ = process_fixed_shape_staticarg(input_data_s1_c2, fixed_shape) # Second call, reuses cache\n","print(\"Second call to process_fixed_shape_staticarg successful (reuses cache).\")\n","\n","\n","# Scenario 2: Function where input shape might vary, leading to retracing if not handled\n","chex.clear_trace_counter() # Reset for this scenario\n","\n","@jit\n","# Exercise 7.3.1: Add @chex.assert_max_traces(n=1) here\n","# YOUR CODE HERE\n","def process_dynamic_shape(x: chex.Array):\n"," # This function will be re-traced if 'x' shape changes between calls\n"," return x + jnp.sum(x) # Example op\n","\n","print(\"\\n--- Scenario 2: Varying input shapes ---\")\n","try:\n"," print(\"Calling process_dynamic_shape with (2, 2)...\")\n"," _ = process_dynamic_shape(jnp.ones((2, 2))) # First call, traces\n"," print(\"First call to process_dynamic_shape successful.\")\n","\n"," # Exercise 7.3.2: Call process_dynamic_shape with a DIFFERENT shape, e.g., (3,3).\n"," # This should trigger an AssertionError if assert_max_traces is working.\n"," print(\"Calling process_dynamic_shape with (3, 3)...\")\n"," # YOUR CODE HERE\n"," # _ = process_dynamic_shape(jnp.ones((3, 3)))\n"," print(\"Second call to process_dynamic_shape successful (UNEXPECTED if shapes differ and max_traces=1).\")\n","\n","except AssertionError as e:\n"," print(f\"\\nCaught EXPECTED AssertionError for too many traces:\\n{e}\")\n","except Exception as e:\n"," print(f\"Caught unexpected error: {e}\")"],"metadata":{"id":"6rp79x5XBW4E"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Solution (for Exercise 7.2.1 & 7.2.2, after attempting):"],"metadata":{"id":"YY8xI2LMBcUN"}},{"cell_type":"code","source":["# chex.clear_trace_counter() # Reset counter for this specific example\n","\n","# # Scenario 1: Function with static argument for shape\n","# @functools.partial(jit, static_argnums=(1,))\n","# @chex.assert_max_traces(n=1)\n","# def process_fixed_shape_staticarg_sol(x: chex.Array, shape_tuple: tuple):\n","# chex.assert_shape(x, shape_tuple)\n","# return x * 2.0\n","\n","# print(\"--- SOLUTION: Scenario 1: Static argnum, consistent shape tuple ---\")\n","# fixed_shape_sol = (3, 4)\n","# input_data_s1_c1_sol = jnp.ones(fixed_shape_sol)\n","# input_data_s1_c2_sol = jnp.zeros(fixed_shape_sol)\n","# _ = process_fixed_shape_staticarg_sol(input_data_s1_c1_sol, fixed_shape_sol)\n","# print(\"First call to process_fixed_shape_staticarg_sol successful (traces).\")\n","# _ = process_fixed_shape_staticarg_sol(input_data_s1_c2_sol, fixed_shape_sol)\n","# print(\"Second call to process_fixed_shape_staticarg_sol successful (reuses cache).\")\n","\n","\n","# chex.clear_trace_counter() # Reset for this scenario\n","\n","# @jit\n","# @chex.assert_max_traces(n=1) # SOLUTION 7.3.1\n","# def process_dynamic_shape_sol(x: chex.Array):\n","# return x + jnp.sum(x)\n","\n","# print(\"\\n--- SOLUTION: Scenario 2: Varying input shapes ---\")\n","# try:\n","# print(\"Calling process_dynamic_shape_sol with (2, 2)...\")\n","# _ = process_dynamic_shape_sol(jnp.ones((2, 2)))\n","# print(\"First call to process_dynamic_shape_sol successful.\")\n","\n","# print(\"Calling process_dynamic_shape_sol with (3, 3)...\")\n","# _ = process_dynamic_shape_sol(jnp.ones((3, 3))) # SOLUTION 7.3.2\n","# print(\"Second call to process_dynamic_shape_sol successful (UNEXPECTED if shapes differ and max_traces=1).\")\n","\n","# except AssertionError as e:\n","# print(f\"\\nCaught EXPECTED AssertionError for too many traces:\\n{e}\")\n","# except Exception as e:\n","# print(f\"Caught unexpected error: {e}\")"],"metadata":{"id":"sT2sQNj6_950"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["In Scenario 2, the AssertionError happens because process_dynamic_shape is JIT-compiled based on the shape of its input x. When called the second time with a different shape, JAX needs to re-trace and re-compile the function for this new shape. @chex.assert_max_traces(n=1) detects this second trace and raises an error, alerting you to a potential performance issue due to recompilation."],"metadata":{"id":"P1vrg1D3AliP"}},{"cell_type":"markdown","source":["## 8. Monitoring with TensorBoard\n","TensorBoard is excellent for visualizing training metrics. The setup is similar to PyTorch.\n","\n","### Exercise 8.1 - 8.3:\n","1. 8.1: Create a tensorboardX.SummaryWriter instance, saving logs to LOG_DIR.\n","2. 8.2: Inside the loop, use writer.add_scalar() to log dummy_loss and dummy_accuracy. Crucially, convert them to Python scalars using .item().\n","3. 8.3: After the loop, close the writer using writer.close().\n","4. Run the cell.\n","5. If you are in Colab:\n"," - Uncomment the lines %load_ext tensorboard and %tensorboard --logdir {LOG_DIR} at the end of the cell.\n"," - Run the cell again. TensorBoard should appear in the output. Navigate to the SCALARS tab.\n","6. If running locally:\n"," - Open your terminal.\n"," - Navigate to the directory containing the logs folder (i.e., the parent of LOG_DIR).\n"," - Run tensorboard --logdir logs.\n"," - Open the URL (usually http://localhost:6006) in your browser.\n","7. Explore the TensorBoard and profiler (XProf) tools"],"metadata":{"id":"ui2JWgdKArS7"}},{"cell_type":"code","source":["# !pip install -Uq tensorboardX tensorboard tensorboard_plugin_profile\n","!pip install -Uq tensorboardX tensorboard_plugin_profile\n","!pip install -U protobuf"],"metadata":{"id":"52D9nw-20Zcd"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# For TensorBoard\n","from tensorboardX import SummaryWriter\n","import shutil # For cleaning up log directories"],"metadata":{"id":"z6S1ynM3094v"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# Clean up previous logs if any\n","LOG_DIR = \"logs/jax_debug_run\"\n","if shutil.os.path.exists(LOG_DIR):\n"," shutil.rmtree(LOG_DIR)\n"," print(f\"Removed old log directory: {LOG_DIR}\")\n","\n","# Exercise 8.1: Create a SummaryWriter from tensorboardX\n","# Point it to the LOG_DIR defined above.\n","# YOUR CODE HERE\n","# writer = ...\n","\n","# Dummy training loop\n","print(\"\\nSimulating training loop...\")\n","jax.profiler.start_trace(LOG_DIR) # Capturing trace for xprof\n","\n","for epoch in range(10):\n"," # Simulate loss and accuracy (JAX arrays)\n"," dummy_loss = jnp.array(1.0 / (epoch + 1))\n"," dummy_accuracy = jnp.array(1.0 - dummy_loss)\n","\n"," # Exercise 8.2: Log dummy_loss as 'Loss/train' and dummy_accuracy as 'Accuracy/validation'\n"," # Remember to use .item() to convert JAX arrays to Python scalars before logging.\n"," # Use 'epoch' as the global_step.\n"," # YOUR CODE HERE\n","\n"," if (epoch + 1) % 2 == 0:\n"," print(f\"Epoch {epoch+1}: Loss = {dummy_loss.item():.4f}, Acc = {dummy_accuracy.item():.4f}\")\n","\n","jax.profiler.stop_trace()\n","# Exercise 8.3: Close the writer\n","# YOUR CODE HERE\n","\n","print(f\"\\nTensorBoard logs saved to: {LOG_DIR}\")\n","print(\"To view in TensorBoard, run the following in your terminal (if local):\")\n","print(f\"tensorboard --logdir={LOG_DIR.split('/')[0]}\") # Get base 'logs' dir\n","print(\"Or, if in Colab, you can use the %tensorboard magic:\")\n","# %load_ext tensorboard\n","# %tensorboard --logdir {LOG_DIR}"],"metadata":{"id":"RU_Yh-FuAcB6"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Solution (for Exercise 8.1-8.3, after attempting):"],"metadata":{"id":"CQAOsIstBqa1"}},{"cell_type":"code","source":["# # Clean up previous logs if any\n","# LOG_DIR_SOL = \"logs/jax_debug_run_solution\" # Use a different dir for solution\n","# if shutil.os.path.exists(LOG_DIR_SOL):\n","# shutil.rmtree(LOG_DIR_SOL)\n","# print(f\"Removed old log directory: {LOG_DIR_SOL}\")\n","\n","# writer = SummaryWriter(LOG_DIR_SOL) # SOLUTION 8.1\n","# print(f\"TensorBoard writer initialized. Logging to: {LOG_DIR_SOL}\")\n","\n","# # Dummy training loop\n","# print(\"\\nSimulating training loop...\")\n","# # Ensure the profiler plugin is included in the trace\n","# jax.profiler.start_trace(LOG_DIR_SOL, create_perfetto_link=False) # Capturing trace for xprof\n","\n","# for epoch in range(10):\n","# dummy_loss = jnp.array(1.0 / (epoch + 1))\n","# dummy_loss.block_until_ready() # Ensure the array is ready\n","# dummy_accuracy = jnp.array(1.0 - dummy_loss)\n","# dummy_accuracy.block_until_ready() # Ensure the array is ready\n","\n","# writer.add_scalar('Loss/train', dummy_loss.item(), global_step=epoch) # SOLUTION 8.2\n","# writer.add_scalar('Accuracy/validation', dummy_accuracy.item(), global_step=epoch) # SOLUTION 8.2\n","\n","# if (epoch + 1) % 2 == 0:\n","# print(f\"Epoch {epoch+1}: Loss = {dummy_loss.item():.4f}, Acc = {dummy_accuracy.item():.4f}\")\n","\n","# jax.profiler.stop_trace()\n","# writer.close() # SOLUTION 8.3\n","\n","# print(f\"\\nTensorBoard logs saved to: {LOG_DIR_SOL}\")\n","# print(\"To view in TensorBoard, run the following in your terminal (if local):\")\n","# print(f\"tensorboard --logdir={LOG_DIR_SOL.split('/')[0]}\")\n","# print(\"Or, if in Colab, you can use the %tensorboard magic:\")\n","# %load_ext tensorboard\n","# %tensorboard --logdir {LOG_DIR_SOL} # --general_plugin_dir \"{LOG_DIR_SOL}/plugins\""],"metadata":{"id":"wBgZvLJcBs57"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# Profiling with XProf\n","\n","Profiling is also essential for understanding and improving your code. XProf is a great tool for profiling JAX and Flax NNX, and is compatible with TensorBoard. We've seen the XProf profiler with TensorBoard above, but let's look at a more interesting example. We'll download some profiling data from an MNIST model."],"metadata":{"id":"xFXlyJy8cghy"}},{"cell_type":"code","source":["# git clone the xprof repo so we have access to the demo data there\n","!git clone http://github.com/openxla/xprof\n","\n","# Launch TensorBoard and navigate to the Profile tab to view performance profile\n","%tensorboard --logdir=xprof/demo"],"metadata":{"id":"nrAHqdS2CJIl"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## 9. Visualizing Data Layout: jax.debug.visualize_array_sharding\n","\n","Understanding data sharding is crucial for multi-device training. `jax.debug.visualize_array_sharding` helps visualize this.\n","\n","Actually demonstrating this effectively requires a multi-device setup (e.g., multiple GPUs or TPUs and a Mesh). In a standard Colab CPU/single GPU environment, arrays won't be genuinely sharded across a mesh, but we can still see how the function works by faking a multi-device environment using `chex.set_n_cpu_devices`, which we did at the beginning of this Colab."],"metadata":{"id":"xvxjppTYCITm"}},{"cell_type":"markdown","source":["### Exercise 9.1:\n","1. Run the cell below.\n","2. Observe the output of `jax.debug.visualize_array_sharding`. Even on a single device, it will print information about the array's (lack of) sharding.\n","3. Think: If you had a Mesh of 4 devices arranged in a 2x2 grid (`Mesh(devices, ('dp', 'mp'))`) and an array arr of shape (8, 1024), how might you define a PartitionSpec to shard arr across data parallelism (dp) for the first dimension and model parallelism (mp) for the second? What would you expect `visualize_array_sharding(arr)` to show?"],"metadata":{"id":"7V5P8KTmGzgD"}},{"cell_type":"code","source":["from jax.sharding import Mesh, PartitionSpec, NamedSharding\n","from jax.experimental import mesh_utils\n","import jax.numpy as jnp\n","from jax import jit, grad, vmap\n","\n","try:\n"," if len(jax.devices()) >= 2:\n"," device_mesh = mesh_utils.create_device_mesh((len(jax.devices()),)) # Use all available devices\n"," mesh = Mesh(devices=device_mesh, axis_names=('data',))\n"," print(f\"Created a mesh with shape: {mesh.devices.shape} and names: {mesh.axis_names}\")\n"," else:\n"," print(\"Not enough devices to create a meaningful mesh for sharding demo. Will run on single device.\")\n"," mesh = None\n","except Exception as e:\n"," print(f\"Could not create mesh (likely on CPU Colab or single GPU): {e}\")\n"," mesh = None\n","\n","\n","@jit\n","def sharded_computation_demo(x_unsharded):\n"," # In a real scenario, x would be sharded before being passed or sharded inside\n"," # For this demo, we'll just visualize the unsharded array as if it were sharded\n","\n"," print(\"--- Input 'Sharding' (on single device, so not truly sharded) ---\")\n"," jax.debug.visualize_array_sharding(x_unsharded)\n","\n"," y = x_unsharded * 2.0\n","\n"," # If 'x_unsharded' had sharding, 'y' would typically inherit it or have a related one.\n"," print(\"--- Output 'Sharding' (on single device) ---\")\n"," jax.debug.visualize_array_sharding(y)\n"," return y\n","\n","an_array = jnp.arange(8.0)\n","\n","print(f\"Original array: {an_array}\")\n","\n","# If we had a mesh, we could try to shard it:\n","if mesh:\n"," # Shard along the first axis ('data')\n"," sharding_spec = NamedSharding(mesh, PartitionSpec('data',))\n"," an_array_sharded = jax.device_put(an_array, sharding_spec)\n"," print(f\"Array sharding: {an_array_sharded.sharding}\")\n"," output_sharded = sharded_computation_demo(an_array_sharded)\n","else:\n"," print(\"No mesh, running unsharded demo.\")\n"," output_unsharded = sharded_computation_demo(an_array) # Run with the original JIT\n","\n","# Simplified version for Colab (no actual sharding applied)\n","print(\"\\n--- Running visualization on a single device (no actual sharding) ---\")\n","output_unsharded = sharded_computation_demo(an_array)\n","print(f\"Output: {output_unsharded}\")"],"metadata":{"id":"wTT7usbDB10Y"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Answer (for Conceptual Exercise 9.1):\n","- You might define P = PartitionSpec('dp', 'mp').\n","- jax.debug.visualize_array_sharding(arr) would then print a diagram showing how the 8 rows are split over the 'dp' axis (e.g., 4 rows per device slice along 'dp') and the 1024 columns are split over the 'mp' axis (e.g., 512 columns per device slice along 'mp'). Each device in the 2x2 mesh would hold a (4, 512) slice of the original array."],"metadata":{"id":"JXGikTMmCmNb"}},{"cell_type":"markdown","source":["## Conclusion & Key Takeaways\n","You've now practiced with several key JAX and Flax NNX debugging tools!\n","- jax.debug.print() & jax.debug.breakpoint(): Your go-to tools for inspecting values inside JITted code.\n","- jax.disable_jit(): The \"escape hatch\" to use standard Python debuggers (pdb, IDEs) at the cost of performance.\n","- jax_debug_nans: Invaluable for automatically finding the source of NaNs.\n","- nnx.display(): Essential for understanding your NNX model's architecture and state.\n","- nnx.sow(): Useful for capturing intermediate activations without altering function signatures.\n","- Chex assertions (assert_shape, assert_tree_all_finite, assert_max_traces): Build robust and performant code by catching errors early and detecting recompilations.\n","- TensorBoard: Standard for monitoring training, works seamlessly with JAX.\n","Debugging in JAX's compiled world requires adapting your PyTorch habits, but with these tools, you're well-equipped to tackle issues effectively!\n","\n","Please send us feedback at https://goo.gle/jax-training-feedback"],"metadata":{"id":"7W9M6dYtC1qD"}}]} \ No newline at end of file diff --git a/docs/learning_jax/code-exercises/7 - Grain for data loading.ipynb b/docs/learning_jax/code-exercises/7 - Grain for data loading.ipynb new file mode 100644 index 0000000..a52f80c --- /dev/null +++ b/docs/learning_jax/code-exercises/7 - Grain for data loading.ipynb @@ -0,0 +1 @@ +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[{"file_id":"1KeLLQiLOy14q9OIPkXjv1lNxpArdqqzf","timestamp":1755113967227},{"file_id":"11x7XXCgvJj33PxSP2289nrHnOdJwoU-c","timestamp":1750353821037}],"toc_visible":true,"authorship_tag":"ABX9TyM8EkHLzJGv8jZyn07u4GN1"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","source":["# Efficient Data Loading with Grain: Exercises for JAX/Flax NNX\n","\n","Welcome! This Colab notebook contains exercises to help you learn Google's Grain library for efficient data loading in JAX. These exercises are designed for developers familiar with PyTorch who are now exploring the JAX ecosystem, including the new Flax NNX API.\n","\n","**Goals of this notebook:**\n","\n","* Understand the core components of Grain: DataSource, Sampler, and Operations.\n","* Learn how to use grain.DataLoader for both sequential and parallel data loading.\n","* Implement custom data transformations.\n","* Explore data sharding for distributed training scenarios.\n","* See how Grain integrates into a conceptual JAX/Flax NNX training loop.\n","* Learn about checkpointing data iterator state for reproducibility.\n","\n","**Simulated Multi-Device Environment:**\n","\n","To demonstrate parallelism and sharding concepts effectively in Colab (which typically provides a single CPU/GPU), this notebook starts by configuring JAX to simulate 8 CPU devices. This is achieved using XLA_FLAGS and chex.set_n_cpu_devices(8).\n","\n","**Let's get started! Please run the next cell to set up the environment.**"],"metadata":{"id":"QqFiEZqVxEA5"}},{"cell_type":"code","execution_count":null,"metadata":{"id":"rMAKmTMJw79f"},"outputs":[],"source":["# Environment Setup\n","# This cell configures the environment to simulate multiple CPU devices\n","# and installs necessary libraries.\n","# IMPORTANT: RUN THIS CELL FIRST. If you encounter issues with JAX device\n","# counts later, try 'Runtime -> Restart runtime' in the Colab menu\n","# and run this cell again before any others.\n","import os\n","\n","# Configure JAX to see 8 virtual CPU devices.\n","# This must be done before JAX is imported for the first time in a session.\n","os.environ[\"XLA_FLAGS\"] = \"--xla_force_host_platform_device_count=8\"\n","\n","# Install libraries\n","# We use google-grain for the Grain library.\n","!pip install -Uq grain chex flax jax jaxlib numpy\n","print(\"Libraries installed.\")\n","\n","# Now, import chex and attempt to set_n_cpu_devices.\n","# This must be called after setting XLA_FLAGS and before JAX initializes its backends.\n","import chex\n","\n","try:\n"," chex.set_n_cpu_devices(8)\n"," print(\"chex.set_n_cpu_devices(8) called successfully.\")\n","except RuntimeError as e:\n"," print(f\"Note on chex.set_n_cpu_devices: {e}\")\n"," print(\"This usually means JAX was already initialized. The XLA_FLAGS environment variable should still apply.\")\n"," print(\"If you see issues with device counts, ensure you 'Restart runtime' and run this cell first.\")\n","\n","# Verify JAX device count\n","import jax\n","print(f\"JAX version: {jax.__version__}\")\n","print(f\"JAX found {jax.device_count()} devices.\")\n","print(f\"JAX devices: {jax.devices()}\")\n","\n","if jax.device_count() != 8:\n"," print(\"\\nWARNING: JAX does not see 8 devices. Parallelism/sharding exercises might not behave as expected.\")\n"," print(\"Please try 'Runtime -> Restart runtime' and run this setup cell again first.\")\n","\n","# Common imports for the exercises\n","import jax.numpy as jnp\n","import numpy as np\n","import grain.python as grain # Main Grain API\n","from flax import nnx # Flax NNX API\n","import time # For simulating work and observing performance\n","import copy # For checkpointing example\n","from typing import Dict, Any, List # For type hints\n","import dataclasses # For ShardOptions if needed manually\n","import functools # For functools.partial\n","print(\"Imports complete. Setup finished.\")"]},{"cell_type":"markdown","source":["## Introduction to Grain\n","\n","As highlighted in the lecture, JAX is incredibly fast for numerical computation, especially on accelerators. However, this speed can be bottlenecked by inefficient data loading. Standard Python data loading can struggle due to I/O limitations, CPU-bound transformations, and the Global Interpreter Lock (GIL).\n","\n","**Grain** is Google's solution for high-performance data loading in JAX. Its key goals are:\n","* **Speed:** Achieved through multiprocessing, shared memory, and prefetching.\n","* **Determinism:** Ensuring reproducibility in experiments.\n","* **Flexibility & Simplicity:** Declarative pipeline definition.\n","* **JAX Ecosystem Focus:** Integrates with concepts like distributed sharding.\n","\n","Conceptually, Grain's `DataLoader` is analogous to PyTorch's `torch.utils.data.DataLoader`. It orchestrates data reading, transformation, batching, and parallelization.\n","\n","**Core Components of `grain.DataLoader` API:**\n","1. **`DataSource`**: Provides access to individual raw data records (must implement `__len__` and `__getitem__`).\n","2. **`Sampler`**: Determines the order in which records are loaded and provides seeds for random operations, ensuring reproducibility.\n","3. **`Operations`**: A list of transformations (e.g., augmentation, filtering, batching) applied sequentially to the records.\n","\n","Let's dive into the exercises!"],"metadata":{"id":"XMV045L50R6_"}},{"cell_type":"markdown","source":["---\n","## Exercise 1: Building Your First `grain.DataLoader` (Sequential)\n","\n","**Goal:** Get familiar with the basic components: `DataSource`, `IndexSampler`, a simple `MapTransform`, and `grain.DataLoader` running in sequential mode (`worker_count=0`).\n","\n","**Instructions:**\n","1. Define `MySource`, a custom `RandomAccessDataSource`.\n"," * `__init__`: Store `num_records`.\n"," * `__len__`: Return `num_records`.\n"," * `__getitem__`: Given an `idx`, return a dictionary `{'image': image_array, 'label': label_int}`.\n"," * The `image_array` should be a NumPy array of shape `(32, 32, 3)` with `dtype=np.uint8`. Its values can depend on `idx` (e.g., `np.ones(...) * (idx % 255)`).\n"," * The `label_int` should be an integer (e.g., `idx % 10`).\n"," * Handle potential index wrap-around for multiple epochs: `idx = idx % self._num_records`.\n","2. Instantiate `MySource`.\n","3. Create an `IndexSampler` for shuffling, running for 1 epoch, with a fixed seed.\n","4. Define a list of `operations`:\n"," * A `ConvertToFloat` class inheriting from `grain.MapTransform` that converts the 'image' to `np.float32` and normalizes it to `[0, 1]`.\n"," * A `grain.Batch` operation to batch 64 items, dropping any remainder.\n","5. Instantiate `grain.DataLoader` with `worker_count=0` (for debugging/sequential mode).\n"," * Since `MySource` is in-memory, use `read_options=grain.ReadOptions(num_threads=0)` to disable Grain's internal read threads.\n","6. Iterate through the `DataLoader` to get the first batch and print its shape and the shape of its labels."],"metadata":{"id":"0-M_gW8d0a0u"}},{"cell_type":"code","source":["# @title Exercise 1: Student Code\n","# 1. Define MySource\n","class MySource(grain.RandomAccessDataSource):\n"," def __init__(self, num_records: int = 1000):\n"," self._num_records = num_records\n"," def __len__(self) -> int:\n"," # TODO: Return the total number of records\n"," # YOUR CODE HERE\n"," return 0 # Replace this\n","\n"," def __getitem__(self, idx: int) -> Dict[str, Any]:\n"," # TODO: Handle potential index wrap-around for multiple epochs\n"," # effective_idx = ...\n"," # YOUR CODE HERE\n"," effective_idx = idx # Replace this\n","\n"," # TODO: Simulate loading data: an image and a label\n"," # image = np.ones(...) * (effective_idx % 255)\n"," # label = effective_idx % 10\n"," # YOUR CODE HERE\n"," image = np.zeros((32,32,3), dtype=np.uint8) # Replace this\n"," label = 0 # Replace this\n"," return {'image': image, 'label': label}\n","\n","# 2. Instantiate MySource\n","# TODO: Create an instance of MySource\n","# source = ...\n","# YOUR CODE HERE\n","source = None # Replace this\n","\n","# 3. Create an IndexSampler\n","# TODO: Create an IndexSampler that shuffles, runs for 1 epoch, and uses seed 42.\n","# num_records should be len(source).\n","# index_sampler = grain.IndexSampler(...)\n","# YOUR CODE HERE\n","index_sampler = None # Replace this\n","\n","# 4. Define Operations\n","# TODO: Define ConvertToFloat transform\n","class ConvertToFloat(grain.MapTransform):\n"," def map(self, features: Dict[str, Any]) -> Dict[str, Any]:\n"," # TODO: Convert 'image' to float32 and normalize to [0, 1].\n"," # Keep 'label' as is.\n"," # YOUR CODE HERE\n"," image = features['image'] # Replace this\n"," return {'image': image.astype(np.float32) / 255.0, 'label': features['label']}\n","\n","# TODO: Create a list of transformations: ConvertToFloat instance, then grain.Batch\n","batch_size = 64, drop_remainder = True\n","# transformations = [...]\n","# YOUR CODE HERE\n","transformations = [] # Replace this\n","\n","# 5. Instantiate DataLoader\n","# TODO: Create a DataLoader with worker_count=0 and appropriate read_options\n","# data_loader_sequential = grain.DataLoader(...)\n","# YOUR CODE HERE\n","data_loader_sequential = None # Replace this\n","\n","# 6. Iterate and print batch info\n","if data_loader_sequential:\n"," print(\"DataLoader configured sequentially.\")\n"," data_iterator_seq = iter(data_loader_sequential)\n"," try:\n"," first_batch_seq = next(data_iterator_seq)\n"," print(f\"Sequential - First batch image shape: {first_batch_seq['image'].shape}\")\n"," print(f\"Sequential - First batch label shape: {first_batch_seq['label'].shape}\")\n"," # Example: Check a value from the first image of the first batch\n"," print(f\"Sequential - Example image value (first item, [0,0,0]): {first_batch_seq['image'][0, 0, 0, 0]}\")\n"," print(f\"Sequential - Example label value (first item): {first_batch_seq['label'][0]}\")\n"," except StopIteration:\n"," print(\"Sequential DataLoader is empty or exhausted.\")\n","else:\n"," print(\"Sequential DataLoader not configured yet.\")"],"metadata":{"id":"AgTRPoyOzi5E","cellView":"form"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Exercise 1: Solution\n","# 1. Define MySource\n","class MySource(grain.RandomAccessDataSource):\n"," def __init__(self, num_records: int = 1000):\n"," self._num_records = num_records\n","\n"," def __len__(self) -> int:\n"," return self._num_records\n","\n"," def __getitem__(self, idx: int) -> Dict[str, Any]:\n"," effective_idx = idx % self._num_records # Handle wrap-around\n"," image = np.ones((32, 32, 3), dtype=np.uint8) * (effective_idx % 255)\n"," label = effective_idx % 10\n"," return {'image': image, 'label': label}\n","\n","# 2. Instantiate MySource\n","source = MySource(num_records=1000)\n","print(f\"DataSource created with {len(source)} records.\")\n","\n","# 3. Create an IndexSampler\n","index_sampler = grain.IndexSampler(\n"," num_records=len(source),\n"," shard_options=grain.NoSharding(), # No sharding for this exercise\n"," shuffle=True,\n"," num_epochs=1, # Run for 1 epoch\n"," seed=42\n"," )\n","print(\"IndexSampler created.\")\n","\n","# 4. Define Operations\n","class ConvertToFloat(grain.MapTransform):\n"," def map(self, features: Dict[str, Any]) -> Dict[str, Any]:\n"," # Convert 'image' to float32 and normalize to [0, 1].\n"," # Keep 'label' as is.\n"," image = features['image'].astype(np.float32) / 255.0\n"," return {'image': image, 'label': features['label']}\n","\n","transformations = [\n"," ConvertToFloat(),\n"," grain.Batch(batch_size=64, drop_remainder=True)\n"," ]\n","print(\"Transformations defined.\")\n","\n","# 5. Instantiate DataLoader\n","data_loader_sequential = grain.DataLoader(\n"," data_source=source,\n"," operations=transformations,\n"," sampler=index_sampler,\n"," worker_count=0, # Sequential mode\n"," shard_options=grain.NoSharding(), # Explicitly no sharding for this loader instance\n"," read_options=grain.ReadOptions(num_threads=0) # Dataset is in-memory\n",")\n","\n","# 6. Iterate and print batch info\n","if data_loader_sequential:\n"," print(\"DataLoader configured sequentially.\")\n"," data_iterator_seq = iter(data_loader_sequential)\n"," try:\n"," first_batch_seq = next(data_iterator_seq)\n"," print(f\"Sequential - First batch image shape: {first_batch_seq['image'].shape}\") # Expected: (64, 32, 32, 3)\n"," print(f\"Sequential - First batch label shape: {first_batch_seq['label'].shape}\") # Expected: (64,)\n"," # Example: Check a value from the first image of the first batch\n"," print(f\"Sequential - Example image value (first item, [0,0,0]): {first_batch_seq['image'][0, 0, 0, 0]}\")\n"," print(f\"Sequential - Example label value (first item): {first_batch_seq['label'][0]}\")\n"," except StopIteration:\n"," print(\"Sequential DataLoader is empty or exhausted.\")\n","else:\n"," print(\"Sequential DataLoader not configured yet.\")"],"metadata":{"id":"1yqwAkGm2ihu"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["---\n","## Exercise 2: Enabling Parallelism with `worker_count`\n","\n","**Goal:** Understand how `worker_count > 0` enables multiprocessing for faster data loading.\n","\n","**Instructions:**\n","1. Reuse `MySource`, `IndexSampler` (or create a new one if you prefer, e.g., for indefinite epochs: `num_epochs=None`), and `transformations` from Exercise 1.\n","2. To better observe the potential benefits of parallelism, let's modify `MySource` slightly. Add a small `time.sleep(0.01)` (10 milliseconds) inside `__getitem__` to simulate some I/O or CPU work for each item.\n","3. Instantiate a new `grain.DataLoader` (e.g., `data_loader_parallel`). This time, set `worker_count` to a value greater than 0 (e.g., 2 or 4). Remember our environment is faking 8 CPUs.\n","4. Iterate to get the first batch and print its shape info.\n","5. (Optional) Time how long it takes to get, for example, 10 batches from the sequential loader vs. the parallel loader. You should see a speed-up with the parallel loader, especially with the added `time.sleep`.\n","\n","**A note on pickling:** When `worker_count > 0`, Grain uses multiprocessing. This means all components (DataSource, Sampler, Operations, and custom transform instances) must be picklable by Python's `pickle` module. Simple classes and functions are usually fine, but avoid complex closures or unpicklable objects in your transform logic."],"metadata":{"id":"PbUAwDXf3_K4"}},{"cell_type":"code","source":["# @title Exercise 2: Student Code\n","# 1. Reuse/Recreate components (DataSource with simulated work, Sampler, Operations)\n","# TODO: Define MySourceWithWork, adding time.sleep(0.01) in getitem\n","class MySourceWithWork(grain.RandomAccessDataSource):\n"," def __init__(self, num_records: int = 1000):\n"," self._num_records = num_records\n","\n"," def __len__(self) -> int:\n"," # YOUR CODE HERE\n"," return self._num_records\n","\n"," def __getitem__(self, idx: int) -> Dict[str, Any]:\n"," effective_idx = idx % self._num_records\n"," # TODO: Add time.sleep(0.01) to simulate work\n"," # YOUR CODE HERE\n","\n"," image = np.ones((32, 32, 3), dtype=np.uint8) * (effective_idx % 255)\n"," label = effective_idx % 10\n"," return {'image': image, 'label': label}\n","\n","# TODO: Instantiate MySourceWithWork\n","# source_with_work = ...\n","# YOUR CODE HERE\n","source_with_work = None # Replace this\n","\n","# TODO: Create a new IndexSampler (e.g., for indefinite epochs, num_epochs=None)\n","# Or reuse the one from Ex1 if you reset it or it's for multiple epochs.\n","# For simplicity, let's create one for indefinite epochs.\n","# parallel_sampler = grain.IndexSampler(...)\n","# YOUR CODE HERE\n","parallel_sampler = None # Replace this\n","\n","# Transformations can be reused from Exercise 1\n","# transformations = [ConvertToFloat(), grain.Batch(batch_size=64, drop_remainder=True)]\n","# (Assuming ConvertToFloat is defined from Ex1 solution)\n","\n","# 2. Instantiate DataLoader with worker_count > 0\n","# TODO: Set num_workers (e.g., 4)\n","# num_workers = ...\n","# YOUR CODE HERE\n","num_workers = 0 # Replace this\n","\n","# TODO: Create data_loader_parallel\n","# data_loader_parallel = grain.DataLoader(...)\n","# YOUR CODE HERE\n","data_loader_parallel = None # Replace this\n","\n","# 3. Iterate and print batch info\n","if data_loader_parallel:\n"," print(f\"DataLoader configured with worker_count={num_workers}.\")\n"," data_iterator_parallel = iter(data_loader_parallel)\n"," try:\n"," first_batch_parallel = next(data_iterator_parallel)\n"," print(f\"Parallel - First batch image shape: {first_batch_parallel['image'].shape}\")\n"," print(f\"Parallel - First batch label shape: {first_batch_parallel['label'].shape}\")\n"," except StopIteration:\n"," print(\"Parallel DataLoader is empty or exhausted.\")\n","else:\n"," print(\"Parallel DataLoader not configured yet.\")\n","\n","# 4. (Optional) Timing comparison\n","# Re-create sequential loader with MySourceWithWork for a fair comparison\n","if source_with_work and transformations and index_sampler: # index_sampler from Ex1\n"," data_loader_seq_with_work = grain.DataLoader(\n"," data_source=source_with_work,\n"," operations=transformations, # Reusing from Ex1\n"," sampler=index_sampler, # Reusing from Ex1 (ensure it's fresh or allows re-iteration)\n"," worker_count=0,\n"," shard_options=grain.NoSharding(),\n"," read_options=grain.ReadOptions(num_threads=0)\n"," )\n"," num_batches_to_test = 5 # Small number for quick test\n","\n","if data_loader_seq_with_work:\n"," print(f\"\\nTiming test for {num_batches_to_test} batches:\")\n"," # Sequential\n"," iterator_seq = iter(data_loader_seq_with_work)\n"," start_time = time.time()\n"," try:\n"," for i in range(num_batches_to_test):\n"," batch = next(iterator_seq)\n"," if i == 0: print(f\" Seq batch 1 label sum: {batch['label'].sum()}\") # to ensure work is done\n"," except StopIteration:\n"," print(\"Sequential loader exhausted early.\")\n"," end_time = time.time()\n"," print(f\"Sequential ({num_batches_to_test} batches) took: {end_time - start_time:.4f} seconds\")\n","\n","if data_loader_parallel:\n"," # Parallel\n"," # Ensure sampler is fresh for parallel loader if it was used above\n"," # For this optional part, let's use a fresh sampler for the parallel loader\n"," # to avoid StopIteration if the previous sampler was single-epoch and exhausted.\n"," fresh_parallel_sampler = grain.IndexSampler(\n"," num_records=len(source_with_work),\n"," shard_options=grain.NoSharding(),\n"," shuffle=True,\n"," num_epochs=None, # Indefinite\n"," seed=43 # Different seed or same, for this test it's about speed\n"," )\n"," data_loader_parallel_for_timing = grain.DataLoader(\n"," data_source=source_with_work,\n"," operations=transformations, # Reusing from Ex1\n"," sampler=fresh_parallel_sampler,\n"," worker_count=num_workers if num_workers > 0 else 2, # Ensure parallelism\n"," shard_options=grain.NoSharding(),\n"," read_options=grain.ReadOptions(num_threads=0)\n"," )\n"," iterator_parallel = iter(data_loader_parallel_for_timing)\n"," start_time = time.time()\n"," try:\n"," for i in range(num_batches_to_test):\n"," batch = next(iterator_parallel)\n"," if i == 0: print(f\" Parallel batch 1 label sum: {batch['label'].sum()}\") # to ensure work is done\n"," except StopIteration:\n"," print(\"Parallel loader exhausted early.\")\n"," end_time = time.time()\n"," print(f\"Parallel ({num_batches_to_test} batches, {num_workers if num_workers > 0 else 2} workers) took: {end_time - start_time:.4f} seconds\")\n","else:\n"," print(\"Skipping optional timing: source_with_work, transformations, or index_sampler not defined.\")"],"metadata":{"id":"kJ9xr9N43tcH"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Exercise 2: Solution\n","# 1. Reuse/Recreate components\n","# Define MySourceWithWork, adding time.sleep(0.01) in getitem\n","class MySourceWithWork(grain.RandomAccessDataSource):\n"," def __init__(self, num_records: int = 1000):\n"," self._num_records = num_records\n","\n"," def __len__(self) -> int:\n"," return self._num_records\n","\n"," def __getitem__(self, idx: int) -> Dict[str, Any]:\n"," effective_idx = idx % self._num_records\n"," time.sleep(0.01) # Simulate 10ms of work per item\n"," image = np.ones((32, 32, 3), dtype=np.uint8) * (effective_idx % 255)\n"," label = effective_idx % 10\n"," return {'image': image, 'label': label}\n","\n","source_with_work = MySourceWithWork(num_records=1000)\n","print(f\"MySourceWithWork created with {len(source_with_work)} records.\")\n","\n","# Sampler for parallel loading (indefinite epochs for robust testing)\n","parallel_sampler = grain.IndexSampler(\n"," num_records=len(source_with_work),\n"," shard_options=grain.NoSharding(),\n"," shuffle=True,\n"," num_epochs=None, # Run indefinitely\n"," seed=42\n"," )\n","print(\"Parallel IndexSampler created.\")\n","\n","# Transformations can be reused from Exercise 1 solution\n","# Ensure ConvertToFloat is defined (it was in Ex1 solution cell)\n","if 'ConvertToFloat' not in globals(): # Basic check\n"," class ConvertToFloat(grain.MapTransform): # Redefine if not in current scope\n"," def map(self, features: Dict[str, Any]) -> Dict[str, Any]:\n"," image = features['image'].astype(np.float32) / 255.0\n"," return {'image': image, 'label': features['label']}\n","\n"," print(\"Redefined ConvertToFloat for safety.\")\n","\n","transformations_ex2 = [\n"," ConvertToFloat(),\n"," grain.Batch(batch_size=64, drop_remainder=True)\n"," ]\n","print(\"Transformations for Ex2 ready.\")\n","\n","# 2. Instantiate DataLoader with worker_count > 0\n","num_workers = 4 # Use 4 workers; JAX is configured for 8 virtual CPUs\n","# Max useful workers often related to num CPU cores available.\n","data_loader_parallel = grain.DataLoader(\n"," data_source=source_with_work,\n"," operations=transformations_ex2,\n"," sampler=parallel_sampler,\n"," worker_count=num_workers,\n"," shard_options=grain.NoSharding(),\n"," read_options=grain.ReadOptions(num_threads=0) # Data source simulates work but is \"in-memory\"\n"," )\n","\n","# 3. Iterate and print batch info\n","if data_loader_parallel:\n"," print(f\"DataLoader configured with worker_count={num_workers}.\")\n"," data_iterator_parallel = iter(data_loader_parallel)\n"," try:\n"," first_batch_parallel = next(data_iterator_parallel)\n"," print(f\"Parallel - First batch image shape: {first_batch_parallel['image'].shape}\")\n"," print(f\"Parallel - First batch label shape: {first_batch_parallel['label'].shape}\")\n"," except StopIteration:\n"," print(\"Parallel DataLoader is empty or exhausted.\")\n","else:\n"," print(\"Parallel DataLoader not configured yet.\")\n","\n","# 4. (Optional) Timing comparison\n","# Create a fresh IndexSampler for the sequential loader for a fair comparison start\n","# (num_epochs=1 to match typical test for sequential pass)\n","seq_sampler_for_timing = grain.IndexSampler(\n"," num_records=len(source_with_work),\n"," shard_options=grain.NoSharding(),\n"," shuffle=True,\n"," num_epochs=1, # Single epoch for this timing test\n"," seed=42\n"," )\n","\n","data_loader_seq_with_work = grain.DataLoader(\n"," data_source=source_with_work,\n"," operations=transformations_ex2,\n"," sampler=seq_sampler_for_timing,\n"," worker_count=0,\n"," shard_options=grain.NoSharding(),\n"," read_options=grain.ReadOptions(num_threads=0)\n"," )\n","num_batches_to_test = 5 # Number of batches to fetch for timing\n","print(f\"\\nTiming test for {num_batches_to_test} batches (each item has 0.01s simulated work):\")\n","\n","# Sequential\n","iterator_seq = iter(data_loader_seq_with_work)\n","start_time_seq = time.time()\n","try:\n"," for i in range(num_batches_to_test):\n"," batch_seq = next(iterator_seq)\n"," if i == 0 and num_batches_to_test > 0 : print(f\" Seq batch 1 label sum: {batch_seq['label'].sum()}\") # to ensure work is done\n","except StopIteration:\n"," print(f\"Sequential loader exhausted before {num_batches_to_test} batches.\")\n","end_time_seq = time.time()\n","print(f\"Sequential ({num_batches_to_test} batches) took: {end_time_seq - start_time_seq:.4f} seconds\")\n","\n","# Parallel\n","# Use a fresh sampler for the parallel loader for timing to ensure it's not exhausted\n","# and runs for enough batches.\n","parallel_sampler_for_timing = grain.IndexSampler(\n"," num_records=len(source_with_work),\n"," shard_options=grain.NoSharding(),\n"," shuffle=True,\n"," num_epochs=None, # Indefinite, or ensure enough for num_batches_to_test\n"," seed=43 # Can be same or different seed\n"," )\n","\n","data_loader_parallel_for_timing = grain.DataLoader(\n"," data_source=source_with_work,\n"," operations=transformations_ex2,\n"," sampler=parallel_sampler_for_timing,\n"," worker_count=num_workers,\n"," shard_options=grain.NoSharding(),\n"," read_options=grain.ReadOptions(num_threads=0)\n"," )\n","\n","iterator_parallel_timed = iter(data_loader_parallel_for_timing)\n","start_time_parallel = time.time()\n","try:\n"," for i in range(num_batches_to_test):\n"," batch_par = next(iterator_parallel_timed)\n"," if i == 0 and num_batches_to_test > 0 : print(f\" Parallel batch 1 label sum: {batch_par['label'].sum()}\") # to ensure work is done\n","except StopIteration:\n"," print(f\"Parallel loader exhausted before {num_batches_to_test} batches.\")\n","end_time_parallel = time.time()\n","print(f\"Parallel ({num_batches_to_test} batches, {num_workers} workers) took: {end_time_parallel - start_time_parallel:.4f} seconds\")\n","\n","if end_time_parallel - start_time_parallel < end_time_seq - start_time_seq:\n"," print(\"Parallel loading was faster, as expected!\")\n","else:\n"," print(\"Parallel loading was not significantly faster. This might happen for very small num_batches_to_test due to overhead, or if simulated work is too little.\")"],"metadata":{"id":"WvQYsi5C5e7_"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["\n","---\n","## Exercise 3: Custom Deterministic Transformations (`MapTransform`)\n","\n","**Goal:** Implement a custom data transformation that behaves deterministically.\n","\n","**Instructions:**\n","1. Define a custom class `OneHotEncodeLabel` that inherits from `grain.MapTransform`.\n"," * Its `__init__` method should take `num_classes`.\n"," * Its `map(self, features: Dict[str, Any])` method should:\n"," * Take the input `features` dictionary.\n"," * Convert the `features['label']` (an integer) into a one-hot encoded NumPy array of type `np.float32`. The length of this array should be `num_classes`.\n"," * Update `features['label']` with this new one-hot array.\n"," * Return the modified `features` dictionary.\n","2. Reuse `MySource` (the one without `time.sleep`) and `IndexSampler` from Exercise 1 (or create new ones).\n","3. Create a new list of `operations` that includes:\n"," * An instance of your `OneHotEncodeLabel` (e.g., with `num_classes=10`, matching `idx % 10` from `MySource`).\n"," * The `ConvertToFloat` transform (if not already applied to image).\n"," * `grain.Batch`.\n","4. Instantiate a `grain.DataLoader` (you can use `worker_count=0` or `>0`).\n","5. Iterate to get the first batch and print the shape of the one-hot encoded labels and an example label vector."],"metadata":{"id":"ELtzCFF7-a-E"}},{"cell_type":"code","source":["# @title Exercise 3: Student Code\n","# 1. Define OneHotEncodeLabel\n","class OneHotEncodeLabel(grain.MapTransform):\n"," def __init__(self, num_classes: int):\n"," # TODO: Store num_classes\n"," # YOUR CODE HERE\n"," self._num_classes = 0 # Replace this\n","\n"," def map(self, features: Dict[str, Any]) -> Dict[str, Any]:\n"," label = features['label']\n"," # TODO: Create one-hot encoded version of the label\n"," # one_hot_label = np.zeros(...)\n"," # one_hot_label[label] = 1.0\n"," # YOUR CODE HERE\n"," one_hot_label = np.array([label]) # Replace this\n","\n"," features['label'] = one_hot_label\n"," return features\n","\n","# 2. Reuse/Create DataSource and Sampler\n","# TODO: Instantiate MySource (from Ex1, no sleep)\n","# source_ex3 = ...\n","# YOUR CODE HERE\n","source_ex3 = None # Replace this\n","\n","# TODO: Instantiate an IndexSampler (e.g., from Ex1, or a new one)\n","# sampler_ex3 = grain.IndexSampler(...)\n","# YOUR CODE HERE\n","sampler_ex3 = None # Replace this\n","\n","# 3. Create new list of operations\n","# TODO: Instantiate OneHotEncodeLabel\n","num_classes_for_ohe = 10\n","one_hot_encoder = OneHotEncodeLabel(num_classes=num_classes_for_ohe)\n","# YOUR CODE HERE\n","one_hot_encoder = None # Replace this\n","\n","# TODO: Define transformations_ex3 list including one_hot_encoder,\n","# ConvertToFloat (if not already applied), and grain.Batch\n","# (Assuming ConvertToFloat is defined from Ex1 solution)\n","# transformations_ex3 = [...]\n","# YOUR CODE HERE\n","transformations_ex3 = [] # Replace this\n","\n","# 4. Instantiate DataLoader\n","# TODO: Create data_loader_ex3\n","# data_loader_ex3 = grain.DataLoader(...)\n","# YOUR CODE HERE\n","data_loader_ex3 = None # Replace this\n","\n","# 5. Iterate and print batch info\n","if data_loader_ex3:\n"," print(\"DataLoader with OneHotEncodeLabel configured.\")\n"," iterator_ex3 = iter(data_loader_ex3)\n"," try:\n"," first_batch_ex3 = next(iterator_ex3)\n"," print(f\"Custom MapTransform - Batch image shape: {first_batch_ex3['image'].shape}\")\n"," print(f\"Custom MapTransform - Batch label shape: {first_batch_ex3['label'].shape}\") # Expected: (batch_size, num_classes)\n"," if first_batch_ex3['label'].size > 0:\n"," print(f\"Custom MapTransform - Example one-hot label: {first_batch_ex3['label'][0]}\")\n"," except StopIteration:\n"," print(\"DataLoader for Ex3 is empty or exhausted.\")\n","else:\n"," print(\"DataLoader for Ex3 not configured yet.\")"],"metadata":{"id":"POurO6hp7-Mo"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Exercise 3: Solution\n","# 1. Define OneHotEncodeLabel\n","class OneHotEncodeLabel(grain.MapTransform):\n"," def __init__(self, num_classes: int):\n"," self._num_classes = num_classes\n","\n"," def map(self, features: Dict[str, Any]) -> Dict[str, Any]:\n"," label_scalar = features['label']\n"," one_hot_label = np.zeros(self._num_classes, dtype=np.float32)\n"," one_hot_label[label_scalar] = 1.0\n","\n"," # Create a new dictionary to avoid modifying the input dict in place if it's reused\n"," # by other transforms or parts of the pipeline, though often direct modification is fine.\n"," # For safety and clarity, let's return a new dict or an updated copy.\n"," updated_features = features.copy()\n"," updated_features['label'] = one_hot_label\n"," return updated_features\n","\n","# 2. Reuse/Create DataSource and Sampler\n","# Using MySource from Exercise 1 solution (no artificial sleep)\n","if 'MySource' not in globals(): # Basic check\n"," class MySource(grain.RandomAccessDataSource): # Redefine if not in current scope\n"," def __init__(self, num_records: int = 1000):\n"," self._num_records = num_records\n"," def __len__(self) -> int:\n"," return self._num_records\n","\n"," def __getitem__(self, idx: int) -> Dict[str, Any]:\n"," effective_idx = idx % self._num_records\n"," image = np.ones((32, 32, 3), dtype=np.uint8) * (effective_idx % 255)\n"," label = effective_idx % 10\n"," return {'image': image, 'label': label}\n"," print(\"Redefined MySource for Ex3.\")\n","\n","source_ex3 = MySource(num_records=1000)\n","sampler_ex3 = grain.IndexSampler(\n"," num_records=len(source_ex3),\n"," shard_options=grain.NoSharding(),\n"," shuffle=True,\n"," num_epochs=1,\n"," seed=42\n"," )\n","print(\"DataSource and Sampler for Ex3 ready.\")\n","\n","# 3. Create new list of operations\n","num_classes_for_ohe = 10 # Matches idx % 10 in MySource\n","one_hot_encoder = OneHotEncodeLabel(num_classes=num_classes_for_ohe)\n","\n","# Ensure ConvertToFloat is defined\n","if 'ConvertToFloat' not in globals():\n"," class ConvertToFloat(grain.MapTransform):\n"," def map(self, features: Dict[str, Any]) -> Dict[str, Any]:\n"," image = features['image'].astype(np.float32) / 255.0\n"," return {'image': image, 'label': features['label']} # Pass label through\n","print(\"Redefined ConvertToFloat for Ex3.\")\n","\n","transformations_ex3 = [\n"," ConvertToFloat(), # Apply first to have float images\n"," one_hot_encoder, # Then one-hot encode labels\n"," grain.Batch(batch_size=64, drop_remainder=True)\n"," ]\n","print(\"Transformations for Ex3 defined.\")\n","\n","# 4. Instantiate DataLoader\n","data_loader_ex3 = grain.DataLoader(\n"," data_source=source_ex3,\n"," operations=transformations_ex3,\n"," sampler=sampler_ex3,\n"," worker_count=0, # Can be > 0 as well, OneHotEncodeLabel is picklable\n"," shard_options=grain.NoSharding(),\n"," read_options=grain.ReadOptions(num_threads=0)\n"," )\n","\n","# 5. Iterate and print batch info\n","if data_loader_ex3:\n"," print(\"DataLoader with OneHotEncodeLabel configured.\")\n"," iterator_ex3 = iter(data_loader_ex3)\n"," try:\n"," first_batch_ex3 = next(iterator_ex3)\n"," print(f\"Custom MapTransform - Batch image shape: {first_batch_ex3['image'].shape}\")\n"," print(f\"Custom MapTransform - Batch label shape: {first_batch_ex3['label'].shape}\") # Expected: (64, 10)\n"," if first_batch_ex3['label'].size > 0:\n"," print(f\"Custom MapTransform - Example one-hot label (first item): {first_batch_ex3['label'][0]}\")\n"," original_label_example = np.argmax(first_batch_ex3['label'][0])\n"," print(f\"Custom MapTransform - Decoded original label (first item): {original_label_example}\")\n"," except StopIteration:\n"," print(\"DataLoader for Ex3 is empty or exhausted.\")\n","else:\n"," print(\"DataLoader for Ex3 not configured yet.\")"],"metadata":{"id":"asNGKCuy_TMz"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["\n","---\n","## Exercise 4: Custom Randomized Transformations (`RandomMapTransform`)\n","\n","**Goal:** Implement a custom transformation that involves randomness while ensuring reproducibility using Grain's mechanisms.\n","\n","**Instructions:**\n","1. Define a custom class `RandomBrightnessAdjust` that inherits from `grain.RandomMapTransform`.\n"," * Its `random_map(self, features: Dict[str, Any], rng: np.random.Generator) -> Dict[str, Any]` method should:\n"," * Take `features` and an `rng` (NumPy random number generator).\n"," * **Crucially, use the provided `rng` for all random operations.** This ensures that the same record, when processed with the same initial seed for the sampler, gets the same \"random\" augmentation.\n"," * Generate a random brightness factor using `rng.uniform(0.7, 1.3)`.\n"," * Multiply the `features['image']` (assuming it's already float and normalized) by this factor.\n"," * Clip the image values to stay within `[0.0, 1.0]` using `np.clip()`.\n"," * Return the modified `features`.\n","2. Reuse `MySource`, `IndexSampler` (ensure it has a `seed`), and `ConvertToFloat` from previous exercises.\n","3. Create a list of `operations` including `ConvertToFloat`, your `RandomBrightnessAdjust`, and `grain.Batch`.\n","4. Instantiate two `DataLoader` instances (`dl_run1`, `dl_run2`) with the **exact same configuration** (same source, sampler instance or sampler with same seed, operations, worker_count).\n","5. Iterate and get the first batch from `dl_run1`. Print a sample pixel value.\n","6. Reset the iterator or re-create the sampler if necessary (if `num_epochs=1`). Then, get the first batch from `dl_run2`. Print the same sample pixel value.\n","7. **Verify:** The pixel values should be identical, demonstrating reproducible random augmentation.\n","8. (Optional) Change the seed in the `IndexSampler` for `dl_run2` and observe that the pixel values now differ."],"metadata":{"id":"Lx0nsNlQBWDx"}},{"cell_type":"code","source":["# @title Exercise 4: Student Code\n","# 1. Define RandomBrightnessAdjust\n","class RandomBrightnessAdjust(grain.RandomMapTransform):\n"," def random_map(self, features: Dict[str, Any], rng: np.random.Generator) -> Dict[str, Any]:\n"," # TODO: Ensure image is float (e.g. by placing ConvertToFloat before this in ops)\n"," image = features['image']\n","\n"," # TODO: Generate a random brightness factor using the provided rng\n"," # brightness_factor = rng.uniform(...)\n"," # YOUR CODE HERE\n"," brightness_factor = 1.0 # Replace this\n","\n"," # TODO: Apply brightness adjustment and clip\n"," # adjusted_image = np.clip(...)\n"," # YOUR CODE HERE\n"," adjusted_image = image # Replace this\n","\n"," # Create a new dictionary or update a copy\n"," updated_features = features.copy()\n"," updated_features['image'] = adjusted_image\n"," return updated_features\n","\n","# 2. Reuse/Create DataSource, Sampler, ConvertToFloat\n","# TODO: Instantiate MySource (from Ex1)\n","# source_ex4 = ...\n","# YOUR CODE HERE\n","source_ex4 = None # Replace this\n","\n","# TODO: Instantiate an IndexSampler with a seed (e.g., seed=42, num_epochs=1 or None)\n","# sampler_ex4_seed42 = grain.IndexSampler(...)\n","# YOUR CODE HERE\n","sampler_ex4_seed42 = None # Replace this\n","\n","# (Assuming ConvertToFloat is defined from Ex1 solution)\n","\n","# 3. Create list of operations\n","# TODO: Instantiate RandomBrightnessAdjust\n","# random_brightness_adjuster = ...\n","# YOUR CODE HERE\n","random_brightness_adjuster = None # Replace this\n","\n","# TODO: Define transformations_ex4 list: ConvertToFloat, random_brightness_adjuster, grain.Batch\n","# transformations_ex4 = [...]\n","# YOUR CODE HERE\n","transformations_ex4 = [] # Replace this\n","\n","# 4. Instantiate two DataLoaders with the same config\n","# TODO: Create dl_run1\n","# dl_run1 = grain.DataLoader(...)\n","# YOUR CODE HERE\n","dl_run1 = None # Replace this\n","\n","# TODO: Create dl_run2 (using the exact same sampler instance or a new one with the same seed)\n","# dl_run2 = grain.DataLoader(...)\n","# YOUR CODE HERE\n","dl_run2 = None # Replace this\n","\n","# 5. & 6. Iterate and compare\n","pixel_to_check = (0, 0, 0, 0) # Batch_idx, H, W, C\n","if dl_run1:\n"," print(\"--- Run 1 (seed 42) ---\")\n"," iterator_run1 = iter(dl_run1)\n"," try:\n"," batch1_run1 = next(iterator_run1)\n"," value_run1 = batch1_run1['image'][pixel_to_check]\n"," print(f\"Run 1 - Pixel {pixel_to_check} value: {value_run1}\")\n"," except StopIteration:\n"," print(\"dl_run1 exhausted.\")\n"," value_run1 = None\n","else:\n"," print(\"dl_run1 not configured.\")\n"," value_run1 = None\n","\n","if dl_run2:\n"," print(\"\\n--- Run 2 (seed 42, same sampler) ---\")\n"," # If sampler_ex4_seed42 was single-epoch and already used by dl_run1,\n"," # dl_run2 might be empty. For robust test, ensure sampler allows re-iteration\n"," # or use a new sampler instance with the same seed.\n"," # If sampler_ex4_seed42 had num_epochs=None, iter(dl_run2) is fine.\n"," # If num_epochs=1, you might need to re-create sampler_ex4_seed42 for dl_run2\n"," # or ensure dl_run1 didn't exhaust it (e.g. by not fully iterating it).\n"," # For this exercise, assume sampler_ex4_seed42 can be re-used or is fresh for dl_run2.\n"," iterator_run2 = iter(dl_run2)\n"," try:\n"," batch1_run2 = next(iterator_run2)\n"," value_run2 = batch1_run2['image'][pixel_to_check]\n"," print(f\"Run 2 - Pixel {pixel_to_check} value: {value_run2}\")\n","\n"," # 7. Verify\n"," if value_run1 is not None and value_run2 is not None:\n"," if np.allclose(value_run1, value_run2):\n"," print(\"\\nSUCCESS: Pixel values are identical. Randomness is reproducible!\")\n"," else:\n"," print(f\"\\nFAILURE: Pixel values differ. value1={value_run1}, value2={value_run2}\")\n"," except StopIteration:\n"," print(\"dl_run2 exhausted. This might happen if the sampler was single-epoch and already used.\")\n"," value_run2 = None\n","else:\n"," print(\"dl_run2 not configured.\")\n","\n","# 8. (Optional) Test with a different seed\n","# TODO: Create sampler_ex4_seed100 (seed=100)\n","# sampler_ex4_seed100 = grain.IndexSampler(...)\n","# YOUR CODE HERE\n","sampler_ex4_seed100 = None # Replace this\n","\n","# TODO: Create dl_run3 with sampler_ex4_seed100\n","# dl_run3 = grain.DataLoader(...)\n","# YOUR CODE HERE\n","dl_run3 = None # Replace this\n","\n","if dl_run3:\n"," print(\"\\n--- Run 3 (seed 100) ---\")\n"," iterator_run3 = iter(dl_run3)\n"," try:\n"," batch1_run3 = next(iterator_run3)\n"," value_run3 = batch1_run3['image'][pixel_to_check]\n"," print(f\"Run 3 - Pixel {pixel_to_check} value: {value_run3}\")\n"," if value_run1 is not None and not np.allclose(value_run1, value_run3):\n"," print(\"SUCCESS: Pixel values differ from Run 1 (seed 42), as expected with a new seed.\")\n"," elif value_run1 is not None:\n"," print(\"NOTE: Pixel values are the same as Run 1. Check seed or logic.\")\n"," except StopIteration:\n"," print(\"dl_run3 exhausted.\")\n","else:\n"," print(\"\\nOptional part (dl_run3) not configured.\")"],"metadata":{"id":"nrZtIbOiAyOC"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Exercise 4: Solution\n","# 1. Define RandomBrightnessAdjust\n","class RandomBrightnessAdjust(grain.RandomMapTransform):\n"," def random_map(self, features: Dict[str, Any], rng: np.random.Generator) -> Dict[str, Any]:\n"," image = features['image'] # Assumes image is already float, e.g. from ConvertToFloat\n"," # Generate a random brightness factor using the provided rng\n"," brightness_factor = rng.uniform(0.7, 1.3)\n","\n"," # Apply brightness adjustment and clip\n"," adjusted_image = image * brightness_factor\n"," adjusted_image = np.clip(adjusted_image, 0.0, 1.0)\n","\n"," updated_features = features.copy()\n"," updated_features['image'] = adjusted_image\n"," return updated_features\n","\n","# 2. Reuse/Create DataSource, Sampler, ConvertToFloat\n","if 'MySource' not in globals(): # Basic check for MySource\n"," class MySource(grain.RandomAccessDataSource):\n"," def __init__(self, num_records: int = 1000):\n"," self._num_records = num_records\n"," def __len__(self) -> int:\n"," return self._num_records\n"," def __getitem__(self, idx: int) -> Dict[str, Any]:\n"," effective_idx = idx % self._num_records\n"," image = np.ones((32, 32, 3), dtype=np.uint8) * (effective_idx % 255)\n"," label = effective_idx % 10\n"," return {'image': image, 'label': label}\n"," print(\"Redefined MySource for Ex4.\")\n","source_ex4 = MySource(num_records=1000)\n","\n","# Sampler with a fixed seed. num_epochs=None allows re-iteration for multiple DataLoaders.\n","# If num_epochs=1, the sampler instance can only be fully iterated once.\n","# For this test, using num_epochs=None or re-creating the sampler for each DataLoader is safest.\n","# Let's use num_epochs=None to allow the same sampler instance to be used.\n","sampler_ex4_seed42 = grain.IndexSampler(\n"," num_records=len(source_ex4),\n"," shard_options=grain.NoSharding(),\n"," shuffle=True,\n"," num_epochs=None, # Allow indefinite iteration\n"," seed=42\n"," )\n","print(\"DataSource and Sampler (seed 42) for Ex4 ready.\")\n","\n","if 'ConvertToFloat' not in globals(): # Basic check for ConvertToFloat\n"," class ConvertToFloat(grain.MapTransform):\n"," def map(self, features: Dict[str, Any]) -> Dict[str, Any]:\n"," image = features['image'].astype(np.float32) / 255.0\n"," return {'image': image, 'label': features['label']}\n"," print(\"Redefined ConvertToFloat for Ex4.\")\n","\n","# 3. Create list of operations\n","random_brightness_adjuster = RandomBrightnessAdjust()\n","transformations_ex4 = [\n"," ConvertToFloat(),\n"," random_brightness_adjuster,\n"," grain.Batch(batch_size=64, drop_remainder=True)\n","]\n","print(\"Transformations for Ex4 defined.\")\n","\n","# 4. Instantiate two DataLoaders with the same config\n","# Using worker_count > 0 to also test picklability of RandomBrightnessAdjust\n","num_workers_ex4 = 2\n","dl_run1 = grain.DataLoader(\n"," data_source=source_ex4,\n"," operations=transformations_ex4,\n"," sampler=sampler_ex4_seed42, # Same sampler instance\n"," worker_count=num_workers_ex4,\n"," shard_options=grain.NoSharding(),\n"," read_options=grain.ReadOptions(num_threads=0)\n"," )\n","\n","dl_run2 = grain.DataLoader(\n"," data_source=source_ex4,\n"," operations=transformations_ex4,\n"," sampler=sampler_ex4_seed42, # Same sampler instance\n"," worker_count=num_workers_ex4,\n"," shard_options=grain.NoSharding(),\n"," read_options=grain.ReadOptions(num_threads=0)\n"," )\n","print(f\"DataLoaders for Run1 and Run2 created with worker_count={num_workers_ex4}.\")\n","\n","# 5. & 6. Iterate and compare\n","pixel_to_check = (0, 0, 0, 0) # Batch_idx=0, H=0, W=0, C=0\n","print(\"\\n--- Run 1 (seed 42) ---\")\n","iterator_run1 = iter(dl_run1)\n","\n","try:\n"," batch1_run1 = next(iterator_run1)\n"," value_run1 = batch1_run1['image'][pixel_to_check]\n"," print(f\"Run 1 - Pixel {pixel_to_check} value: {value_run1}\")\n","except StopIteration:\n"," print(\"dl_run1 exhausted.\")\n"," value_run1 = None\n","\n","print(\"\\n--- Run 2 (seed 42, same sampler instance) ---\")\n","iterator_run2 = iter(dl_run2) # Gets a new iterator from the DataLoader\n","\n","try:\n"," batch1_run2 = next(iterator_run2)\n"," value_run2 = batch1_run2['image'][pixel_to_check]\n"," print(f\"Run 2 - Pixel {pixel_to_check} value: {value_run2}\")\n"," # 7. Verify\n"," if value_run1 is not None and value_run2 is not None:\n"," if np.allclose(value_run1, value_run2):\n"," print(\"\\nSUCCESS: Pixel values are identical. Randomness is reproducible with the same sampler instance!\")\n"," else:\n"," print(f\"\\nFAILURE: Pixel values differ. value1={value_run1}, value2={value_run2}. This shouldn't happen if sampler is re-used correctly.\")\n","except StopIteration:\n"," print(\"dl_run2 exhausted.\")\n","\n","# 8. (Optional) Test with a different seed\n","sampler_ex4_seed100 = grain.IndexSampler(\n"," num_records=len(source_ex4),\n"," shard_options=grain.NoSharding(),\n"," shuffle=True,\n"," num_epochs=None,\n"," seed=100 # Different seed\n"," )\n","\n","dl_run3 = grain.DataLoader(\n"," data_source=source_ex4,\n"," operations=transformations_ex4,\n"," sampler=sampler_ex4_seed100, # Sampler with different seed\n"," worker_count=num_workers_ex4,\n"," shard_options=grain.NoSharding(),\n"," read_options=grain.ReadOptions(num_threads=0)\n"," )\n","print(\"\\nDataLoader for Run3 (seed 100) created.\")\n","print(\"\\n--- Run 3 (seed 100) ---\")\n","iterator_run3 = iter(dl_run3)\n","\n","try:\n"," batch1_run3 = next(iterator_run3)\n"," value_run3 = batch1_run3['image'][pixel_to_check]\n"," print(f\"Run 3 - Pixel {pixel_to_check} value: {value_run3}\")\n"," if value_run1 is not None and not np.allclose(value_run1, value_run3):\n"," print(\"SUCCESS: Pixel values differ from Run 1 (seed 42), as expected with a new sampler seed.\")\n"," elif value_run1 is not None:\n"," print(\"NOTE: Pixel values are the same as Run 1. This is unexpected if seeds are different.\")\n","except StopIteration:\n"," print(\"dl_run3 exhausted.\")"],"metadata":{"id":"JmdOMp4sIInr"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["---\n","## Exercise 5: Data Sharding for Distributed Training\n","\n","**Goal:** Understand how Grain handles data sharding, essential for distributed training where each JAX process needs a unique slice of data.\n","\n","**Background:**\n","In a real distributed JAX setup, you'd have multiple Python processes. Each process would call `jax.process_index()` to know its ID and `jax.process_count()` for the total number of processes. `grain.sharding.ShardByJaxProcess()` is a helper that automatically uses these values.\n","\n","Since we are in a single Colab notebook (simulating one JAX process, even with multiple virtual devices), we can't directly run multiple JAX processes. Instead, we will manually create `grain.ShardOptions` to simulate what would happen on two different processes.\n","\n","**Instructions:**\n","1. Reuse `MySource` and `transformations` (e.g., `ConvertToFloat` and `grain.Batch`) from previous exercises.\n","2. Define `shard_count = 2`.\n","3. **Simulate Process 0:**\n"," * Create `shard_options_p0 = grain.ShardOptions(shard_index=0, shard_count=shard_count, drop_remainder=True)`.\n"," * Create an `IndexSampler` (`sampler_p0`) using these `shard_options_p0`. Ensure it shuffles and uses a common seed (e.g., 42).\n"," * Create a `DataLoader` (`dl_p0`) using this `sampler_p0` and the `shard_options_p0` passed to the DataLoader itself.\n"," * Iterate through `dl_p0` and collect all unique labels from the first few batches (or all batches if `num_epochs=1`).\n","4. **Simulate Process 1:**\n"," * Create `shard_options_p1 = grain.ShardOptions(shard_index=1, shard_count=shard_count, drop_remainder=True)`.\n"," * Create an `IndexSampler` (`sampler_p1`) using `shard_options_p1` (same seed as `sampler_p0`).\n"," * Create a `DataLoader` (`dl_p1`) using `sampler_p1` and `shard_options_p1`.\n"," * Iterate through `dl_p1` and collect all unique labels.\n","5. **Verify:**\n"," * Print the set of unique labels obtained by \"Process 0\" and \"Process 1\".\n"," * Confirm that these two sets of labels are largely distinct (they might have minor overlaps if shuffling leads to boundary items being similar by chance, but the bulk of data indices processed should be different). The key is that the *indices* sampled by `sampler_p0` and `sampler_p1` should be disjoint.\n"," * The `drop_remainder=True` in `ShardOptions` ensures that if the dataset size isn't perfectly divisible by `shard_count`, some data might be dropped to ensure shards are equal or nearly equal (depending on implementation details).\n","\n","**Note on `shard_options` in `IndexSampler` vs `DataLoader`:**\n","The `shard_options` argument to `grain.DataLoader` is the primary way to enable sharding for a JAX process. The `DataLoader` will then ensure its underlying sampler (even if you provide a non-sharded one) respects these global sharding options for the current JAX process. If you provide an `IndexSampler` that is *already* sharded, its sharding must be compatible with the `DataLoader`'s `shard_options`. For simplicity and clarity in distributed settings, passing `ShardByJaxProcess()` or manually configured `ShardOptions` to the `DataLoader` is typical."],"metadata":{"id":"OGMGYA7wK1Nx"}},{"cell_type":"code","source":["# @title Exercise 5: Student Code\n","# 1. Reuse DataSource and basic transformations\n","# TODO: Instantiate MySource (from Ex1)\n","# source_ex5 = ...\n","# YOUR CODE HERE\n","source_ex5 = None # Replace this\n","\n","# TODO: Define basic_transformations_ex5 (e.g., ConvertToFloat, Batch)\n","# (Assuming ConvertToFloat is defined)\n","# basic_transformations_ex5 = [...]\n","# YOUR CODE HERE\n","basic_transformations_ex5 = [] # Replace this\n","\n","# 2. Define shard_count\n","shard_count = 2\n","common_seed = 42\n","num_epochs_for_sharding_test = 1 # To make collection of all labels feasible\n","\n","# 3. Simulate Process 0\n","# TODO: Create shard_options_p0\n","# shard_options_p0 = grain.ShardOptions(...)\n","# YOUR CODE HERE\n","shard_options_p0 = None # Replace this\n","\n","# TODO: Create sampler_p0. Pass shard_options_p0 to the IndexSampler.\n","# sampler_p0 = grain.IndexSampler(...)\n","# YOUR CODE HERE\n","sampler_p0 = None # Replace this\n","\n","# TODO: Create dl_p0. Pass shard_options_p0 to the DataLoader as well.\n","# dl_p0 = grain.DataLoader(...)\n","# YOUR CODE HERE\n","dl_p0 = None # Replace this\n","\n","labels_p0 = set()\n","if dl_p0:\n"," print(\"--- Simulating Process 0 ---\")\n"," # YOUR CODE HERE: Iterate through dl_p0 and collect all unique original labels.\n"," # Remember that labels might be batched. You need to iterate through items in a batch.\n"," # For simplicity, if your MySource generates labels like idx % 10,\n"," # you can try to collect the indices that were sampled.\n"," # Or, more directly, collect the 'label' field from each item.\n"," # To get original indices, you might need a transform that passes index through.\n"," # Let's collect the 'label' values directly.\n"," pass # Replace with iteration logic\n","\n","# 4. Simulate Process 1\n","# TODO: Create shard_options_p1\n","# shard_options_p1 = grain.ShardOptions(...)\n","# YOUR CODE HERE\n","shard_options_p1 = None # Replace this\n","\n","# TODO: Create sampler_p1\n","# sampler_p1 = grain.IndexSampler(...)\n","# YOUR CODE HERE\n","sampler_p1 = None # Replace this\n","\n","# TODO: Create dl_p1\n","# dl_p1 = grain.DataLoader(...)\n","# YOUR CODE HERE\n","dl_p1 = None # Replace this\n","\n","labels_p1 = set()\n","if dl_p1:\n"," print(\"\\n--- Simulating Process 1 ---\")\n"," # YOUR CODE HERE: Iterate through dl_p1 and collect all unique labels.\n"," pass # Replace with iteration logic\n","\n","# 5. Verify\n","print(f\"\\n--- Verification (Total records in source: {len(source_ex5) if source_ex5 else 'N/A'}) ---\")\n","print(f\"Unique labels collected by Process 0 (count {len(labels_p0)}): sorted {sorted(list(labels_p0))[:20]}...\")\n","print(f\"Unique labels collected by Process 1 (count {len(labels_p1)}): sorted {sorted(list(labels_p1))[:20]}...\")\n","if labels_p0 and labels_p1:\n"," intersection = labels_p0.intersection(labels_p1)\n"," if not intersection:\n"," print(\"\\nSUCCESS: No overlap in labels between Process 0 and Process 1. Sharding works as expected!\")\n"," else:\n"," print(f\"\\nNOTE: Some overlap in labels found (count {len(intersection)}): {intersection}.\")\n"," print(\"This can happen if labels are not unique per index, or if sharding logic has issues.\")\n"," print(\"With MySource's label = idx % 10, an overlap in labels is expected even if indices are disjoint.\")\n"," print(\"A better test would be to collect original indices if possible.\")\n","\n","# For a more direct test of sharding of indices:\n","# We can define a DataSource that returns the index itself.\n","class IndexSource(grain.RandomAccessDataSource):\n"," def __init__(self, num_records: int):\n"," self._num_records = num_records\n"," def __len__(self) -> int:\n"," return self._num_records\n"," def __getitem__(self, idx: int) -> int:\n"," return idx % self._num_records # Return the index\n","index_source = IndexSource(num_records=100) # Smaller source for easier inspection\n","\n","idx_sampler_p0 = grain.IndexSampler(len(index_source), shard_options_p0, shuffle=False, num_epochs=1, seed=common_seed)\n","idx_sampler_p1 = grain.IndexSampler(len(index_source), shard_options_p1, shuffle=False, num_epochs=1, seed=common_seed)\n","\n","# DataLoader for indices (no batching, just to see raw sampled indices)\n","# Note: DataLoader expects dicts. Let's make IndexSource return {'index': idx}\n","class IndexDictSource(grain.RandomAccessDataSource):\n"," def __init__(self, num_records: int): self._num_records = num_records\n"," def __len__(self) -> int:\n"," return self._num_records\n"," def __getitem__(self, idx: int) -> Dict[str,int]:\n"," return {'index': idx % self._num_records}\n","index_dict_source = IndexDictSource(num_records=100)\n","\n","# Samplers for IndexDictSource\n","idx_dict_sampler_p0 = grain.IndexSampler(len(index_dict_source), shard_options_p0, shuffle=False, num_epochs=1, seed=common_seed)\n","idx_dict_sampler_p1 = grain.IndexSampler(len(index_dict_source), shard_options_p1, shuffle=False, num_epochs=1, seed=common_seed)\n","\n","# DataLoaders for IndexDictSource\n","# Pass shard_options to DataLoader as well.\n","if shard_options_p0 and shard_options_p1:\n"," dl_indices_p0 = grain.DataLoader(index_dict_source, [], idx_dict_sampler_p0, worker_count=0, shard_options=shard_options_p0)\n"," dl_indices_p1 = grain.DataLoader(index_dict_source, [], idx_dict_sampler_p1, worker_count=0, shard_options=shard_options_p1)\n"," indices_from_p0 = {item['index'] for item in dl_indices_p0} if dl_indices_p0 else set()\n"," indices_from_p1 = {item['index'] for item in dl_indices_p1} if dl_indices_p1 else set()\n","\n","print(f\"\\n--- Verification of INDICES (Source size: {len(index_dict_source)}) ---\")\n","print(f\"Indices from P0 (count {len(indices_from_p0)}, shuffle=False): {sorted(list(indices_from_p0))}\")\n","print(f\"Indices from P1 (count {len(indices_from_p1)}, shuffle=False): {sorted(list(indices_from_p1))}\")\n","\n","if indices_from_p0 and indices_from_p1:\n"," idx_intersection = indices_from_p0.intersection(indices_from_p1)\n"," if not idx_intersection:\n"," print(\"SUCCESS: No overlap in INDICES. Sharding of data sources works correctly!\")\n"," else:\n"," print(f\"FAILURE: Overlap in INDICES found: {idx_intersection}\")\n","else:\n"," print(\"Skipping index verification part as shard_options are not defined.\")\n"," print(\"\\nReminder: In a real distributed setup, you'd use grain.sharding.ShardByJaxProcess() \"\n"," \"and JAX would manage jax.process_index() automatically for each process.\")"],"metadata":{"id":"206ZCYxCKpSc"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Exercise 5: Solution\n","# Redefine MySource for Ex5 to include 'original_index'.\n","class MySource(grain.RandomAccessDataSource):\n"," def __init__(self, num_records: int = 1000):\n"," self._num_records = num_records\n"," def __len__(self) -> int:\n"," return self._num_records\n"," def __getitem__(self, idx: int) -> Dict[str, Any]:\n"," effective_idx = idx % self._num_records\n"," image = np.ones((32, 32, 3), dtype=np.uint8) * (effective_idx % 255)\n"," label = effective_idx % 10 # Label is idx % 10\n"," # For better sharding verification, let's also pass the original index\n"," return {'image': image, 'label_test': label, 'original_index': effective_idx}\n","\n","print(\"Redefined MySource for Ex5 to include 'original_index'.\")\n","source_ex5 = MySource(num_records=1000)\n","\n","# Redefine ConvertToFloat for Ex5 to include 'original_index'.\n","class ConvertToFloat(grain.MapTransform):\n"," def map(self, features: Dict[str, Any]) -> Dict[str, Any]:\n"," # This transform should pass through all keys it doesn't modify\n"," updated_features = features.copy()\n"," updated_features['image'] = features['image'].astype(np.float32) / 255.0\n"," return updated_features\n","print(\"Redefined ConvertToFloat for Ex5.\")\n","\n","# We will collect 'original_index' after batching, so batch must preserve it.\n","# grain.Batch by default collates features with the same name.\n","basic_transformations_ex5 = [\n"," ConvertToFloat(),\n"," grain.Batch(batch_size=64, drop_remainder=True) # drop_remainder for batching\n"," ]\n","print(\"DataSource and Transformations for Ex5 ready.\")\n","\n","# 2. Define shard_count\n","shard_count = 2\n","common_seed = 42\n","num_epochs_for_sharding_test = 1\n","\n","# 3. Simulate Process 0\n","shard_options_p0 = grain.ShardOptions(shard_index=0, shard_count=shard_count, drop_remainder=True) # drop_remainder for sharding\n","\n","# Sampler for Process 0. It's important that the sampler itself is sharded.\n","# The DataLoader's shard_options will also apply this sharding if the sampler isn't already sharded,\n","# or verify consistency if it is.\n","sampler_p0 = grain.IndexSampler(\n"," num_records=len(source_ex5),\n"," shard_options=shard_options_p0, # Shard the sampler\n"," shuffle=True, # Shuffle for more realistic scenario\n"," num_epochs=num_epochs_for_sharding_test,\n"," seed=common_seed\n"," )\n","\n","dl_p0 = grain.DataLoader(\n"," data_source=source_ex5,\n"," operations=basic_transformations_ex5,\n"," sampler=sampler_p0,\n"," worker_count=0, # Keep it simple for verification\n"," shard_options=shard_options_p0, # Also inform DataLoader about the sharding context\n"," read_options=grain.ReadOptions(num_threads=0)\n"," )\n","\n","indices_p0 = set()\n","if dl_p0:\n"," print(\"--- Simulating Process 0 ---\")\n"," for batch in dl_p0:\n"," indices_p0.update(batch['original_index'].tolist()) # Collect original indices\n"," print(f\"Process 0 collected {len(indices_p0)} unique indices.\")\n","\n","# 4. Simulate Process 1\n","shard_options_p1 = grain.ShardOptions(shard_index=1, shard_count=shard_count, drop_remainder=True)\n","sampler_p1 = grain.IndexSampler(\n"," num_records=len(source_ex5),\n"," shard_options=shard_options_p1, # Shard the sampler\n"," shuffle=True, # Use same shuffle setting and seed for apples-to-apples comparison of sharding logic\n"," num_epochs=num_epochs_for_sharding_test,\n"," seed=common_seed # Same seed ensures shuffle order is same before sharding\n"," )\n","\n","dl_p1 = grain.DataLoader(\n"," data_source=source_ex5,\n"," operations=basic_transformations_ex5,\n"," sampler=sampler_p1,\n"," worker_count=0,\n"," shard_options=shard_options_p1, # Inform DataLoader\n"," read_options=grain.ReadOptions(num_threads=0)\n"," )\n","\n","indices_p1 = set()\n","if dl_p1:\n"," print(\"\\n--- Simulating Process 1 ---\")\n"," for batch in dl_p1:\n"," indices_p1.update(batch['original_index'].tolist()) # Collect original indices\n"," print(f\"Process 1 collected {len(indices_p1)} unique indices.\")\n","\n","# 5. Verify\n","print(f\"\\n--- Verification of original_indices (Total records in source: {len(source_ex5)}) ---\")\n","# Showing a few from each for brevity\n","print(f\"Unique original_indices from P0 (first 20 sorted): {sorted(list(indices_p0))[:20]}...\")\n","print(f\"Unique original_indices from P1 (first 20 sorted): {sorted(list(indices_p1))[:20]}...\")\n","expected_per_shard = len(source_ex5) // shard_count # Due to drop_remainder=True in ShardOptions\n","print(f\"Expected records per shard (approx, due to drop_remainder in sharding): {expected_per_shard}\")\n","print(f\"Actual for P0: {len(indices_p0)}, P1: {len(indices_p1)}\")\n","\n","if indices_p0 and indices_p1:\n"," intersection = indices_p0.intersection(indices_p1)\n"," if not intersection:\n"," print(\"\\nSUCCESS: No overlap in original_indices between Process 0 and Process 1. Sharding works!\")\n"," else:\n"," print(f\"\\nFAILURE: Overlap in original_indices found (count {len(intersection)}): {sorted(list(intersection))[:20]}...\")\n"," print(\"This should not happen if sharding is correct and seeds/shuffle are consistent.\")\n","else:\n"," print(\"Could not perform intersection test as one or both sets of indices are empty.\")\n"," total_unique_indices_seen = len(indices_p0.union(indices_p1))\n"," print(f\"Total unique indices seen across both simulated processes: {total_unique_indices_seen}\")\n","\n","# With drop_remainder=True in sharding, total might be less than len(source_ex5)\n","# if len(source_ex5) is not divisible by shard_count.\n","# Example: 1000 records, 2 shards. Each gets 500. Total 1000.\n","# Example: 1001 records, 2 shards. drop_remainder=True means each gets 500. Total 1000. 1 record dropped.\n","print(\"\\nReminder: In a real distributed JAX application:\")\n","print(\"1. Each JAX process would run this script (or similar).\")\n","print(\"2. shard_options = grain.sharding.ShardByJaxProcess(drop_remainder=True) would be used.\")\n","print(\"3. jax.process_index() and jax.process_count() would provide the correct shard info automatically.\")\n","print(\"4. The IndexSampler and DataLoader would be configured with these auto-detected shard_options.\")"],"metadata":{"id":"J-BAVny4NRDH"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["\n","---\n","## Exercise 6: Integrating Grain with a JAX/Flax NNX (Conceptual) Loop\n","\n","**Goal:** Understand how a Grain `DataLoader` feeds data into a typical JAX/Flax NNX training loop. This exercise is conceptual regarding model training (no actual weight updates) but practical in terms of data flow.\n","\n","**Instructions:**\n","1. **Define a Simple Flax NNX Model:**\n"," * Create a class `SimpleNNXModel` inheriting from `nnx.Module`.\n"," * In `__init__`, initialize an `nnx.Linear` layer. The input features should match the flattened image dimensions (e.g., `32*32*3`), and output features can be `num_classes` (e.g., 10). Remember to pass `rngs` for parameter initialization.\n"," * Implement `__call__(self, x)`: it should flatten the input image `x` (if it's `B, H, W, C`) and pass it through the linear layer.\n","2. **Define a Conceptual `train_step`:**\n"," * This JAX function should be JIT-compiled (`@jax.jit`).\n"," * It takes the `model` (your `SimpleNNXModel` instance) and a `batch` from Grain.\n"," * Inside, it performs a forward pass: `logits = model(batch['image'])`.\n"," * It calculates a dummy loss, e.g., `loss = jnp.mean(logits)`. (No real loss computation or gradients needed for this exercise).\n"," * It returns the `loss` and the `model`. In a real training scenario using `nnx.Optimizer`, the optimizer would update the model's parameters in-place. The `train_step` function would typically return the `loss`, the updated `model`, and the updated `optimizer` state to be used in the next iteration.\n","3. **Set up DataLoader:**\n"," * Use `MySource` (the one that yields `{'image': ..., 'label': ...}`), an `IndexSampler` (e.g., for a few epochs), and `transformations` (e.g., `ConvertToFloat`, `grain.Batch`).\n"," * Instantiate a `grain.DataLoader`.\n","4. **Write the Training Loop:**\n"," * Initialize your `SimpleNNXModel` with an appropriate JAX PRNG key.\n"," * Get an iterator from your `DataLoader`.\n"," * Loop for a fixed number of steps (e.g., 100):\n"," * Get the `next_batch` from the iterator. Handle `StopIteration` if the loader is exhausted.\n"," * Call your `train_step` function with the current `model` and `next_batch`.\n"," * Print the dummy loss occasionally."],"metadata":{"id":"rFyUOM_PbAEV"}},{"cell_type":"code","source":["# @title Exercise 6: Student Code\n","# 1. Define SimpleNNXModel\n","class SimpleNNXModel(nnx.Module):\n"," def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):\n"," # TODO: Initialize an nnx.Linear layer\n"," # self.linear = nnx.Linear(...)\n"," # YOUR CODE HERE\n"," self.linear = None # Replace this\n"," def __call__(self, x: jax.Array):\n"," # TODO: Flatten the input image (if B, H, W, C) and pass through linear layer\n"," # x_flat = x.reshape((x.shape[0], -1))\n"," # return self.linear(x_flat)\n"," # YOUR CODE HERE\n"," return x # Replace this\n","\n","# 2. Define train_step\n","def train_step(model: SimpleNNXModel, batch: Dict[str, jax.Array]):\n"," # TODO: Perform forward pass: model(batch['image'])\n"," # logits = ...\n"," # YOUR CODE HERE\n"," logits = model(batch['image']) # Assuming model handles it\n"," # TODO: Calculate a dummy loss (e.g., mean of logits)\n"," # loss = ...\n"," # YOUR CODE HERE\n"," loss = jnp.array(0.0) # Replace this\n","\n"," # In a real scenario, you'd also compute gradients and update model parameters here.\n"," # For this exercise, we just return the loss and the original model.\n"," return loss, model\n","\n","# 3. Set up DataLoader\n","# TODO: Instantiate MySource (from Ex1, or the one with 'original_index' if you prefer)\n","# source_ex6 = ...\n","# YOUR CODE HERE\n","source_ex6 = None # Replace this\n","\n","# TODO: Instantiate an IndexSampler for a few epochs (e.g., 2 epochs)\n","# sampler_ex6 = grain.IndexSampler(...)\n","# YOUR CODE HERE\n","sampler_ex6 = None # Replace this\n","\n","# TODO: Define transformations_ex6 (e.g., ConvertToFloat, grain.Batch)\n","# (Assuming ConvertToFloat is defined)\n","# transformations_ex6 = [...]\n","# YOUR CODE HERE\n","transformations_ex6 = [] # Replace this\n","\n","# TODO: Instantiate data_loader_ex6\n","# data_loader_ex6 = grain.DataLoader(...)\n","# YOUR CODE HERE\n","data_loader_ex6 = None # Replace this\n","\n","# 4. Write the Training Loop\n","if data_loader_ex6: # Proceed only if DataLoader is configured\n"," # TODO: Initialize SimpleNNXModel\n"," # image_height, image_width, image_channels = 32, 32, 3\n"," # num_classes_ex6 = 10\n"," # model_key = jax.random.key(0)\n"," # model_ex6 = SimpleNNXModel(...)\n"," # YOUR CODE HERE\n"," model_ex6 = None # Replace this\n","\n","if model_ex6:\n"," # TODO: Get an iterator from data_loader_ex6\n"," # grain_iterator_ex6 = ...\n"," # YOUR CODE HERE\n"," grain_iterator_ex6 = iter([]) # Replace this\n","\n"," num_steps = 100\n"," print(f\"\\nStarting conceptual training loop for {num_steps} steps...\")\n"," for step in range(num_steps):\n"," try:\n"," # TODO: Get next_batch from iterator\n"," # next_batch = ...\n"," # YOUR CODE HERE\n"," next_batch = None # Replace this\n"," if next_batch is None:\n"," raise StopIteration # Simulate exhaustion if not implemented\n"," except StopIteration:\n"," print(f\"DataLoader exhausted at step {step}. Ending loop.\")\n"," break\n","\n"," # TODO: Call train_step\n"," # loss, model_ex6 = train_step(model_ex6, next_batch) # model_ex6 isn't actually updated here\n"," # YOUR CODE HERE\n"," loss = jnp.array(0.0) # Replace this\n","\n"," if step % 20 == 0 or step == num_steps - 1:\n"," print(f\"Step {step}: Dummy Loss = {loss.item():.4f}\")\n"," print(\"Conceptual training loop finished.\")\n"," else:\n"," print(\"Model for Ex6 not initialized.\")\n","else:\n"," print(\"DataLoader for Ex6 not configured.\")"],"metadata":{"id":"Ml0niZnwQSIF"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Exercise 6: Solution\n","# 1. Define SimpleNNXModel\n","class SimpleNNXModel(nnx.Module):\n"," def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):\n"," self.linear = nnx.Linear(din, dout, rngs=rngs)\n"," def __call__(self, x: jax.Array) -> jax.Array:\n"," # Assuming x is (B, H, W, C)\n"," batch_size = x.shape[0]\n"," x_flat = x.reshape((batch_size, -1)) # Flatten H, W, C dimensions\n"," return self.linear(x_flat)\n","\n","# 2. Define conceptual train_step\n","def train_step(model: SimpleNNXModel, batch: Dict[str, jax.Array]):\n"," # Perform forward pass\n"," logits = model(batch['image']) # model.call is invoked\n"," # Calculate a dummy loss\n"," loss = jnp.mean(logits**2) # Example: mean of squared logits\n","\n"," # In a real training step:\n"," # # 1. Define a loss function.\n"," # def loss_fn(model):\n"," # logits = model(batch['image'])\n"," # # loss_value = ... (e.g., optax.softmax_cross_entropy_with_integer_labels)\n"," # return jnp.mean(logits**2) # Using dummy loss from exercise\n"," #\n"," # # 2. Calculate gradients.\n"," # grads = nnx.grad(loss_fn, wrt=nnx.Param)(model)\n"," #\n"," # # 3. Update the model's parameters in-place using the optimizer.\n"," # # Note: The optimizer is defined outside the train step.\n"," # # e.g., optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)\n"," # optimizer.update(model, grads)\n"," #\n"," # # 4. Return the updated model and optimizer state.\n"," return model, optimizer, loss\n","\n","# 3. Set up DataLoader\n","# Redefine MySource for Ex6.\n","class MySource(grain.RandomAccessDataSource):\n"," def __init__(self, num_records: int = 1000):\n"," self._num_records = num_records\n"," def __len__(self) -> int:\n"," return self._num_records\n"," def __getitem__(self, idx: int) -> Dict[str, Any]:\n"," effective_idx = idx % self._num_records\n"," image = np.ones((32, 32, 3), dtype=np.uint8) * (effective_idx % 255)\n"," label = effective_idx % 10\n"," return {'image': image, 'label': label}\n","print(\"Redefined MySource for Ex6.\")\n","\n","source_ex6 = MySource(num_records=1000)\n","sampler_ex6 = grain.IndexSampler(\n"," num_records=len(source_ex6),\n"," shard_options=grain.NoSharding(),\n"," shuffle=True,\n"," num_epochs=2, # Run for 2 epochs\n"," seed=42\n"," )\n","\n","# Redefine ConvertToFloat for Ex6.\n","class ConvertToFloat(grain.MapTransform):\n"," def map(self, features: Dict[str, Any]) -> Dict[str, Any]:\n"," updated_features = features.copy()\n"," updated_features['image'] = features['image'].astype(np.float32) / 255.0\n"," return updated_features\n","print(\"Redefined ConvertToFloat for Ex6.\")\n","\n","transformations_ex6 = [\n"," ConvertToFloat(),\n"," grain.Batch(batch_size=64, drop_remainder=True)\n"," ]\n","\n","data_loader_ex6 = grain.DataLoader(\n"," data_source=source_ex6,\n"," operations=transformations_ex6,\n"," sampler=sampler_ex6,\n"," worker_count=2, # Use a couple of workers\n"," shard_options=grain.NoSharding(),\n"," read_options=grain.ReadOptions(num_threads=0)\n"," )\n","print(\"DataLoader for Ex6 configured.\")\n","\n","# 4. Write the Training Loop\n","# Define image dimensions and number of classes\n","image_height, image_width, image_channels = 32, 32, 3\n","input_dim = image_height * image_width * image_channels\n","num_classes_ex6 = 10\n","\n","# Initialize SimpleNNXModel\n","# NNX modules are typically initialized outside JIT, then their state can be passed.\n","# For this conceptual example, the model instance itself is passed.\n","model_key = jax.random.key(0)\n","model_ex6 = SimpleNNXModel(din=input_dim, dout=num_classes_ex6, rngs=nnx.Rngs(params=model_key))\n","print(f\"SimpleNNXModel initialized. Input dim: {input_dim}, Output dim: {num_classes_ex6}\")\n","\n","# Get an iterator from data_loader_ex6\n","grain_iterator_ex6 = iter(data_loader_ex6)\n","num_steps = 100 # Total steps for the conceptual loop\n","\n","print(f\"\\nStarting conceptual training loop for {num_steps} steps...\")\n","for step_idx in range(num_steps):\n"," try:\n"," # Get next_batch from iterator\n"," next_batch = next(grain_iterator_ex6)\n"," except StopIteration:\n"," print(f\"DataLoader exhausted at step {step_idx}. Ending loop.\")\n"," # Example: Re-initialize iterator if sampler allows multiple epochs\n"," # if sampler_ex6.num_epochs is None or sampler_ex6.num_epochs > 1 (and we tracked current epoch):\n"," # print(\"Re-initializing iterator for new epoch...\")\n"," # grain_iterator_ex6 = iter(data_loader_ex6)\n"," # next_batch = next(grain_iterator_ex6)\n"," # else:\n"," break # Exit loop if truly exhausted\n","\n"," # Call train_step\n"," # JAX arrays in batch are automatically handled by jax.jit\n"," _, loss = train_step(model_ex6, next_batch) # model_ex6 state isn't actually updated here\n","\n"," if step_idx % 20 == 0 or step_idx == num_steps - 1:\n"," print(f\"Step {step_idx}: Dummy Loss = {loss.item():.4f}\") # .item() to get Python scalar from JAX array\n","print(\"Conceptual training loop finished.\")"],"metadata":{"id":"qNHGUhUbdAgF"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["---\n","## Exercise 7: Checkpointing and Resuming Data Iteration\n","\n","**Goal:** Understand how to save and restore the state of a Grain data iterator for reproducible experiments, especially when resuming long training runs.\n","\n","**Background:**\n","Grain's `DataLoader` produces an iterator when you call `iter(data_loader)`. This iterator (`grain.PyGrainDatasetIterator`) has `get_state()` and `set_state()` methods. These allow you to capture the internal state of the iteration (e.g., current position, RNG states for samplers/transforms) and restore it later. For full experiment checkpointing, this data iterator state should be saved alongside your model parameters (often using a library like Orbax).\n","\n","**Instructions:**\n","1. Set up a `DataLoader` (e.g., using `MySource`, an `IndexSampler` with `num_epochs=None` for indefinite iteration and a seed, and some basic `transformations`).\n","2. Get an iterator (`iterator1`) from this `DataLoader`.\n","3. Iterate a few times (e.g., 3 batches) using `next(iterator1)` and store the last batch obtained.\n","4. **Save State:** Call `saved_iterator_state = iterator1.get_state()`.\n","5. **Simulate Resumption:**\n"," * Get a *new* iterator (`iterator2`) from the *same* `DataLoader` instance.\n"," * **Restore State:** Call `iterator2.set_state(saved_iterator_state)`.\n","6. Iterate once using `next(iterator2)` to get a batch (`resumed_batch`).\n","7. **Verify:**\n"," * The `resumed_batch` obtained from `iterator2` should be the *same* as the batch that *would have come after* the last batch from `iterator1`.\n"," * To verify this:\n"," * After getting `saved_iterator_state` from `iterator1`, call `next(iterator1)` one more time to get the `expected_next_batch_from_iterator1`.\n"," * Compare `resumed_batch` (from `iterator2` after `set_state`) with `expected_next_batch_from_iterator1`. Their contents (e.g., image data) should be identical."],"metadata":{"id":"cr7s9TH_h0B-"}},{"cell_type":"code","source":["# @title Exercise 7: Student Code\n","# 1. Set up DataLoader\n","# TODO: Instantiate MySource (e.g., from Ex1)\n","# source_ex7 = ...\n","# YOUR CODE HERE\n","source_ex7 = None # Replace this\n","\n","# TODO: Instantiate an IndexSampler (num_epochs=None, seed=42)\n","# sampler_ex7 = grain.IndexSampler(...)\n","# YOUR CODE HERE\n","sampler_ex7 = None # Replace this\n","\n","# TODO: Define transformations_ex7 (e.g., ConvertToFloat, Batch)\n","# (Assuming ConvertToFloat is defined)\n","# transformations_ex7 = [...]\n","# YOUR CODE HERE\n","transformations_ex7 = [] # Replace this\n","\n","# TODO: Instantiate data_loader_ex7\n","# data_loader_ex7 = grain.DataLoader(...)\n","# YOUR CODE HERE\n","data_loader_ex7 = None # Replace this\n","\n","if data_loader_ex7:\n"," # 2. Get iterator1\n"," # TODO: iterator1 = iter(...)\n"," # YOUR CODE HERE\n"," iterator1 = iter([]) # Replace this\n","\n","# 3. Iterate a few times\n","num_initial_iterations = 3\n","print(f\"--- Initial Iteration (iterator1) for {num_initial_iterations} batches ---\")\n","last_batch_iterator1 = None\n","for i in range(num_initial_iterations):\n"," try:\n"," # TODO: last_batch_iterator1 = next(...)\n"," # YOUR CODE HERE\n"," last_batch_iterator1 = {} # Replace this\n"," print(f\"iterator1, batch {i+1} - first label: {last_batch_iterator1.get('label', [None])[0]}\")\n"," except StopIteration:\n"," print(\"iterator1 exhausted prematurely.\")\n"," break\n","\n","# 4. Save State\n","# TODO: saved_iterator_state = iterator1.get_state()\n","# YOUR CODE HERE\n","saved_iterator_state = None # Replace this\n","print(f\"\\nIterator state saved. Type: {type(saved_iterator_state)}\")\n","\n","# For verification: get the *next* batch from iterator1 *after* saving state\n","expected_next_batch_from_iterator1 = None\n","if saved_iterator_state is not None: # Ensure state was actually saved\n"," try:\n"," # TODO: expected_next_batch_from_iterator1 = next(...)\n"," # YOUR CODE HERE\n"," expected_next_batch_from_iterator1 = {} # Replace this\n"," print(f\"Expected next batch (from iterator1 after get_state) - first label: {expected_next_batch_from_iterator1.get('label', [None])[0]}\")\n"," except StopIteration:\n"," print(\"iterator1 exhausted when trying to get expected_next_batch.\")\n","\n","# 5. Simulate Resumption\n","# TODO: Get iterator2 from the same data_loader_ex7\n","# iterator2 = iter(...)\n","# YOUR CODE HERE\n","iterator2 = iter([]) # Replace this\n","\n","if saved_iterator_state is not None:\n"," # TODO: iterator2.set_state(...)\n"," # YOUR CODE HERE\n"," print(\"\\n--- Resumed Iteration (iterator2) ---\")\n"," print(\"Iterator state restored to iterator2.\")\n","else:\n"," print(\"\\nSkipping resumption, saved_iterator_state is None.\")\n","\n","# 6. Iterate once from iterator2\n","resumed_batch = None\n","if saved_iterator_state is not None: # Only if state was set\n"," try:\n"," # TODO: resumed_batch = next(...)\n"," # YOUR CODE HERE\n"," resumed_batch = {} # Replace this\n"," print(f\"Resumed batch (from iterator2 after set_state) - first label: {resumed_batch.get('label', [None])[0]}\")\n"," except StopIteration:\n"," print(\"iterator2 exhausted immediately after set_state.\")\n","\n","# 7. Verify\n","if expected_next_batch_from_iterator1 is not None and resumed_batch is not None:\n"," # Compare 'image' data of the first element in the batch\n"," # TODO: Perform comparison (e.g., np.allclose on image data)\n"," # are_identical = np.allclose(...)\n"," # YOUR CODE HERE\n"," are_identical = False # Replace this\n","\n"," if are_identical:\n"," print(\"\\nSUCCESS: Resumed batch is identical to the expected next batch. Checkpointing works!\")\n"," else:\n"," print(\"\\nFAILURE: Resumed batch differs from the expected next batch.\")\n"," # print(f\"Expected image data (sample): {expected_next_batch_from_iterator1['image'][0,0,0,:3]}\")\n"," # print(f\"Resumed image data (sample): {resumed_batch['image'][0,0,0,:3]}\")\n","elif saved_iterator_state is not None:\n"," print(\"\\nVerification inconclusive: could not obtain both expected and resumed batches.\")\n","else:\n"," print(\"DataLoader for Ex7 not configured.\")"],"metadata":{"id":"M-y5f4x-gzmA"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Exercise 7: Solution\n","# 1. Set up DataLoader\n","# Redefine MySource for Ex7.\n","class MySource(grain.RandomAccessDataSource):\n"," def __init__(self, num_records: int = 1000):\n"," self._num_records = num_records\n"," def __len__(self) -> int:\n"," return self._num_records\n"," def __getitem__(self, idx: int) -> Dict[str, Any]:\n"," effective_idx = idx % self._num_records\n"," image = np.ones((32, 32, 3), dtype=np.uint8) * (effective_idx % 255)\n"," label = effective_idx % 10\n"," # Add original_index for easier verification if needed, though label might suffice\n"," return {'image': image, 'label': label, 'original_index': effective_idx}\n","print(\"Redefined MySource for Ex7.\")\n","\n","source_ex7 = MySource(num_records=1000)\n","sampler_ex7 = grain.IndexSampler(\n"," num_records=len(source_ex7),\n"," shard_options=grain.NoSharding(),\n"," shuffle=True, # Shuffling makes the test more robust\n"," num_epochs=None, # Indefinite iteration\n"," seed=42\n"," )\n","\n","# Redefine ConvertToFloat for Ex7.\n","class ConvertToFloat(grain.MapTransform):\n"," def map(self, features: Dict[str, Any]) -> Dict[str, Any]:\n"," updated_features = features.copy()\n"," updated_features['image'] = features['image'].astype(np.float32) / 255.0\n"," return updated_features\n","print(\"Redefined ConvertToFloat for Ex7.\")\n","\n","transformations_ex7 = [\n"," ConvertToFloat(),\n"," grain.Batch(batch_size=64, drop_remainder=True)\n"," ]\n","\n","data_loader_ex7 = grain.DataLoader(\n"," data_source=source_ex7,\n"," operations=transformations_ex7,\n"," sampler=sampler_ex7,\n"," worker_count=0, # Simpler for state verification, but works with >0 too\n"," shard_options=grain.NoSharding(),\n"," read_options=grain.ReadOptions(num_threads=0)\n"," )\n","print(\"DataLoader for Ex7 configured.\")\n","\n","# 2. Get iterator1\n","iterator1 = iter(data_loader_ex7)\n","\n","# 3. Iterate a few times\n","num_initial_iterations = 3\n","print(f\"--- Initial Iteration (iterator1) for {num_initial_iterations} batches ---\")\n","last_batch_iterator1 = None\n","for i in range(num_initial_iterations):\n"," try:\n"," last_batch_iterator1 = next(iterator1)\n"," # Using 'original_index' for more robust check than just 'label'\n"," print(f\"iterator1, batch {i+1} - first original_index: {last_batch_iterator1['original_index'][0]}\")\n"," except StopIteration:\n"," print(\"iterator1 exhausted prematurely.\")\n"," break\n","\n","# 4. Save State\n","# Make a deep copy if you plan to continue using iterator1 and don't want\n","# its state object to be modified if Python passes by reference internally (usually not an issue for simple state).\n","# For PyGrainDatasetIterator, get_state() returns a new state object.\n","saved_iterator_state = iterator1.get_state()\n","print(f\"\\nIterator state saved. Type: {type(saved_iterator_state)}\")\n","\n","# For verification: get the next batch from iterator1 after saving state\n","expected_next_batch_from_iterator1 = None\n","if saved_iterator_state is not None:\n"," try:\n"," expected_next_batch_from_iterator1 = next(iterator1)\n"," print(f\"Expected next batch (from iterator1 after get_state) - first original_index: {expected_next_batch_from_iterator1['original_index'][0]}\")\n"," except StopIteration:\n"," print(\"iterator1 exhausted when trying to get expected_next_batch.\")\n","\n","# 5. Simulate Resumption\n","# Get a new iterator from the same DataLoader instance.\n","iterator2 = iter(data_loader_ex7)\n","if saved_iterator_state is not None:\n"," iterator2.set_state(saved_iterator_state)\n"," print(\"\\n--- Resumed Iteration (iterator2) ---\")\n"," print(\"Iterator state restored to iterator2.\")\n","else:\n"," print(\"\\nSkipping resumption, saved_iterator_state is None.\")\n","\n","# 6. Iterate once from iterator2\n","resumed_batch = None\n","if saved_iterator_state is not None:\n"," try:\n"," resumed_batch = next(iterator2)\n"," print(f\"Resumed batch (from iterator2 after set_state) - first original_index: {resumed_batch['original_index'][0]}\")\n"," except StopIteration:\n"," print(\"iterator2 exhausted immediately after set_state. This means the saved state was at the very end.\")\n","\n","# 7. Verify\n","if expected_next_batch_from_iterator1 is not None and resumed_batch is not None:\n"," # Compare 'image' data and 'original_index' of the first element in the batch for robustness\n"," # (Labels might repeat, indices are better for this check if available)\n"," expected_img_sample = expected_next_batch_from_iterator1['image'][0]\n"," resumed_img_sample = resumed_batch['image'][0]\n"," expected_idx_sample = expected_next_batch_from_iterator1['original_index'][0]\n"," resumed_idx_sample = resumed_batch['original_index'][0]\n"," are_indices_identical = (expected_idx_sample == resumed_idx_sample)\n"," are_images_identical = np.allclose(expected_img_sample, resumed_img_sample)\n","\n"," are_identical = are_indices_identical and are_images_identical\n","\n"," if are_identical:\n"," print(\"\\nSUCCESS: Resumed batch is identical to the expected next batch. Checkpointing works!\")\n"," else:\n"," print(\"\\nFAILURE: Resumed batch differs from the expected next batch.\")\n"," if not are_indices_identical:\n"," print(f\" - Mismatch in first original_index: Expected {expected_idx_sample}, Got {resumed_idx_sample}\")\n"," if not are_images_identical:\n"," print(f\" - Mismatch in image data for first element.\")\n"," # print(f\" Expected image data (sample [0,0,0]): {expected_img_sample[0,0,0]}\")\n"," # print(f\" Resumed image data (sample [0,0,0]): {resumed_img_sample[0,0,0]}\")\n","elif saved_iterator_state is not None: # If state was saved but verification couldn't complete\n"," if expected_next_batch_from_iterator1 is None and resumed_batch is None:\n"," print(\"\\nVERIFICATION NOTE: Both iterators seem to be at the end of the dataset after the initial iterations. This is valid if the dataset was short.\")\n"," else:\n"," print(\"\\nVerification inconclusive: could not obtain both expected and resumed batches for comparison.\")\n"," print(f\" expected_next_batch_from_iterator1 is None: {expected_next_batch_from_iterator1 is None}\")\n"," print(f\" resumed_batch is None: {resumed_batch is None}\")\n","else: # If DataLoader itself wasn't configured\n"," print(\"DataLoader for Ex7 not configured.\")"],"metadata":{"id":"vVcRIZdujpFW"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["\n","---\n","## Conclusion and Further Exploration\n","\n","Congratulations on completing the exercises! You should now have a good understanding of:\n","* The fundamental components of Grain (`DataSource`, `Sampler`, `Operations`).\n","* How to construct and use `grain.DataLoader` for efficient data input, including parallel loading.\n","* Implementing custom deterministic and random transformations.\n","* The basics of data sharding for distributed training.\n","* How Grain iterators fit into a JAX/Flax NNX training loop.\n","* Saving and restoring data iterator state for reproducibility.\n","\n","**Key Takeaways (Recap from Slides):**\n","* **Use Grain:** Solves JAX data bottlenecks for better performance.\n","* **Boost Speed:** Use `DataLoader(worker_count > 0)` for parallelism.\n","* **Ensure Reproducibility:** Use samplers/seeds & `RandomMapTransform`'s provided `rng`.\n","* **Distribute:** Use `grain.sharding.ShardByJaxProcess` (or manual `ShardOptions`) for JAX sharding.\n","* **Save Everything:** Checkpoint data iterator state (e.g., via Orbax for comprehensive checkpointing) along with your model state.\n","\n","**Further Exploration:**\n","* **Orbax Integration:** For robust checkpointing in real-world projects, explore integrating Grain with [Orbax](https://github.com/google/orbax). Orbax can manage saving and loading your Grain iterator state alongside your Flax model parameters and optimizer states atomically. **Note:** If you are migrating a project from NNX v0.10, be aware that the checkpoint structure for models with RNGs (like `Dropout` or `BatchNorm`) has changed. You will need to use a migration script to update old checkpoints to the v0.11 format, as described in the [official migration guide](https://flax.readthedocs.io/en/latest/migrating/nnx_010_to_nnx_011.html).\n","* **Different Data Sources:** Explore reading from various on-disk formats (e.g., TFRecord, RecordIO) using appropriate `DataSource` implementations or by integrating with libraries like TFDS.\n","* **Performance Profiling:** Use JAX's profiling tools to identify and optimize data loading bottlenecks in more complex scenarios.\n","\n","We hope these exercises have been helpful in your journey with JAX and Grain!"],"metadata":{"id":"yDU9idjEmcRD"}}]} \ No newline at end of file diff --git a/docs/learning_jax/code-exercises/8 - Orbax for checkpointing.ipynb b/docs/learning_jax/code-exercises/8 - Orbax for checkpointing.ipynb new file mode 100644 index 0000000..f3cfa88 --- /dev/null +++ b/docs/learning_jax/code-exercises/8 - Orbax for checkpointing.ipynb @@ -0,0 +1 @@ +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[{"file_id":"13nlRxnUp2Y8sLDpqVQDjMSs_soWB-JWQ","timestamp":1755114004925}],"toc_visible":true,"authorship_tag":"ABX9TyO1TCkH6hEPMVvV1U3DTETQ"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","source":["# Introduction to Checkpointing with Flax NNX and Orbax\n","\n","Welcome to this hands-on exercise! We'll explore how to save and load your JAX/Flax NNX models, a crucial skill for any serious machine learning project.\n","\n","## Why Checkpoint?\n","Training deep learning models can take a long time. Checkpointing allows you to:\n","\n","* Save your progress (model parameters, optimizer state) to resume training later if it gets interrupted.\n","* Preserve model states at different stages for analysis or inference.\n","* Implement fault tolerance in long training runs.\n","\n","## Flax NNX: A Quick Recap\n","\n","* **Stateful Modules**: NNX modules are Python classes that directly hold their own state (like parameters) as attributes. This often feels more intuitive, especially if you're coming from PyTorch.\n","* `nnx.Module`: The base class for creating these stateful components.\n","* `nnx.Variable`: Special types like nnx.Param and nnx.BatchStat are used to define learnable parameters and other stateful variables within an nnx.Module.\n","* `nnx.State`: A JAX Pytree (like a nested dictionary) that holds all the nnx.Variable values from a module. This is what Orbax saves and restores.\n","\n","## The Functional Bridge:\n","\n","* `nnx.split(module)`: Separates a module into its static structure (GraphDef) and its dynamic state (nnx.State). This is key for getting the state to save.\n","* `nnx.merge(graphdef, state)`: Reconstructs a module instance from its GraphDef and nnx.State. Used after restoring.\n","* `nnx.update(module, state)`: Updates an existing module's state in-place. Also used after restoring.\n","\n","## Orbax: The JAX Checkpointing Library\n","\n","Orbax is the standard library for checkpointing in JAX, designed to be robust and scalable.\n","\n","* `ocp.CheckpointManager`: A high-level utility that simplifies managing multiple checkpoints over a training run (e.g., keeping the last N checkpoints, handling versions). We'll be using this extensively.\n","* `ocp.args`: Namespace for specifying how to save/restore different parts of your state (e.g., ocp.args.StandardSave, ocp.args.StandardRestore, ocp.args.Composite).\n","\n","Let's get started!"],"metadata":{"id":"LCP8B52UdhGP"}},{"cell_type":"code","execution_count":null,"metadata":{"id":"0dHcdEQETEfy"},"outputs":[],"source":["# @title Setup: Install and Import Libraries\n","# Install necessary libraries\n","!pip install -Uq flax orbax-checkpoint chex optax\n","\n","import jax\n","import jax.numpy as jnp\n","from jax.sharding import Mesh, PartitionSpec, NamedSharding\n","from jax.experimental import mesh_utils\n","import flax\n","from flax import nnx\n","import orbax.checkpoint as ocp\n","import optax\n","import os\n","import shutil # For cleaning up directories\n","import chex # For faking devices\n","\n","# Suppress some JAX warnings for cleaner output in the notebook\n","import warnings\n","warnings.filterwarnings(\"ignore\", message=\"No GPU/TPU found, falling back to CPU.\")\n","warnings.filterwarnings(\"ignore\", message=\"Custom node type GlobalDeviceArray is not handled by Pytree traversal.\") # Orbax/NNX interactions\n","\n","print(f\"JAX version: {jax.__version__}\")\n","print(f\"Flax version: {flax.__version__}\")\n","print(f\"Orbax version: {ocp.__version__}\")\n","print(f\"Optax version: {optax.__version__}\")\n","print(f\"Chex version: {chex.__version__}\")\n","\n","# --- Setup for Distributed Exercises ---\n","# Simulate an environment with 8 CPUs for distributed examples\n","# This allows us to test sharding logic even on a single-CPU Colab machine.\n","try:\n"," chex.set_n_cpu_devices(8)\n","except RuntimeError as e:\n"," print(f\"Note: Could not set_n_cpu_devices (may have been set already): {e}\")\n","\n","print(f\"Number of JAX devices available: {jax.device_count()}\")\n","print(f\"Available devices: {jax.devices()}\")\n","\n","# Helper function to clean up checkpoint directories\n","def cleanup_ckpt_dir(ckpt_dir):\n"," if os.path.exists(ckpt_dir):\n"," shutil.rmtree(ckpt_dir)\n"," print(f\"Cleaned up checkpoint directory: {ckpt_dir}\")\n","\n","# Create a default checkpoint directory for exercises\n","CKPT_BASE_DIR = '/tmp/nnx_orbax_workshop_checkpoints'\n","if not os.path.exists(CKPT_BASE_DIR):\n"," os.makedirs(CKPT_BASE_DIR)\n","\n","print(f\"Base checkpoint directory: {CKPT_BASE_DIR}\")"]},{"cell_type":"markdown","source":["## Exercise 1: Basic Checkpointing - Saving nnx.State\n","\n","**Goal**: Learn to save the state of a simple Flax NNX module using Orbax.\n","\n","### Topics:\n","\n","* Defining an nnx.Module.\n","* Instantiating an nnx.Module with initial parameters.\n","* Using nnx.split() to extract the nnx.State Pytree.\n","* Setting up ocp.CheckpointManager.\n","* Saving the state using mngr.save() with ocp.args.StandardSave.\n","\n","### Instructions:\n","\n","1. Define a simple linear layer SimpleLinear that inherits from nnx.Module.\n"," - In its __init__, define a weight matrix and a bias vector as nnx.Param attributes. Initialize them with JAX random functions (e.g., jax.random.uniform for weights, jnp.zeros for bias). Remember nnx.Rngs for key management!\n"," - Implement the __call__ method for the forward pass: y = x @ weight + bias.\n","2. Instantiate this SimpleLinear module.\n","3. Specify a directory for saving checkpoints.\n","4. Create an ocp.CheckpointManagerOptions object to configure checkpointing (e.g., max_to_keep=3).\n","5. Instantiate ocp.CheckpointManager with the directory and options.\n","6. Use nnx.split(model) to get the graphdef and the state_to_save.\n","7. Save the state_to_save at a specific training step (e.g., step 100) using mngr.save(). You'll need to wrap state_to_save with ocp.args.StandardSave().\n","8. Call mngr.wait_until_finished() to ensure the save operation completes (important if saving is asynchronous).\n","9. Close the manager using mngr.close()."],"metadata":{"id":"H__qBFj1eyVK"}},{"cell_type":"code","source":["# --- Define the NNX Module ---\n","class SimpleLinear(nnx.Module):\n"," def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):\n"," key_w, key_b = rngs.params(), rngs.params() # Example of splitting keys if needed, or use one key for multiple params\n"," # TODO: Define self.weight as an nnx.Param with shape (din, dout)\n"," # self.weight = ...\n"," # TODO: Define self.bias as an nnx.Param with shape (dout,)\n"," # self.bias = ...\n","\n"," def __call__(self, x: jax.Array) -> jax.Array:\n"," # TODO: Implement the forward pass\n"," # return ...\n","\n","# --- Instantiate the Model ---\n","din, dout = 10, 5\n","# TODO: Create an nnx.Rngs object for parameter initialization\n","# rngs = ...\n","# TODO: Instantiate SimpleLinear\n","# model = ...\n","\n","print(f\"Model created. Weight shape: {model.weight.value.shape}, Bias shape: {model.bias.value.shape}\")\n","\n","# --- Setup CheckpointManager ---\n","ckpt_dir_ex1 = os.path.join(CKPT_BASE_DIR, 'ex1_basic_save')\n","cleanup_ckpt_dir(ckpt_dir_ex1) # Clean up from previous runs\n","\n","# TODO: Create CheckpointManagerOptions\n","# options = ...\n","# TODO: Instantiate CheckpointManager\n","# mngr = ...\n","\n","# --- Split the model to get the state ---\n","# TODO: Split the model into graphdef and state_to_save\n","# _graphdef, state_to_save = ...\n","# Alternatively, for just the state: state_to_save = nnx.state(model)\n","# print(f\"State to save: {jax.tree_util.tree_map(lambda x: x.shape if hasattr(x, 'shape') else x, state_to_save)}\")\n","\n","# --- Save the state ---\n","step = 100\n","# TODO: Save the state_to_save at the given step. Use ocp.args.StandardSave.\n","# mngr.save(...)\n","# TODO: Wait for saving to complete\n","# mngr.wait_until_finished()\n","\n","print(f\"Checkpoint saved for step {step} in {ckpt_dir_ex1}.\")\n","print(f\"Available checkpoints: {mngr.all_steps()}\")\n","\n","# TODO: Close the manager\n","# mngr.close()"],"metadata":{"id":"5jXjWVyVdF5G"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Exercise 1: Solution\n","# --- Define the NNX Module ---\n","class SimpleLinear(nnx.Module):\n"," def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):\n"," # Parameters defined using nnx.Param (a type of nnx.Variable)\n"," self.weight = nnx.Param(jax.random.uniform(rngs.params(), (din, dout)))\n"," self.bias = nnx.Param(jnp.zeros((dout,)))\n","\n"," def __call__(self, x: jax.Array) -> jax.Array:\n"," # Parameters used directly via self.weight, self.bias\n"," return x @ self.weight.value + self.bias.value\n","\n","# --- Instantiate the Model ---\n","din, dout = 10, 5\n","rngs = nnx.Rngs(params=jax.random.key(0)) # NNX requires explicit RNG management\n","model = SimpleLinear(din=din, dout=dout, rngs=rngs)\n","\n","print(f\"Model created. Weight shape: {model.weight.value.shape}, Bias shape: {model.bias.value.shape}\")\n","\n","# --- Setup CheckpointManager ---\n","ckpt_dir_ex1 = os.path.join(CKPT_BASE_DIR, 'ex1_basic_save')\n","cleanup_ckpt_dir(ckpt_dir_ex1)\n","\n","options = ocp.CheckpointManagerOptions(max_to_keep=3, save_interval_steps=1)\n","mngr = ocp.CheckpointManager(ckpt_dir_ex1, options=options)\n","\n","# --- Split the model to get the state ---\n","_graphdef, state_to_save = nnx.split(model)\n","# Alternatively: state_to_save = nnx.state(model)\n","print(f\"State to save structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else type(x), state_to_save)}\")\n","\n","# --- Save the state ---\n","step = 100\n","mngr.save(step, args=ocp.args.StandardSave(state_to_save))\n","mngr.wait_until_finished() # Ensure save completes if async\n","\n","print(f\"Checkpoint saved for step {step} in {ckpt_dir_ex1}.\")\n","print(f\"Available checkpoints: {mngr.all_steps()}\")\n","\n","mngr.close() # Clean up resources"],"metadata":{"id":"N72kLQJff3Ex"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Exercise 2: Basic Checkpointing - Restoring nnx.State\n","\n","**Goal**: Learn to restore a model's state from a checkpoint using Orbax.\n","\n","###Topics:\n","\n","* Using nnx.eval_shape() to create an \"abstract\" model template.\n","* Splitting the abstract model to get an abstract_state (a Pytree of ShapeDtypeStruct objects).\n","* Restoring the state using mngr.restore() with the abstract_state and ocp.args.StandardRestore.\n","* Reconstructing the model using nnx.merge() with the original graphdef and the restored_state.\n","* Alternatively, updating an existing model instance with nnx.update().\n","\n","### Instructions:\n","\n","1. Re-open the CheckpointManager pointing to the directory from Exercise 1 (ckpt_dir_ex1).\n","2. Define a function create_abstract_model() that instantiates your SimpleLinear module. This function will be passed to nnx.eval_shape().\n"," - Inside this function, use dummy RNG keys and input shapes as nnx.eval_shape only cares about the structure and dtypes, not actual values.\n","3. Create an abstract_model by calling abstract_model = nnx.eval_shape(create_abstract_model).\n","4. Split the abstract_model using graphdef_for_restore, abstract_state = nnx.split(abstract_model). The abstract_state now contains ShapeDtypeStruct leaves, which Orbax uses as a template for restoration.\n","5. Find the latest checkpoint step using mngr.latest_step().\n","6. If a checkpoint exists, restore the state using mngr.restore(step_to_restore, args=ocp.args.StandardRestore(abstract_state)).\n","7. Reconstruct the model using restored_model = nnx.merge(graphdef_for_restore, restored_state).\n","8. (Optional) Print a value from the restored model (e.g., restored_model.bias.value) to verify.\n","9. Close the manager."],"metadata":{"id":"kdhOVjSOgTKy"}},{"cell_type":"code","source":["# Ensure the SimpleLinear class definition from Exercise 1 is available\n","\n","# --- Re-open CheckpointManager ---\n","# TODO: Instantiate CheckpointManager for ckpt_dir_ex1 (no need for options if just restoring)\n","# mngr_restore = ...\n","\n","# --- Create Abstract Model for Restoration ---\n","def create_abstract_model():\n"," # Use dummy RNG key/inputs for abstract creation\n"," # TODO: Return an instance of SimpleLinear, same din/dout as before\n"," # return ...\n","\n","# TODO: Create the abstract_model using nnx.eval_shape\n","# abstract_model = ...\n","\n","# --- Split Abstract Model to get Abstract State Structure ---\n","# TODO: Split the abstract_model to get graphdef_for_restore and abstract_state\n","# graphdef_for_restore, abstract_state = ...\n","print(f\"Abstract state structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else x, abstract_state)}\")\n","\n","\n","# --- Restore the State ---\n","# TODO: Get the latest step to restore\n","# step_to_restore = ...\n","\n","if step_to_restore is not None:\n"," # TODO: Restore the state using mngr_restore.restore() and ocp.args.StandardRestore with abstract_state\n"," # restored_state = mngr_restore.restore(...)\n","\n"," # --- Reconstruct the Model ---\n"," # TODO: Reconstruct the model using nnx.merge with graphdef_for_restore and restored_state\n"," # restored_model = ...\n"," print(f\"Model restored from step {step_to_restore}.\")\n"," # You can now use 'restored_model'\n"," print(f\"Restored bias (first 3 values): {restored_model.bias.value[:3]}\")\n","\n"," # Alternative: Update an existing model instance\n"," # model_to_update = SimpleLinear(din=din, dout=dout, rngs=nnx.Rngs(params=jax.random.key(99))) # Fresh instance\n"," # nnx.update(model_to_update, restored_state)\n"," # print(f\"Updated model bias (first 3 values): {model_to_update.bias.value[:3]}\")\n","else:\n"," print(\"No checkpoint found to restore.\")\n","\n","# TODO: Close the manager\n","# mngr_restore.close()"],"metadata":{"id":"CugEuqYCgB-5"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Exercise 2: Solution\n","\n","# Ensure the SimpleLinear class definition from Exercise 1 is available\n","\n","# --- Re-open CheckpointManager ---\n","mngr_restore = ocp.CheckpointManager(ckpt_dir_ex1) # Re-open manager\n","\n","# --- Create Abstract Model for Restoration ---\n","def create_abstract_model():\n"," # Use dummy RNG key/inputs for abstract creation\n"," return SimpleLinear(din=din, dout=dout, rngs=nnx.Rngs(params=jax.random.key(0))) # din, dout from Ex1\n","\n","abstract_model = nnx.eval_shape(create_abstract_model)\n","\n","# --- Split Abstract Model to get Abstract State Structure ---\n","graphdef_for_restore, abstract_state = nnx.split(abstract_model)\n","# abstract_state now contains ShapeDtypeStruct leaves\n","print(f\"Abstract state structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else x, abstract_state)}\")\n","\n","# --- Restore the State ---\n","step_to_restore = mngr_restore.latest_step()\n","\n","if step_to_restore is not None:\n"," restored_state = mngr_restore.restore(step_to_restore,\n"," args=ocp.args.StandardRestore(abstract_state))\n"," print(f\"Restored state structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else type(x), restored_state)}\")\n","\n"," # --- Reconstruct the Model ---\n"," restored_model = nnx.merge(graphdef_for_restore, restored_state)\n"," print(f\"Model restored from step {step_to_restore}.\")\n"," # You can now use 'restored_model'\n"," print(f\"Restored bias (first 3 values): {restored_model.bias.value[:3]}\")\n","\n"," # Compare with original model's bias (optional, if 'model' from Ex1 is still in scope)\n"," # print(f\"Original bias (first 3 values): {model.bias.value[:3]}\")\n"," # chex.assert_trees_all_close(restored_model.bias.value, model.bias.value)\n","\n"," # Alternative: Update an existing model instance\n"," model_to_update = SimpleLinear(din=din, dout=dout, rngs=nnx.Rngs(params=jax.random.key(99))) # Fresh instance\n"," # Initialize with different values to see update working\n"," model_to_update.bias.value = jnp.ones_like(model_to_update.bias.value) * 55.0\n"," print(f\"Bias before update: {model_to_update.bias.value[:3]}\")\n"," nnx.update(model_to_update, restored_state)\n"," print(f\"Updated model bias (first 3 values): {model_to_update.bias.value[:3]}\")\n"," if 'model' in globals(): # Check if original model exists\n"," chex.assert_trees_all_close(model_to_update.bias.value, model.bias.value)\n","else:\n"," print(\"No checkpoint found to restore.\")\n","\n","mngr_restore.close()"],"metadata":{"id":"QJz_4Pm9g-AG"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Exercise 3: Saving Model Parameters and Optimizer State\n","\n","**Goal**: Learn to save both model parameters and optimizer state together in a single checkpoint.\n","\n","### Topics:\n","\n","* Using nnx.Optimizer to manage model parameters and an Optax optimizer state.\n","* Extracting model parameters (e.g., using nnx.split(model, nnx.Param)).\n","* Extracting the full optimizer state (nnx.state(optimizer)).\n","* Using ocp.args.Composite to save multiple named items (model params, optimizer state) in one checkpoint.\n","\n","### Instructions:\n","\n","1. Reuse the SimpleLinear module definition. Instantiate a new SimpleLinear model.\n","2. Create an Optax optimizer (e.g., optax.adam(learning_rate=1e-3)).\n","3. Wrap the model and the Optax optimizer with nnx.Optimizer.\n","4. (Optional) Simulate a few training steps to update the optimizer's internal state (e.g., momentum). You don't need actual data; just update the step count and imagine gradients were applied.\n"," - Access optimizer step via optimizer.step.value. Update it: optimizer.step.value += 1.\n","5. Set up a new CheckpointManager in a new directory (ckpt_dir_ex3).\n","6. Extract the model's parameters: _graphdef_params, params_state = nnx.split(model_ex3, nnx.Param). Note that the optimizer.model attribute has been removed, so we split the original model variable directly.\n","7. Extract the full optimizer state: optimizer_state_tree = nnx.state(optimizer). This includes optimizer internal states (like momentum) and its own step count.\n","8. Define a dictionary save_items where keys are names (e.g., 'params', 'optimizer') and values are ocp.args.StandardSave() wrapped Pytrees (i.e., params_state and optimizer_state_tree).\n","9. Save these items using mngr.save(step, args=ocp.args.Composite(**save_items)). Use the optimizer's current step.\n","10. Wait and close the manager."],"metadata":{"id":"9ZhJATu8hkYw"}},{"cell_type":"code","source":["# Ensure SimpleLinear class definition is available\n","# --- Instantiate Model and Optimizer ---\n","rngs_ex3 = nnx.Rngs(params=jax.random.key(1))\n","model_ex3 = SimpleLinear(din=10, dout=5, rngs=rngs_ex3)\n","\n","# TODO: Create an Optax optimizer (e.g., Adam)\n","# tx = ...\n","# TODO: Create an nnx.Optimizer, wrapping the model and tx\n","# optimizer = ...\n","\n","# Simulate a few \"training\" steps to populate optimizer state\n","# For a real scenario, this would involve gradients and updates\n","if hasattr(optimizer, 'step') and hasattr(optimizer.step, 'value'): # Check for NNX Optimizer structure\n"," optimizer.step.value += 10 # Simulate 10 steps\n"," # In a real loop: optimizer.update_fn(grads, optimizer.state) -> optimizer.state would be updated\n"," # For this exercise, just advancing step is enough to see it saved/restored.\n"," # Let's also change a parameter slightly to see it saved\n"," original_bias_val_ex3 = model_ex3.bias.value.copy()\n"," model_ex3.bias.value = model_ex3.bias.value * 0.5 + 0.1\n"," print(f\"Optimizer step: {optimizer.step.value}\")\n"," print(f\"Bias modified. Original first val: {original_bias_val_ex3[0]}, New first val: {model_ex3.bias.value[0]}\")\n","else:\n"," print(\"Skipping optimizer step update as structure might differ from expected nnx.Optimizer.\")\n","\n","\n","# --- Setup CheckpointManager for Composite Save ---\n","ckpt_dir_ex3 = os.path.join(CKPT_BASE_DIR, 'ex3_composite_save')\n","cleanup_ckpt_dir(ckpt_dir_ex3)\n","# TODO: Instantiate CheckpointManager for ckpt_dir_ex3\n","# mngr_comp = ...\n","\n","# --- Extract States for Saving ---\n","# TODO: Extract model parameters state from optimizer.model using nnx.split with nnx.Param filter\n","# _graphdef_params, params_state = ...\n","# TODO: Extract the full optimizer state tree using nnx.state()\n","# optimizer_state_tree = ...\n","\n","print(f\"Parameter state structure: {jax.tree_util.tree_map(lambda x: x.shape if hasattr(x, 'shape') else x, params_state)}\")\n","print(f\"Optimizer state structure: {jax.tree_util.tree_map(lambda x: x.shape if hasattr(x, 'shape') else x, optimizer_state_tree)}\")\n","\n","# --- Save Composite State ---\n","current_step_val = 0\n","if hasattr(optimizer, 'step') and hasattr(optimizer.step, 'value'):\n"," current_step_val = optimizer.step.value\n","else: # Fallback for safety, though nnx.Optimizer should have .step\n"," current_step_val = 10\n","\n","\n","# TODO: Define save_items dictionary for 'params' and 'optimizer'\n","# Each item should be wrapped with ocp.args.StandardSave\n","# save_items = {\n","# 'params': ...,\n","# 'optimizer': ...\n","# }\n","\n","# TODO: Save using mngr_comp.save() and ocp.args.Composite\n","# mngr_comp.save(...)\n","# TODO: Wait and close the manager\n","# mngr_comp.wait_until_finished()\n","# print(f\"Composite checkpoint saved for step {current_step_val} in {ckpt_dir_ex3}.\")\n","# print(f\"Available checkpoints: {mngr_comp.all_steps()}\")\n","# mngr_comp.close()"],"metadata":{"id":"9ZTdYGw3hGJq"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Exercise 3: Solution\n","\n","# Ensure SimpleLinear class definition is available\n","# --- Instantiate Model and Optimizer ---\n","rngs_ex3 = nnx.Rngs(params=jax.random.key(1))\n","model_ex3 = SimpleLinear(din=10, dout=5, rngs=rngs_ex3)\n","\n","tx = optax.adam(learning_rate=1e-3)\n","optimizer = nnx.Optimizer(model_ex3, tx, wrt=nnx.Param)\n","\n","# Simulate a few \"training\" steps to populate optimizer state\n","# For a real scenario, this would involve gradients and updates\n","optimizer.step.value += 10 # Simulate 10 steps\n","original_bias_val_ex3 = model_ex3.bias.value.copy()\n","# Simulate a parameter update that would happen during training\n","model_ex3.bias.value = model_ex3.bias.value * 0.5 + 0.1 # Arbitrary change\n","print(f\"Optimizer step: {optimizer.step.value}\")\n","print(f\"Bias modified. Original first val: {original_bias_val_ex3[0]}, New first val: {model_ex3.bias.value[0]}\")\n","\n","# --- Setup CheckpointManager for Composite Save ---\n","ckpt_dir_ex3 = os.path.join(CKPT_BASE_DIR, 'ex3_composite_save')\n","cleanup_ckpt_dir(ckpt_dir_ex3)\n","mngr_comp = ocp.CheckpointManager(ckpt_dir_ex3, options=ocp.CheckpointManagerOptions(max_to_keep=3))\n","\n","# --- Extract States for Saving ---\n","# Extract model parameters (e.g., using nnx.split(model, nnx.Param))\n","_graphdef_params, params_state = nnx.split(model_ex3, nnx.Param)\n","# Extract optimizer state (nnx.state(optimizer))\n","optimizer_state_tree = nnx.state(optimizer)\n","\n","print(f\"Parameter state structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else type(x), params_state)}\")\n","print(f\"Optimizer state structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else type(x), optimizer_state_tree)}\")\n","# Note: optimizer_state_tree also contains the model's state within optimizer.model_variables\n","\n","# --- Save Composite State ---\n","current_step_val = optimizer.step.value # Get current step from optimizer\n","\n","# Save using Composite args\n","save_items = {\n"," 'params': ocp.args.StandardSave(params_state),\n"," 'optimizer': ocp.args.StandardSave(optimizer_state_tree)\n","}\n","\n","# Can generate args per item using orbax_utils too\n","mngr_comp.save(current_step_val, args=ocp.args.Composite(**save_items))\n","mngr_comp.wait_until_finished()\n","print(f\"Composite checkpoint saved for step {current_step_val} in {ckpt_dir_ex3}.\")\n","print(f\"Available checkpoints: {mngr_comp.all_steps()}\")\n","mngr_comp.close()"],"metadata":{"id":"V0lKNBMQiKBh"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Exercise 4: Restoring Model Parameters and Optimizer State\n","\n","**Goal**: Learn to restore both model parameters and optimizer state from a composite checkpoint.\n","\n","### Topics:\n","\n","* Creating abstract versions of both model and optimizer using nnx.eval_shape.\n","* Getting abstract state templates for both parameter state and optimizer state.\n","* Using ocp.args.Composite with ocp.args.StandardRestore for restoring multiple items.\n","* Instantiating new concrete model and optimizer instances.\n","* Updating these instances using nnx.update() with the restored states.\n","\n","### Instructions:\n","\n","1. Re-open the CheckpointManager from Exercise 3 (ckpt_dir_ex3).\n","2. Define a function create_abstract_model_and_optimizer():\n"," - Inside, create an abstract model instance (e.g., SimpleLinear) using nnx.eval_shape on a creation lambda.\n"," - Then, create an abstract nnx.Optimizer instance using nnx.eval_shape, passing the abstract model and a new Optax optimizer instance to its creation lambda.\n"," - Return both abs_model and abs_optimizer.\n","3. Call this function to get abs_model and abs_optimizer.\n","4. Get the abstract state for parameters: _graphdef_abs_params, abs_params_state = nnx.split(abs_model, nnx.Param).\n","5. Get the abstract state for the optimizer: abs_optimizer_state = nnx.state(abs_optimizer).\n","6. Find the latest step to restore.\n","7. If a checkpoint exists, define a restore_targets dictionary for ocp.args.Composite. Keys should match those used during save ('params', 'optimizer'), and values should be ocp.args.StandardRestore() wrapped abstract states.\n","8. Restore using mngr_comp.restore(step, args=ocp.args.Composite(**restore_targets)). This will return a dictionary restored_items.\n","9. Create new, \"fresh\" instances of your SimpleLinear model and nnx.Optimizer.\n","10. Update the fresh model in-place using nnx.update(fresh_model, restored_items['params']).\n","11. Update the fresh optimizer in-place using nnx.update(fresh_optimizer, restored_items['optimizer']).\n","12. Verify by checking the optimizer's step and a model parameter.\n","13. Close the manager."],"metadata":{"id":"xPqCsnJNidgw"}},{"cell_type":"code","source":["# Ensure SimpleLinear class definition is available\n","# --- Re-open CheckpointManager ---\n","# TODO: Instantiate CheckpointManager for ckpt_dir_ex3\n","# mngr_comp_restore = ...\n","\n","# --- Create Abstract Model and Optimizer ---\n","def create_abstract_model_and_optimizer():\n"," rngs_abs = nnx.Rngs(params=jax.random.key(0)) # Dummy key for abstract creation\n"," # TODO: Create abstract model. Model class: SimpleLinear(din=10, dout=5, ...)\n"," # abs_model = SimpleLinear(...)\n","\n"," # TODO: Create abstract optimizer. Pass abs_model and an optax.adam instance.\n"," # abs_opt = nnx.Optimizer(...)\n"," # return abs_model, abs_opt\n","\n","# TODO: Call the function to get abstract model and optimizer\n","# abs_model_restore, abs_optimizer_restore = ...\n","\n","# --- Get Abstract States ---\n","# TODO: Get abstract parameter state from abs_model_restore (filter with nnx.Param)\n","# _graphdef_abs_params, abs_params_state = ...\n","# TODO: Get abstract optimizer state from abs_optimizer_restore\n","# abs_optimizer_state = ...\n","\n","print(f\"Abstract params state structure: {jax.tree_util.tree_map(lambda x: x.shape if hasattr(x, 'shape') else x, abs_params_state)}\")\n","print(f\"Abstract optimizer state structure: {jax.tree_util.tree_map(lambda x: x.shape if hasattr(x, 'shape') else x, abs_optimizer_state)}\")\n","\n","# --- Restore Composite State ---\n","# TODO: Get the latest step\n","# step_to_restore_comp = ...\n","\n","if step_to_restore_comp is not None:\n"," # TODO: Define restore_targets dictionary for 'params' and 'optimizer'\n"," # Each item should be wrapped with ocp.args.StandardRestore and its corresponding abstract state.\n"," # restore_targets = {\n"," # 'params': ...,\n"," # 'optimizer': ...\n"," # }\n"," # TODO: Restore items using mngr_comp_restore.restore() and ocp.args.Composite\n"," # restored_items = mngr_comp_restore.restore(...)\n","\n"," # --- Instantiate and Update Concrete Model/Optimizer ---\n"," # TODO: Create a fresh SimpleLinear model instance (use a new RNG key, e.g., key(2))\n"," # fresh_model = ...\n"," # TODO: Create a fresh nnx.Optimizer instance with fresh_model and a new optax.adam instance\n"," # fresh_optimizer = ...\n","\n"," # Store pre-update values for comparison\n"," pre_update_bias = fresh_model.bias.value.copy()\n"," pre_update_opt_step = fresh_optimizer.step.value\n","\n"," # TODO: Update fresh_model with restored_items['params'] using nnx.update()\n"," # nnx.update(...)\n"," # TODO: Update fresh_optimizer with restored_items['optimizer'] using nnx.update()\n"," # nnx.update(...)\n","\n"," print(f\"Restored and updated. Optimizer step: {fresh_optimizer.step.value}\")\n"," print(f\"Fresh model bias before update (first val): {pre_update_bias[0]}\")\n"," print(f\"Fresh model bias after update (first val): {fresh_model.bias.value[0]}\")\n"," print(f\"Original bias from Ex3 (first val): {model_ex3.bias.value[0]}\") # model_ex3 is from previous cell\n","\n"," # Verification\n"," # chex.assert_trees_all_close(fresh_model.bias.value, model_ex3.bias.value) # Compare with the state that was saved\n"," # assert fresh_optimizer.step.value == optimizer.step.value # Compare with optimizer state that was saved\n","else:\n"," print(\"No composite checkpoint found.\")\n","\n","# TODO: Close the manager\n","# mngr_comp_restore.close()"],"metadata":{"id":"FIoLTU7tiUZv"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Exercise 4: Solution\n","\n","# Ensure SimpleLinear class definition is available\n","# --- Re-open CheckpointManager ---\n","mngr_comp_restore = ocp.CheckpointManager(ckpt_dir_ex3)\n","\n","# --- Create Abstract Model and Optimizer ---\n","def create_abstract_model_and_optimizer():\n"," rngs_abs = nnx.Rngs(params=jax.random.key(0)) # Dummy key for abstract creation\n"," # Create abstract model\n"," abs_model = SimpleLinear(din=10, dout=5, rngs=rngs_abs)\n"," # Create abstract optimizer\n"," abs_opt = nnx.Optimizer(abs_model, optax.adam(1e-3), wrt=nnx.Param)\n"," return abs_model, abs_opt\n","\n","abs_model_restore, abs_optimizer_restore = create_abstract_model_and_optimizer()\n","\n","# --- Get Abstract States ---\n","_graphdef_abs_params, abs_params_state = nnx.split(abs_model_restore, nnx.Param)\n","abs_optimizer_state = nnx.state(abs_optimizer_restore)\n","\n","print(f\"Abstract params state structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else type(x), abs_params_state)}\")\n","print(f\"Abstract optimizer state structure: {jax.tree_util.tree_map(lambda x: (x.shape, x.dtype) if hasattr(x, 'shape') else type(x), abs_optimizer_state)}\")\n","\n","# --- Restore Composite State ---\n","step_to_restore_comp = mngr_comp_restore.latest_step()\n","\n","if step_to_restore_comp is not None:\n"," restore_targets = {\n"," 'params': ocp.args.StandardRestore(abs_params_state),\n"," 'optimizer': ocp.args.StandardRestore(abs_optimizer_state)\n"," }\n"," restored_items = mngr_comp_restore.restore(step_to_restore_comp, args=ocp.args.Composite(**restore_targets))\n","\n"," # --- Instantiate and Update Concrete Model/Optimizer ---\n"," # Create fresh instances\n"," fresh_rngs = nnx.Rngs(params=jax.random.key(2)) # Use a different key for the fresh model\n"," fresh_model = SimpleLinear(din=10, dout=5, rngs=fresh_rngs)\n"," fresh_optimizer = nnx.Optimizer(fresh_model, optax.adam(1e-3), wrt=nnx.Param) # Matching optax optimizer\n","\n"," # Store pre-update values for comparison\n"," pre_update_bias = fresh_model.bias.value.copy()\n"," pre_update_opt_step = fresh_optimizer.step.value\n","\n"," # Update using restored states\n"," nnx.update(fresh_model, restored_items['params'])\n"," nnx.update(fresh_optimizer, restored_items['optimizer'])\n","\n"," print(f\"Restored and updated. Optimizer step: {fresh_optimizer.step.value}\")\n"," print(f\"Fresh model bias before update (first val): {pre_update_bias[0]}\") # Will be from key(2)\n"," print(f\"Fresh model bias after update (first val): {fresh_model.bias.value[0]}\") # Should match model_ex3 bias\n","\n"," # Verification (model_ex3 and optimizer are from the previous cell where they were saved)\n"," chex.assert_trees_all_close(fresh_model.bias.value, model_ex3.bias.value)\n"," assert fresh_optimizer.step.value == optimizer.step.value\n"," print(\"Verification successful: Restored model parameters and optimizer step match the saved state.\")\n","else:\n"," print(\"No composite checkpoint found.\")\n","\n","mngr_comp_restore.close()"],"metadata":{"id":"xB50SxmBjMJr"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Exercise 5: Distributed Checkpointing - Saving Sharded State\n","\n","**Goal**: Understand how to save model state that has been sharded across multiple devices. Orbax handles sharded JAX arrays efficiently.\n","\n","### Topics:\n","\n","* Setting up a JAX device Mesh.\n","* Defining PartitionSpec for sharding arrays.\n","* Creating sharded parameters within an nnx.Module. One way is to initialize parameters and then use jax.device_put with NamedSharding to shard them, then update the module's state. NNX also allows attaching sharding annotations directly to nnx.Variable metadata.\n","* Saving sharded state: Orbax handles sharded arrays transparently during saving if the JAX arrays in the state Pytree are already sharded.\n","\n","### Instructions:\n","\n","1. Define the number of devices and create a device mesh (e.g., a 1D mesh with all available devices).\n","2. Modify the SimpleLinear module (or create ShardedSimpleLinear):\n","* In `__init__`, after initializing parameters, you'll shard them.\n","* For the weight matrix (din, dout), let's shard it along the dout dimension (e.g., PartitionSpec(None, 'data')).\n","* The bias vector (dout,) will also be sharded along its only dimension (PartitionSpec('data')).\n","* To apply sharding:\n"," - Create NamedSharding objects from your PartitionSpec and the mesh.\n"," - Use jax.device_put(param_value, named_sharding) to get sharded JAX arrays.\n"," - Update the .value of your nnx.Param attributes with these sharded arrays.\n","3. Instantiate your sharded model within the mesh context manager (with mesh:). This ensures operations are aware of the mesh.\n","4. Set up a CheckpointManager in a new directory (ckpt_dir_ex5).\n","5. Split the sharded model to get its state: _graphdef_sharded, sharded_state_to_save = nnx.split(sharded_model). The arrays within sharded_state_to_save should now be jax.Array objects with sharding information.\n","6. Save this sharded_state_to_save using mngr.save(). The process is the same as non-sharded saving from Orbax's perspective.\n","7. Wait and close."],"metadata":{"id":"1n58wd9gm1Pq"}},{"cell_type":"code","source":["# --- Setup JAX Mesh ---\n","num_devices = jax.device_count()\n","# If num_devices is 1 after chex.set_n_cpu_devices(8), it means JAX didn't pick up the fakes.\n","# This can happen if JAX initializes its backends before chex runs.\n","# Forcing a rerun of this cell or restarting runtime and running setup first might help.\n","print(f\"Using {num_devices} devices for sharding.\")\n","device_mesh = mesh_utils.create_device_mesh((num_devices,))\n","mesh = Mesh(devices=device_mesh, axis_names=('data',)) # 1D mesh\n","print(mesh)\n","\n","# --- Define Sharded NNX Module ---\n","class ShardedSimpleLinear(nnx.Module):\n"," def __init__(self, din: int, dout: int, mesh: Mesh, *, rngs: nnx.Rngs):\n"," self.din = din\n"," self.dout = dout\n"," self.mesh = mesh\n","\n"," key_w, key_b = rngs.params(), rngs.params()\n","\n"," # Initialize as regular JAX arrays first\n"," initial_weight = jax.random.uniform(key_w, (din, dout))\n"," initial_bias = jnp.zeros((dout,))\n","\n"," # TODO: Define PartitionSpec for weight (shard dout across 'data' axis)\n"," # e.g., PartitionSpec(None, 'data') means not sharded on dim 0, sharded on dim 1\n"," # weight_pspec = ...\n"," # TODO: Define PartitionSpec for bias (shard along 'data' axis)\n"," # bias_pspec = ...\n","\n"," # TODO: Create NamedSharding for weight and bias using self.mesh and the pspecs\n"," # weight_sharding = NamedSharding(...)\n"," # bias_sharding = NamedSharding(...)\n","\n"," # TODO: Shard the initial arrays using jax.device_put and the NamedSharding\n"," # sharded_weight_value = jax.device_put(...)\n"," # sharded_bias_value = jax.device_put(...)\n","\n"," # TODO: Assign these sharded arrays to nnx.Param attributes\n"," # self.weight = nnx.Param(sharded_weight_value)\n"," # self.bias = nnx.Param(sharded_bias_value)\n","\n"," # Alternative (more direct with nnx.Variable metadata if supported well for this case):\n"," # self.weight = nnx.Param(initial_weight, sharding=weight_sharding) # This depends on NNX API\n"," # For this exercise, jax.device_put is explicit and clear.\n","\n"," def __call__(self, x: jax.Array) -> jax.Array:\n"," # x is assumed to be replicated or appropriately sharded for the matmul\n"," # For simplicity, assume x is replicated if din is not sharded, or sharded compatibly.\n"," return x @ self.weight.value + self.bias.value\n","\n","# --- Instantiate Sharded Model within Mesh context ---\n","din_s, dout_s = 8, num_devices * 2 # Ensure dout is divisible by num_devices for even sharding\n","rngs_sharded = nnx.Rngs(params=jax.random.key(3))\n","\n","# TODO: Instantiate ShardedSimpleLinear within the mesh context\n","# with mesh:\n","# sharded_model = ...\n","\n","# print(f\"Sharded model created. Weight sharding: {sharded_model.weight.value.sharding}\")\n","# print(f\"Sharded model bias sharding: {sharded_model.bias.value.sharding}\")\n","\n","\n","# --- Setup CheckpointManager for Sharded Save ---\n","ckpt_dir_ex5 = os.path.join(CKPT_BASE_DIR, 'ex5_sharded_save')\n","cleanup_ckpt_dir(ckpt_dir_ex5)\n","# TODO: Instantiate CheckpointManager\n","# mngr_sharded_save = ...\n","\n","# --- Split and Save Sharded State ---\n","# TODO: Split the sharded_model\n","# _graphdef_sharded, sharded_state_to_save = ...\n","\n","# print(f\"Sharded state to save (bias type): {type(sharded_state_to_save['bias'].value)}\")\n","# print(f\"Sharded state to save (bias sharding): {sharded_state_to_save['bias'].value.sharding}\")\n","\n","# current_step_sharded = 200\n","# TODO: Save the sharded_state_to_save\n","# mngr_sharded_save.save(...)\n","# TODO: Wait and close\n","# mngr_sharded_save.wait_until_finished()\n","# print(f\"Sharded checkpoint saved for step {current_step_sharded} in {ckpt_dir_ex5}.\")\n","# mngr_sharded_save.close()"],"metadata":{"id":"tzeRStI1jVuf"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Exercise 5: Solution\n","\n","# --- Setup JAX Mesh ---\n","num_devices = jax.device_count()\n","if num_devices == 1 and chex.set_n_cpu_devices.called_in_process: # If we faked 8 but only see 1\n"," print(\"Warning: JAX might not be using the faked CPU devices. Restart runtime and run Setup cell first if sharding tests fail.\")\n","print(f\"Using {num_devices} devices for sharding.\")\n","# Ensure a 1D mesh for simplicity, using all available (or faked) devices.\n","device_mesh = mesh_utils.create_device_mesh((num_devices,))\n","mesh = Mesh(devices=device_mesh, axis_names=('data',)) # 1D mesh for 'data' parallelism\n","print(mesh)\n","\n","# --- Define Sharded NNX Module ---\n","class ShardedSimpleLinear(nnx.Module):\n"," def __init__(self, din: int, dout: int, mesh: Mesh, *, rngs: nnx.Rngs):\n"," self.din = din\n"," self.dout = dout\n"," self.mesh = mesh # Store mesh for creating NamedSharding\n","\n"," key_w, key_b = rngs.params(), rngs.params()\n","\n"," initial_weight = jax.random.uniform(key_w, (din, dout))\n"," initial_bias = jnp.zeros((dout,))\n","\n"," # Define PartitionSpec for sharding\n"," # Shard weight's second dimension (dout) across the 'data' mesh axis\n"," weight_pspec = PartitionSpec(None, 'data')\n"," # Shard bias's only dimension (dout) across the 'data' mesh axis\n"," bias_pspec = PartitionSpec('data',)\n","\n"," # Create NamedSharding from PartitionSpec and mesh\n"," weight_sharding = NamedSharding(self.mesh, weight_pspec)\n"," bias_sharding = NamedSharding(self.mesh, bias_pspec)\n","\n"," # Shard the initial arrays using jax.device_put\n"," # This ensures the arrays are created with the specified sharding\n"," sharded_weight_value = jax.device_put(initial_weight, weight_sharding)\n"," sharded_bias_value = jax.device_put(initial_bias, bias_sharding)\n","\n"," self.weight = nnx.Param(sharded_weight_value)\n"," self.bias = nnx.Param(sharded_bias_value)\n"," # Note: Flax NNX aims to allow sharding annotations directly in nnx.Variable metadata\n"," # e.g., using nnx.spmd.with_partitioning or passing sharding to nnx.Param.\n"," # Explicit jax.device_put is also a valid way to get sharded arrays into the state.\n","\n"," def __call__(self, x: jax.Array) -> jax.Array:\n"," return x @ self.weight.value + self.bias.value\n","\n","# --- Instantiate Sharded Model within Mesh context ---\n","din_s, dout_s = 8, num_devices * 2 # Make dout divisible by num_devices\n","rngs_sharded = nnx.Rngs(params=jax.random.key(3))\n","\n","with mesh: # Operations within this context are aware of the mesh\n"," sharded_model = ShardedSimpleLinear(din_s, dout_s, mesh, rngs=rngs_sharded)\n","\n","print(f\"Sharded model created. Weight sharding: {sharded_model.weight.value.sharding}\")\n","print(f\"Sharded model bias sharding: {sharded_model.bias.value.sharding}\")\n","\n","# --- Setup CheckpointManager for Sharded Save ---\n","ckpt_dir_ex5 = os.path.join(CKPT_BASE_DIR, 'ex5_sharded_save')\n","cleanup_ckpt_dir(ckpt_dir_ex5)\n","mngr_sharded_save = ocp.CheckpointManager(ckpt_dir_ex5, options=ocp.CheckpointManagerOptions(max_to_keep=1))\n","\n","# --- Split and Save Sharded State ---\n","# The live state already contains sharded jax.Array objects\n","_graphdef_sharded, sharded_state_to_save = nnx.split(sharded_model)\n","\n","print(f\"Sharded state to save (bias type): {type(sharded_state_to_save['bias'].value)}\")\n","print(f\"Sharded state to save (bias sharding): {sharded_state_to_save['bias'].value.sharding}\")\n","# The actual arrays in sharded_state_to_save are now GlobalDeviceArrays (or jax.Array with sharding)\n","\n","current_step_sharded = 200\n","# Orbax handles sharded-array saving under the hood\n","mngr_sharded_save.save(current_step_sharded, args=ocp.args.StandardSave(sharded_state_to_save))\n","mngr_sharded_save.wait_until_finished()\n","print(f\"Sharded checkpoint saved for step {current_step_sharded} in {ckpt_dir_ex5}.\")\n","mngr_sharded_save.close()"],"metadata":{"id":"-0Yvg9non-Jw"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Exercise 6: Distributed Checkpointing - Restoring Sharded State\n","\n","**Goal**: Learn to restore sharded model state, which requires providing an abstract state Pytree that includes the target sharding specifications.\n","\n","### Topics:\n","\n","* Creating an abstract model using nnx.eval_shape.\n","* Splitting it to get an abstract state.\n","* Crucial Step: Applying sharding specifications to this abstract state to create a \"sharding-aware template\" or abstract_target. This is often done using jax.lax.with_sharding_constraint or by ensuring the nnx.eval_shape process (if the module itself defines sharding during abstract construction) yields abstract states with correct sharding.\n","* Using StandardRestore with this sharding-aware abstract_target.\n","* Merging the restored sharded state with a graph definition to reconstruct the model.\n","\n","### Instructions:\n","\n","1. Reuse the mesh from Exercise 5.\n","2. Re-open the CheckpointManager pointing to ckpt_dir_ex5.\n","3. Define a function, e.g., create_abstract_sharded_model_for_restore(mesh).\n","* Inside, instantiate your ShardedSimpleLinear module (or a similar one intended for sharded restoration) with the provided mesh. This instantiation should ensure its parameters would be sharded if it were a concrete model.\n","* Pass a lambda creating this module to nnx.eval_shape() to get an abstract_model.\n","* The key is that nnx.split(abstract_model) should yield an abstract_state where leaves corresponding to sharded parameters are ShapeDtypeStructs that already encode the target sharding. This happens if ShardedSimpleLinear's `__init__` uses jax.device_put with NamedSharding on dummy data when nnx.is_abstract_eval() is true, or if NNX's sharding annotation system propagates this to the abstract state.\n","* A more explicit way (if the above is tricky with eval_shape directly embedding sharding into abstract state leaves) is shown in the slides:\n"," 1. abstract_model = nnx.eval_shape(...) for a non-sharded-at-init version.\n"," 2. _graphdef, abstract_state_plain = nnx.split(abstract_model).\n"," 3. Define sharding_specs (Pytree of PartitionSpec).\n"," 4. abstract_target = jax.tree_util.tree_map(lambda x, spec: jax.ShapeDtypeStruct(x.shape, x.dtype, sharding=NamedSharding(mesh, spec)), abstract_state_plain, sharding_specs) OR use jax.lax.with_sharding_constraint(abstract_state_plain, sharding_specs) on the abstract state within a jax.jit and mesh context as shown in slide conceptual code. Let's try to make ShardedSimpleLinear work with eval_shape directly if possible, or fall back to explicit constraint.\n","\n","4. To make ShardedSimpleLinear directly produce an abstract state with sharding during nnx.eval_shape:\n","* Modify ShardedSimpleLinear`.__init__`.\n","* When nnx.is_abstract_eval() is true, instead of jax.device_put(real_data, ...) use jax.ShapeDtypeStruct(shape, dtype, sharding=NamedSharding(mesh, pspec)) for the .value of the nnx.Param.\n","5. Call your function within the mesh context and jax.jit it (as per slides ) to get the abstract_target_state and graphdef_for_restore. graphdef_for_restore, abstract_target_state = nnx.split(nnx.eval_shape(lambda: ShardedSimpleLinear(..., mesh=mesh,...))) (simplified)\n","6. Restore using mngr.restore(step, args=ocp.args.StandardRestore(abstract_target_state)).\n","7. Reconstruct the model using nnx.merge(graphdef_for_restore, restored_sharded_state).\n","8. Verify the sharding of the restored model's parameters.\n","9. Close the manager.\n","\n","Self-correction for instruction 4 & 5: Instead of modifying ShardedSimpleLinear to behave differently under nnx.is_abstract_eval(), it's cleaner and more aligned with typical Orbax/JAX patterns to:\n","a. Get a plain abstract state (shapes/dtypes only) from a version of the model that doesn't try to shard during abstract init.\n","b. Then, explicitly create the abstract_target by adding sharding to this plain abstract state.\n","Let's refine ShardedSimpleLinear to accept an init_sharded flag. For eval_shape, we'll pass init_sharded=False (or rely on nnx.eval_shape not creating real arrays), then apply sharding to the resulting abstract state.\n","\n","A more direct approach for step 5, if the ShardedSimpleLinear from Ex5 is used for eval_shape: nnx.eval_shape will create ShapeDtypeStruct for parameters. If jax.device_put was part of the module's __init__, nnx.eval_shape might not execute it to produce sharded ShapeDtypeStructs directly. The critical part is that abstract_target passed to StandardRestore must have the sharding information.\n","\n","Let's use the method from slide \"Distributed Checkpointing: Restoring Sharded State\":\n","\n","1. abstract_model = nnx.eval_shape(lambda: ModelClass(...)) (ModelClass here doesn't apply sharding during this abstract init).\n","2. _graphdef, abstract_state_struct_only = nnx.split(abstract_model).\n","3. Define sharding_pytree (same Pytree structure as state, but with NamedSharding objects at leaves).\n","4. abstract_target = jax.tree.map(lambda s, n: jax.ShapeDtypeStruct(s.shape, s.dtype, sharding=n), abstract_state_struct_only, sharding_pytree). This abstract_target is then used in StandardRestore."],"metadata":{"id":"9y6cURi8pKIQ"}},{"cell_type":"code","source":["# Ensure ShardedSimpleLinear class definition and mesh from Ex5 are available.\n","\n","# --- Re-open CheckpointManager for Sharded Restore ---\n","# TODO: Instantiate CheckpointManager for ckpt_dir_ex5\n","# mngr_sharded_restore = ...\n","\n","# --- Create Abstract Target State with Sharding Information ---\n","# Method:\n","# 1. Create a \"plain\" abstract model (shapes/dtypes only).\n","# 2. Split it to get graphdef and plain abstract_state.\n","# 3. Define the desired sharding for each parameter (Pytree of NamedSharding).\n","# 4. Combine plain abstract_state with sharding to create the final abstract_target.\n","\n","def create_abstract_model_for_sharded_restore():\n"," # This lambda should instantiate the model structure without applying sharding during this phase.\n"," # We'll use the ShardedSimpleLinear class, but its sharding logic inside __init__\n"," # might be skipped by eval_shape if it involves actual data.\n"," # Alternatively, provide a version of the model that takes sharding specs externally.\n"," # For simplicity, let's assume nnx.eval_shape on ShardedSimpleLinear gives us ShapeDtypeStructs,\n"," # and we will then OVERWRITE their sharding attribute if necessary, or construct them fresh.\n","\n"," # Let's make a 'template' instance of ShardedSimpleLinear just to get its structure via split.\n"," # The actual sharding for the abstract target will be defined explicitly.\n"," temp_rngs = nnx.Rngs(params=jax.random.key(99))\n"," # Create an instance of ShardedSimpleLinear as it was defined in Ex5.\n"," # nnx.eval_shape will trace its construction.\n"," # TODO: abstract_model_proto = nnx.eval_shape(lambda: ShardedSimpleLinear(... pass din_s, dout_s, mesh from Ex5 ...))\n"," # abstract_model_proto = ...\n"," # return abstract_model_proto\n","\n","# Run within mesh context for operations that might interact with sharding\n","# with mesh:\n"," # TODO: Create the abstract_model_proto by calling the function above.\n"," # abstract_model_for_target = create_abstract_model_for_sharded_restore()\n"," # TODO: Split it to get graphdef_for_restore and an abstract_state (which might have None for sharding)\n"," # graphdef_for_restore_sharded, abstract_state_struct_only = ...\n","\n"," # Define the target sharding (PartitionSpecs, then NamedSharding)\n"," # These must match the sharding used when the checkpoint was SAVED.\n"," # weight_pspec_target = PartitionSpec(None, 'data') # As in Ex5\n"," # bias_pspec_target = PartitionSpec('data',) # As in Ex5\n","\n"," # weight_sharding_target = NamedSharding(mesh, weight_pspec_target)\n"," # bias_sharding_target = NamedSharding(mesh, bias_pspec_target)\n","\n"," # Create the sharding pytree for the abstract_target\n"," # It needs to match the structure of abstract_state_struct_only['params'] or similar,\n"," # depending on how ShardedSimpleLinear structures its state.\n"," # Assuming state is flat { 'weight': ..., 'bias': ... } within the nnx.State object.\n"," # If ShardedSimpleLinear created params like self.weight = nnx.Param(...),\n"," # then abstract_state_struct_only will look like {'weight': {'value': ShapeDtypeStruct}, 'bias': {'value': ShapeDtypeStruct}}\n","\n"," # TODO: Construct the `sharding_for_abstract_state` Pytree.\n"," # It should mirror the structure of `abstract_state_struct_only` but contain NamedSharding objects at the leaves\n"," # where parameters are.\n"," # Example if state is {'weight': {'value':...}, 'bias': {'value':...}}:\n"," # sharding_for_abstract_state = {\n"," # 'weight': {'value': weight_sharding_target},\n"," # 'bias': {'value': bias_sharding_target}\n"," # }\n"," # Verify this structure based on print(abstract_state_struct_only) from split.\n","\n"," # TODO: Create the final abstract_target by combining shapes/dtypes with new sharding.\n"," # abstract_target_state = jax.tree.map(\n"," # lambda sds, sh: jax.ShapeDtypeStruct(sds.shape, sds.dtype, sharding=sh) if isinstance(sds, jax.ShapeDtypeStruct) else sds,\n"," # abstract_state_struct_only,\n"," # sharding_for_abstract_state\n"," # )\n"," # print(f\"Abstract target for restore (bias sharding): {abstract_target_state['bias'].value.sharding}\")\n","\n","# --- Restore Sharded State ---\n","# step_to_restore_sharded = mngr_sharded_restore.latest_step()\n","# if step_to_restore_sharded is not None:\n"," # with mesh: # Restoration happens within the mesh context\n"," # TODO: Restore sharded state using abstract_target_state\n"," # restored_sharded_state_dict = mngr_sharded_restore.restore(...)\n","\n"," # TODO: Reconstruct the model using nnx.merge\n"," # reconstructed_sharded_model = ...\n","\n"," # print(f\"Sharded model restored from step {step_to_restore_sharded}.\")\n"," # print(f\"Restored weight sharding: {reconstructed_sharded_model.weight.value.sharding}\")\n"," # print(f\"Restored bias sharding: {reconstructed_sharded_model.bias.value.sharding}\")\n","\n"," # Verification (optional): Compare with sharded_model from Ex5 if it's in scope and has same structure\n"," # chex.assert_trees_all_equal_shapes_and_dtypes(nnx.state(reconstructed_sharded_model), nnx.state(sharded_model))\n"," # assert str(reconstructed_sharded_model.weight.value.sharding) == str(sharded_model.weight.value.sharding)\n","\n","# else:\n","# print(\"No sharded checkpoint found to restore.\")\n","\n","# mngr_sharded_restore.close()"],"metadata":{"id":"8ZinY5nyoXI7"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# @title Exercise 6: Solution\n","\n","# Ensure ShardedSimpleLinear class definition and mesh from Ex5 are available.\n","# din_s, dout_s from Ex5 were: din_s = 8, dout_s = num_devices * 2\n","\n","# --- Re-open CheckpointManager for Sharded Restore ---\n","mngr_sharded_restore = ocp.CheckpointManager(ckpt_dir_ex5)\n","\n","# --- Create Abstract Target State with Sharding Information ---\n","# This follows the principle that the abstract target for restore must contain sharding info\n","\n","def create_abstract_model_for_sharded_restore_eval_shape():\n"," # This lambda is for nnx.eval_shape. It should define the *structure*\n"," # ShardedSimpleLinear's __init__ from Ex5 already creates sharded JAX arrays.\n"," # nnx.eval_shape will trace this. The resulting abstract state's leaves\n"," # should be ShapeDtypeStructs that already reflect the sharding\n"," # because jax.device_put (which includes sharding) is part of its traced __init__.\n"," # This is a more integrated way if the module's __init__ handles sharding for abstract eval.\n"," temp_rngs_for_eval = nnx.Rngs(params=jax.random.key(100)) # Dummy key for eval_shape\n"," # Pass the same mesh instance that will be used for restoration\n"," return ShardedSimpleLinear(din=din_s, dout=dout_s, mesh=mesh, rngs=temp_rngs_for_eval)\n","\n","with mesh: # Operations like eval_shape and restore should be within the mesh context\n"," # Create abstract model using nnx.eval_shape.\n"," # The sharding info should ideally be embedded by ShardedSimpleLinear's __init__\n"," # when traced by nnx.eval_shape, because it uses jax.device_put.\n"," abstract_model_sharded_eval = nnx.eval_shape(create_abstract_model_for_sharded_restore_eval_shape)\n"," # Use the graphdef from the abstract sharded model for merging\n"," graphdef_for_restore_sharded = nnx.split(abstract_model_sharded_eval)[0]\n","\n"," # We need the abstract state structure from the plain model (SimpleLinear)\n"," # because nnx.eval_shape on ShardedSimpleLinear might already put sharding\n"," # in the abstract state, and we want to demonstrate the manual creation\n"," # of the abstract target with sharding.\n"," plain_abstract_model = nnx.eval_shape(lambda: SimpleLinear(din_s, dout_s, rngs=nnx.Rngs(0)))\n"," # This state will have ShapeDtypeStructs, but likely with sharding=None\n"," _gdef_plain, abstract_state_struct_only = nnx.split(plain_abstract_model)\n","\n"," # Define target sharding specs\n"," weight_pspec_target = PartitionSpec(None, 'data') # As in Ex5\n"," bias_pspec_target = PartitionSpec('data',) # As in Ex5\n"," weight_sharding_target = NamedSharding(mesh, weight_pspec_target)\n"," bias_sharding_target = NamedSharding(mesh, bias_pspec_target)\n","\n"," # Create the sharding pytree for the abstract_target\n"," # It needs to match the structure of `abstract_state_struct_only` exactly.\n"," # Since abstract_state_struct_only is {'bias': {'value':ShapeDtypeStruct}, 'weight': {'value':ShapeDtypeStruct}},\n"," # the sharding pytree should mirror this structure, placing NamedSharding at the leaves.\n"," sharding_pytree_for_target = nnx.State({\n"," 'weight': nnx.VariableState(type=nnx.Param, value=weight_sharding_target),\n"," 'bias': nnx.VariableState(type=nnx.Param, value=bias_sharding_target)\n"," })\n","\n"," # Create the final abstract_target by mapping over the structure of\n"," # abstract_state_struct_only and sharding_pytree_for_target.\n"," # We want to replace the ShapeDtypeStruct in abstract_state_struct_only.value\n"," # with a new ShapeDtypeStruct that includes the sharding from sharding_pytree_for_target.value.\n","\n"," # Define a function that takes two VariableState objects\n"," def update_variable_state_sharding(sds_variable_state: nnx.VariableState, sharding_variable_state: nnx.VariableState):\n"," if isinstance(sds_variable_state, jax.ShapeDtypeStruct):\n"," # Create a new ShapeDtypeStruct with the desired sharding\n"," new_sds = jax.ShapeDtypeStruct(sds_variable_state.shape, sds_variable_state.dtype, sharding=sharding_variable_state)\n"," # Return a new VariableState with the updated value\n"," return new_sds\n"," else:\n"," # If the value is not a ShapeDtypeStruct (e.g., metadata), keep it as is\n"," # In this specific case, this path might not be strictly needed if abstract_state_struct_only\n"," # only contains VariableState with ShapeDtypeStruct values at the leaves we care about.\n"," return sds_variable_state\n","\n"," # Map this function over the two pytrees. Use a custom is_leaf to map at the VariableState level.\n"," # This ensures the mapping function receives (VariableState, VariableState containing sharding) pairs.\n"," def is_variable_state_node(x):\n"," # Treat VariableState itself as a node (not a leaf) so mapping happens inside it\n"," return not isinstance(x, nnx.VariableState)\n","\n"," # Apply the mapping. The lambda receives items from corresponding positions in both trees.\n"," # Here, lambda `sds_node` is a VariableState from `abstract_state_struct_only`,\n"," # and lambda `sharding_node` is a VariableState from `sharding_pytree_for_target`.\n"," abstract_target_state = jax.tree.map(\n"," update_variable_state_sharding,\n"," abstract_state_struct_only, # This tree has ShapeDtypeStructs nested in VariableState.value\n"," sharding_pytree_for_target # This tree has NamedSharding objects nested in VariableState.value\n"," )\n","\n"," print(f\"Abstract target for restore (bias sharding): {abstract_target_state['bias'].value.sharding}\")\n"," print(f\"Abstract target for restore (weight sharding): {abstract_target_state['weight'].value.sharding}\")\n","\n","# --- Restore Sharded State ---\n","step_to_restore_sharded = mngr_sharded_restore.latest_step()\n","if step_to_restore_sharded is not None:\n"," with mesh: # Restoration happens within the mesh context\n"," # Use StandardRestore with the abstract_target that includes sharding info\n"," restored_sharded_state_dict = mngr_sharded_restore.restore(\n"," step_to_restore_sharded,\n"," args=ocp.args.StandardRestore(abstract_target_state)\n"," )\n","\n"," # Reconstruct the model using nnx.merge\n"," # Use the graphdef obtained from splitting the abstract sharded model\n"," reconstructed_sharded_model = nnx.merge(graphdef_for_restore_sharded, restored_sharded_state_dict)\n","\n"," print(f\"Sharded model restored from step {step_to_restore_sharded}.\")\n"," print(f\"Restored weight sharding: {reconstructed_sharded_model.weight.value.sharding}\")\n"," print(f\"Restored bias sharding: {reconstructed_sharded_model.bias.value.sharding}\")\n","\n"," # Verification\n"," if 'sharded_model' in globals(): # If sharded_model from Ex5 is available\n"," # Compare structure and dtypes\n"," chex.assert_trees_all_equal_structs(nnx.state(reconstructed_sharded_model), nnx.state(sharded_model))\n"," # Compare sharding\n"," assert str(reconstructed_sharded_model.weight.value.sharding) == str(sharded_model.weight.value.sharding)\n"," assert str(reconstructed_sharded_model.bias.value.sharding) == str(sharded_model.bias.value.sharding)\n"," # Compare values - this will involve communication due to sharding\n"," chex.assert_trees_all_close(nnx.state(reconstructed_sharded_model), nnx.state(sharded_model))\n"," print(\"Verification of sharding, structure, and values successful.\")\n","else:\n"," print(\"No sharded checkpoint found to restore.\")\n","\n","mngr_sharded_restore.close()"],"metadata":{"id":"fu1XapHPrVbK"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Advanced Orbax Features & Best Practices (Brief Overview)\n","\n","There are also some more advanced Orbax features. While we won't do full coding exercises for these in this notebook, it's good to be aware of them:\n","\n","* **Asynchronous Checkpointing**: manager.save() can operate in the background. Use manager.wait_until_finished() before your program exits or if you need to use the checkpoint immediately. This improves training throughput by not blocking the main training loop. Our examples used wait_until_finished().\n","\n","* **Atomicity**: CheckpointManager ensures that checkpoints are saved atomically. This means you won't get corrupted checkpoints if your training job crashes mid-save. This is handled for you by Orbax.\n","\n","* **Saving Non-Pytree Data (Metadata)**: Sometimes you need to save extra information like training configuration, dataset iterators, or model version. You can use ocp.args.JsonSave within ocp.args.Composite to save dictionary-like data as JSON alongside your model Pytrees. Restoration uses ocp.args.JsonRestore.\n","\n","### Example Concept:\n","\n","```\n","metadata = {'version': '1.0', 'dataset_info': 'imagenet_split_train'}\n","save_args = ocp.args.Composite(\n"," params=ocp.args.StandardSave(params_state),\n"," metadata=ocp.args.JsonSave(metadata)\n",")\n","mngr.save(step, args=save_args)\n","```\n","\n","* **TensorStore Backend**: For extremely large models or when working with cloud storage, Orbax can use TensorStore. This backend allows for more efficient, potentially parallel I/O for individual array shards, often transparently. This is usually configured at a lower level or might be default in certain JAX environments.\n","\n","### Key Takeaways:\n","\n","* Flax NNX offers a stateful, Pythonic way to define models.\n","* Orbax is the standard for checkpointing NNX State Pytrees.\n","* The general workflow:\n"," - **Saving**: nnx.split -> mngr.save.\n"," - **Restoring**: nnx.eval_shape -> Get abstract_state -> mngr.restore -> nnx.merge or nnx.update.\n","* CheckpointManager is your friend for managing multiple checkpoints.\n","* Use ocp.args.Composite for saving multiple distinct items (e.g., model parameters + optimizer state).\n","* For sharded (distributed) data, ensuring your abstract_target for restoration correctly specifies the target sharding is crucial. StandardRestore handles this if the abstract target has the sharding info.\n","\n","### Congratulations!\n","You've now worked through the fundamentals of checkpointing Flax NNX models with Orbax, from basic saving and restoring to handling optimizer states and distributed (sharded) scenarios.\n","\n","Remember to consult the official documentation for more in-depth details:\n","\n","* Orbax: https://orbax.readthedocs.io\n","* Flax NNX: (Part of the Flax documentation) https://flax.readthedocs.io\n","* JAX: https://jax.readthedocs.io\n","Keep practicing, and happy JAXing!\n","\n","Please send us feedback at https://goo.gle/jax-training-feedback"],"metadata":{"id":"lIT5kKJF15Ew"}},{"cell_type":"code","source":[],"metadata":{"id":"EVAtnk_P3Gk0"},"execution_count":null,"outputs":[]}]} \ No newline at end of file diff --git a/docs/learning_jax/code-exercises/9 - Serving JAX with vLLM_SGLang.ipynb b/docs/learning_jax/code-exercises/9 - Serving JAX with vLLM_SGLang.ipynb new file mode 100644 index 0000000..06e097a --- /dev/null +++ b/docs/learning_jax/code-exercises/9 - Serving JAX with vLLM_SGLang.ipynb @@ -0,0 +1 @@ +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[{"file_id":"1wU06hEOn87VZwNKG2c5E-lQpv4LII3y2","timestamp":1755114036632},{"file_id":"1vUmOju_8clAPQ4M0aI0PHwypAIE15dIk","timestamp":1743983895983}],"gpuType":"T4","machine_shape":"hm"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","source":["# Serving a JAX model with vLLM and GPU\n","\n","This notebook shows a simple workflow from a model which is loaded from Hugging Face into JAX, and then served using vLLM. For brevity we leave out the actual fine-tuning or other alterations in JAX, since this is covered in other tutorials. This is right on the edge of what can be done in a free Colab GPU instance, so we restart before installing vLLM to free up memory. As a bonus, this notebook contains a JAX implementation of a Llama 3.2 model, which can be interesting by itself."],"metadata":{"id":"kFSEN6lZVPJ8"}},{"cell_type":"markdown","source":["# Do all the Pips\n","Let's get the downloads out of the way."],"metadata":{"id":"otDKAxr1_Cm7"}},{"cell_type":"code","source":["!pip install -Uq jax[cuda] flax # Install the JAX AI Stack for GPU\n","!pip install -q vllm # We'll need it later"],"metadata":{"id":"DU73PpuC-1Nf"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# Hugging Face"],"metadata":{"id":"RTQtpU7GCsJX"}},{"cell_type":"markdown","source":["## Download the model from Hugging Face\n","\n","We'll download the model weights in Safetensors format."],"metadata":{"id":"qKOJ3dLAMWho"}},{"cell_type":"code","source":["!huggingface-cli login"],"metadata":{"id":"g_VCX510CpNH"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["import os\n","from huggingface_hub import snapshot_download\n","\n","model_id = \"meta-llama/Llama-3.2-1B\"\n","path_to_model_weights = os.path.join('/content', model_id)\n","\n","snapshot_download(repo_id=model_id, local_dir=path_to_model_weights)"],"metadata":{"id":"fA1VIuSXeU8L"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# Load the weights from the Safetensors file in Flax format\n","\n","import jax\n","from pathlib import Path\n","from safetensors import safe_open\n","\n","def load_safetensors():\n"," weights = {}\n"," safetensors_files = Path(path_to_model_weights).glob('*.safetensors')\n","\n"," for file in safetensors_files:\n"," with safe_open(file, framework=\"flax\") as f:\n"," for key in f.keys():\n"," print(f\"Loading {key}\")\n"," weights[key] = f.get_tensor(key)\n"," return weights\n","\n","weights = load_safetensors()"],"metadata":{"id":"VQxt8nEsekn2"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# Llama 3.2-1B JAX Implementation"],"metadata":{"id":"qHz4Zkd6bpib"}},{"cell_type":"code","source":["# # Install the JAX AI Stack for GPU\n","# !pip install -q jax[cuda] jax-ai-stack\n","\n","import jax\n","print(jax.devices())\n","print(jax.__version__)"],"metadata":{"id":"PJvcveFoadvw"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from flax import nnx\n","from dataclasses import dataclass\n","import jax.numpy as jnp\n","\n","@dataclass\n","class LlamaConfig:\n"," def __init__(self):\n"," self.dim = 2048\n"," self.n_layers = 16\n"," self.n_heads = 32\n"," self.n_kv_heads = 8\n"," self.head_dim = self.dim // self.n_heads\n"," self.intermediate_size = 14336\n"," self.vocab_size = 128256\n"," self.multiple_of = 256\n"," self.norm_eps = 1e-05\n"," self.rope_theta = 500000.0\n","\n","config = LlamaConfig()\n","\n","class LlamaRMSNorm(nnx.Module):\n","\n"," def __init__(self, dim: int, rngs=None):\n"," self.norm_weights = nnx.Param(jnp.zeros((dim,), dtype=jnp.bfloat16))\n","\n"," @nnx.jit()\n"," def __call__(self, hidden_states):\n"," input_dtype = hidden_states.dtype\n"," hidden_states = hidden_states.astype(jnp.float32)\n"," squared_mean = jnp.mean(jnp.square(hidden_states), axis=-1, keepdims=True)\n"," hidden_states = hidden_states * jnp.reciprocal(jnp.sqrt(squared_mean + config.norm_eps))\n"," return self.norm_weights * hidden_states.astype(input_dtype)\n","\n","class LlamaRotaryEmbedding(nnx.Module):\n","\n"," def __init__(self, dim, base=10000, rngs=None):\n"," self.dim = dim\n"," self.base = base\n","\n"," @nnx.jit()\n"," def __call__(self, position_ids):\n"," inv_freq = 1.0 / (self.base ** (jnp.arange(0, self.dim, 2, dtype=jnp.float32) / self.dim))\n"," inv_freq_expanded = jnp.expand_dims(inv_freq, axis=(0, 1))\n"," position_ids_expanded = jnp.expand_dims(position_ids, axis=(0, 2)).astype(jnp.float32)\n"," freqs = jnp.einsum('bij,bjk->bijk', position_ids_expanded, inv_freq_expanded)\n"," emb = jnp.concatenate([freqs, freqs], axis=-1)\n"," cos = jnp.cos(emb).squeeze(2).astype(jnp.bfloat16)\n"," sin = jnp.sin(emb).squeeze(2).astype(jnp.bfloat16)\n"," return cos, sin\n","\n","class LlamaAttention(nnx.Module):\n","\n"," def __init__(self, layer_idx, rngs=None):\n"," self.q_proj = nnx.Linear(config.dim, config.n_heads * config.head_dim, use_bias=False, rngs=rngs, param_dtype=jnp.bfloat16)\n"," self.k_proj = nnx.Linear(config.dim, config.n_kv_heads * config.head_dim, use_bias=False, rngs=rngs, param_dtype=jnp.bfloat16)\n"," self.v_proj = nnx.Linear(config.dim, config.n_kv_heads * config.head_dim, use_bias=False, rngs=rngs, param_dtype=jnp.bfloat16)\n"," self.o_proj = nnx.Linear(config.n_heads * config.head_dim, config.dim, use_bias=False, rngs=rngs, param_dtype=jnp.bfloat16)\n"," self.rotary_emb = LlamaRotaryEmbedding(config.head_dim, base=config.rope_theta, rngs=rngs)\n","\n"," # Alternative implementation:\n"," # https://github.com/google/flax/blob/5d896bc1a2c68e2099d147cd2bc18ebb6a46a0bd/examples/gemma/positional_embeddings.py#L45\n"," def apply_rotary_pos_emb(self, q, k, cos, sin, unsqueeze_dim=1):\n"," cos = jnp.expand_dims(cos, axis=unsqueeze_dim)\n"," sin = jnp.expand_dims(sin, axis=unsqueeze_dim)\n"," q_embed = (q * cos) + (self.rotate_half(q) * sin)\n"," k_embed = (k * cos) + (self.rotate_half(k) * sin)\n"," return q_embed, k_embed\n","\n"," def rotate_half(self, x):\n"," x1 = x[..., : x.shape[-1] // 2]\n"," x2 = x[..., x.shape[-1] // 2 :]\n"," return jnp.concatenate([-x2, x1], axis=-1)\n","\n"," def repeat_kv(self, hidden_states, n_repeat):\n"," batch, n_kv_heads, seq_len, head_dim = hidden_states.shape\n"," if n_repeat == 1:\n"," return hidden_states\n"," hidden_states = hidden_states[:, :, None, :, :].repeat(n_repeat, axis=2)\n"," return hidden_states.reshape(batch, n_kv_heads * n_repeat, seq_len, head_dim)\n","\n"," @nnx.jit()\n"," def __call__(self, x, position_ids):\n"," batch_size, seq_len, _ = x.shape\n"," query = self.q_proj(x).reshape(batch_size, seq_len, config.n_heads, config.head_dim).transpose((0, 2, 1, 3))\n"," key = self.k_proj(x).reshape(batch_size, seq_len, config.n_kv_heads, config.head_dim).transpose((0, 2, 1, 3))\n"," value = self.v_proj(x).reshape(batch_size, seq_len, config.n_kv_heads, config.head_dim).transpose((0, 2, 1, 3))\n"," # Assuming batch_size=1\n"," cos, sin = self.rotary_emb(position_ids[0])\n"," query, key = self.apply_rotary_pos_emb(query, key, cos, sin)\n","\n"," key = self.repeat_kv(key, config.n_heads // config.n_kv_heads)\n"," value = self.repeat_kv(value, config.n_heads // config.n_kv_heads)\n","\n"," attn_weights = jnp.matmul(query, jnp.transpose(key, (0, 1, 3, 2)))\n"," attn_weights = (attn_weights.astype(jnp.float32) / jnp.sqrt(config.head_dim)).astype(jnp.bfloat16)\n"," attn_weights = jax.nn.softmax(attn_weights.astype(jnp.float32), axis=-1).astype(jnp.bfloat16)\n"," attn_output = jnp.matmul(attn_weights, value).transpose((0, 2, 1, 3)).reshape(batch_size, seq_len, -1)\n"," output = self.o_proj(attn_output)\n"," return output\n","\n","class LlamaMLP(nnx.Module):\n","\n"," def __init__(self, layer_idx, rngs=None):\n"," self.gate_proj = nnx.Linear(config.dim, config.intermediate_size, use_bias=False, rngs=rngs, param_dtype=jnp.bfloat16)\n"," self.up_proj = nnx.Linear(config.dim, config.intermediate_size, use_bias=False, rngs=rngs, param_dtype=jnp.bfloat16)\n"," self.down_proj = nnx.Linear(config.intermediate_size, config.dim, use_bias=False, rngs=rngs, param_dtype=jnp.bfloat16)\n","\n"," @nnx.jit()\n"," def __call__(self, x):\n"," return self.down_proj(jax.nn.silu(self.gate_proj(x)) * self.up_proj(x))\n","\n","class LlamaTransformerBlock(nnx.Module):\n","\n"," def __init__(self, layer_idx, rngs=None):\n"," self.input_layernorm = LlamaRMSNorm(dim=config.dim, rngs=rngs)\n"," self.attention = LlamaAttention(layer_idx=layer_idx, rngs=rngs)\n"," self.post_attention_layernorm = LlamaRMSNorm(dim=config.dim, rngs=rngs)\n"," self.mlp = LlamaMLP(layer_idx=layer_idx, rngs=rngs)\n","\n"," @nnx.jit()\n"," def __call__(self, x, position_ids):\n"," residual = x\n"," x = self.input_layernorm(x)\n"," x = self.attention(x, position_ids)\n"," x = residual + x\n","\n"," residual = x\n"," x = self.post_attention_layernorm(x)\n"," x = self.mlp(x)\n"," x = residual + x\n"," return x\n","\n","class LlamaForCausalLM(nnx.Module):\n","\n"," def __init__(self, rngs=None):\n"," self.token_embed = nnx.Embed(num_embeddings=config.vocab_size, features=config.dim, param_dtype=jnp.bfloat16, rngs=rngs)\n","\n"," self.layers = [LlamaTransformerBlock(layer_idx=idx, rngs=rngs) for idx in range(config.n_layers)]\n"," self.lm_head = nnx.Linear(config.dim, config.vocab_size, use_bias=False, rngs=rngs, param_dtype=jnp.bfloat16)\n"," self.norm = LlamaRMSNorm(dim=config.head_dim, rngs=rngs)\n","\n"," @nnx.jit()\n"," def __call__(self, input_ids, position_ids):\n"," assert input_ids.shape[0] == 1, \"Only batch size 1 is supported\"\n"," x = self.token_embed(input_ids)\n"," for layer in self.layers:\n"," x = layer(x, position_ids)\n"," x = self.norm(x)\n"," logits = self.lm_head(x)\n"," return logits"],"metadata":{"id":"QztvReU0T40B"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["model = LlamaForCausalLM(rngs=nnx.Rngs(0))\n","state = nnx.state(model)\n","nnx.display(state) # This can be very useful"],"metadata":{"id":"dYzGVw1fLT3l"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# Map the PyTorch weights to Flax NNX\n","\n","Because of differences in the layer definitions between PyTorch and JAX/Flax NNX we need to alter the shapes of some of the weights. Here's a quick summary:\n","\n","* **Linear (FC)**: Transpose\n","* **Convolutions**: Transpose from `[outC, inC, kH, kW]` to `[kH, kW, inC, outC]`\n","```\n","# [outC, inC, kH, kW] -> [kH, kW, inC, outC]\n","kernel = jnp.transpose(kernel, (2, 3, 1, 0))\n","```\n","\n","* **Convolutions and FC Layers**:\n","We have to be careful, when we have a model that uses convolutions followed by fc layers (ResNet, VGG, etc). In PyTorch, the activations will have shape [N, C, H, W] after the convolutions and are then reshaped to [N, C * H * W] before being fed to the fc layers. When we port our weights from PyTorch to Flax, the activations after the convolutions will be of shape [N, H, W, C] in Flax. Before we reshape the activations for the fc layers, we have to transpose them to [N, C, H, W].\n","\n","* **BatchNorm**: No change"],"metadata":{"id":"L3Nb2DUsgg3y"}},{"cell_type":"code","source":["# This is specific to the format of a Hugging Face Llama 3.2 checkpoint\n","\n","def update_from_HF_checkpoint(state: nnx.State, weights: dict) -> None:\n"," for wholekey in weights:\n"," keys = wholekey.split('.')\n"," if keys[1] == 'layers':\n"," if keys[3] == 'self_attn':\n"," keys[3] = 'attention'\n"," if keys[1] == 'layers' and keys[3] == 'attention':\n"," state['layers'][int(keys[2])][keys[3]][keys[4]]['kernel'].value = weights[wholekey].T\n"," elif keys[1] == 'layers' and keys[3] == 'mlp':\n"," state['layers'][int(keys[2])][keys[3]][keys[4]]['kernel'].value = weights[wholekey].T\n"," elif keys[1] == 'layers' and keys[3] == 'input_layernorm':\n"," state['layers'][int(keys[2])][keys[3]]['norm_weights'].value = weights[wholekey]\n"," elif keys[1] == 'layers' and keys[3] == 'post_attention_layernorm':\n"," state['layers'][int(keys[2])][keys[3]]['norm_weights'].value = weights[wholekey]\n"," elif keys[1] == 'embed_tokens':\n"," state['token_embed'].embedding.value = weights[wholekey]\n"," state['lm_head'].kernel.value = weights[wholekey].T\n"," elif keys[1] == 'norm':\n"," state['norm'].norm_weights.value = weights[wholekey]\n","\n","update_from_HF_checkpoint(state, weights)\n","nnx.update(model, state)\n","# nnx.display(state)"],"metadata":{"id":"GJXOy1QQc5Yi"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from transformers import AutoTokenizer\n","\n","tokenizer = AutoTokenizer.from_pretrained(model_id)\n","input_text = \"The capital of Japan is\"\n","\n","input_ids = tokenizer(input_text, return_tensors=\"jax\")[\"input_ids\"]\n","position_ids = jnp.asarray([jnp.arange(input_ids.shape[1])])\n","\n","for _ in range(15):\n"," logits = model(input_ids, position_ids)\n"," next_token = jnp.argmax(logits[:, -1, :], axis=-1)\n"," input_ids = jnp.concatenate([input_ids, next_token[:, None]], axis=1)\n"," position_ids = jnp.asarray([jnp.arange(input_ids.shape[1])])\n"," print(f\"Generated token: {next_token[0]}\")\n","\n","print(tokenizer.decode(input_ids[0]))"],"metadata":{"id":"x3fPqvgfrdIT"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# Get the updated model for serving\n","\n","We loaded up our JAX model, and although we didn't make any changes to it in this notebook, in real life we may have done some fine-tuning, alignment, etc. So now we need to get our updated model so that we can serve it with vLLM."],"metadata":{"id":"j7SlJncxVW7S"}},{"cell_type":"code","source":["state = nnx.state(model) # We already have it, but just to illustrate\n","\n","# This is specific to the format of a Hugging Face Llama 3.2 checkpoint\n","\n","def model_state_to_HF_weights(state: nnx.State) -> dict:\n"," global weights\n","\n"," weights_dict = {}\n"," weights_dict['model.embed_tokens.weight'] = state['token_embed'].embedding.value\n"," weights_dict['model.norm.weight'] = state['norm'].norm_weights.value\n","\n"," for idx, layer in enumerate(state['layers'].values()):\n"," weights_dict[f'model.layers.{idx}.input_layernorm.weight'] = layer['input_layernorm'].norm_weights.value\n"," weights_dict[f'model.layers.{idx}.post_attention_layernorm.weight'] = layer['post_attention_layernorm'].norm_weights.value\n"," weights_dict[f'model.layers.{idx}.self_attn.k_proj.weight'] = layer['attention']['k_proj'].kernel.value.T\n"," weights_dict[f'model.layers.{idx}.self_attn.o_proj.weight'] = layer['attention']['o_proj'].kernel.value.T\n"," weights_dict[f'model.layers.{idx}.self_attn.q_proj.weight'] = layer['attention']['q_proj'].kernel.value.T\n"," weights_dict[f'model.layers.{idx}.self_attn.v_proj.weight'] = layer['attention']['v_proj'].kernel.value.T\n"," weights_dict[f'model.layers.{idx}.mlp.down_proj.weight'] = layer['mlp']['down_proj'].kernel.value.T\n"," weights_dict[f'model.layers.{idx}.mlp.gate_proj.weight'] = layer['mlp']['gate_proj'].kernel.value.T\n"," weights_dict[f'model.layers.{idx}.mlp.up_proj.weight'] = layer['mlp']['up_proj'].kernel.value.T\n"," return weights_dict\n","\n","new_weights = model_state_to_HF_weights(state)"],"metadata":{"id":"Z2C6bWvHHulx"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# Now convert the new weights back to Safetensors in preparation for serving"],"metadata":{"id":"VMgtTCefWpz3"}},{"cell_type":"code","source":["import torch\n","import numpy as np\n","\n","# vLLM wants the weight dictionary flattened\n","def flatten_weight_dict(torch_params, prefix=\"\"):\n"," flat_params = {}\n"," for key, value in torch_params.items():\n"," new_key = f\"{prefix}{key}\" if prefix else key\n"," if isinstance(value, dict):\n"," flat_params.update(flatten_weight_dict(value, new_key + \".\"))\n"," else:\n"," flat_params[new_key] = value\n"," return flat_params\n","\n","servable_weights = flatten_weight_dict(new_weights)"],"metadata":{"id":"zR1ycZh10eri"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# Replace the old model with the new model. Note that we could also\n","# keep the old and save the new model to a new directory\n","from safetensors.flax import save_file\n","save_file(servable_weights, path_to_model_weights + '/model.safetensors')"],"metadata":{"id":"zLEjIsEm5Tud"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# Serving with vLLM"],"metadata":{"id":"zaGqNbxag0wY"}},{"cell_type":"markdown","source":["## Runtime > Restart session to free memory\n","\n","We're right on the edge of our GPU memory for a T4 Colab instance."],"metadata":{"id":"3J9dQW0s7xjS"}},{"cell_type":"markdown","source":["# Which models can you serve with vLLM?\n","\n","While safetensors is a required format for the model's weights, vLLM has two other critical requirements that determine compatibility.\n","\n","## Model Architecture is Key\n","The most important factor is the model's architecture. vLLM achieves its high speed by using custom, highly-optimized compute kernels for specific transformer architectures (like Llama, Mixtral, Gemma, Phi-3, etc.).\n","\n","**Supported Architectures Only:** If the model's architecture is not on vLLM's list of supported models, vLLM will not know how to load or run it, regardless of the file format. However vLLM can also run custom models, see below.\n","\n","**Checking Compatibility:** You can check a model's architecture in its config.json file under the \"architectures\" or \"model_type\" field and compare it against the [official vLLM supported models list](https://docs.vllm.ai/en/latest/models/supported_models.html).\n","\n","## More Than Just Weights\n","A `.safetensors` file only contains the model's weights (the numerical parameters). To function, a model also needs its configuration and tokenizer files. When you point vLLM to a model, it expects a complete directory (or a Hugging Face repository identifier) that includes:\n","\n","* `config.json`: Defines the model's architecture, size, and other essential parameters. vLLM reads this first to check for compatibility.\n","\n","* `tokenizer.json` (and related files): Defines how to convert text into tokens that the model can understand.\n","\n","* `model.safetensors` (or sharded versions): The file(s) containing the actual model weights.\n","\n","## Can I serve a model not in the supported models list?\n","\n","Yes! Check out the [instructions here](https://docs.vllm.ai/en/latest/models/supported_models.html#custom-models to learn how to serve custom models."],"metadata":{"id":"p4OpfaXnCd61"}},{"cell_type":"markdown","source":["## CUDA"],"metadata":{"id":"hc799D81GSVJ"}},{"cell_type":"code","source":["!nvcc --version"],"metadata":{"id":"93aHdF7FDk5-"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["%env CUDA_HOME=/usr/local/cuda-12.5"],"metadata":{"id":"dXlgdHqiF-UL"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Install vLLM"],"metadata":{"id":"fTMbKW2HCUST"}},{"cell_type":"code","execution_count":null,"metadata":{"id":"So9dZbyUCSo0"},"outputs":[],"source":["!pip install -q vllm"]},{"cell_type":"markdown","source":["## Serve the model with vLLM"],"metadata":{"id":"o_TlPnQ9udhj"}},{"cell_type":"code","source":["# Need to restore these after restarting the session\n","import os\n","\n","model_id = \"meta-llama/Llama-3.2-1B\"\n","path_to_model_weights = os.path.join('/content', model_id)"],"metadata":{"id":"4sK7H72b3Bl2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# Load the model into vLLM\n","from vllm import LLM, SamplingParams\n","\n","llm = LLM(model=path_to_model_weights, load_format=\"safetensors\", dtype=\"half\")"],"metadata":{"id":"SdKRic9TCjwI"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["prompts = [\n"," \"Hello, my name is\",\n"," \"The president of the United States is\",\n"," \"The capital of France is\",\n"," \"The future of AI is\",\n","]\n","\n","sampling_params = SamplingParams(temperature=0.8, top_p=0.95)\n","\n","outputs = llm.generate(prompts, sampling_params)\n","for output in outputs:\n"," prompt = output.prompt\n"," generated_text = output.outputs[0].text\n"," print(\"===============================\")\n"," print(f\"Prompt: {prompt}\\nGenerated text: {generated_text}\")\n"],"metadata":{"id":"d1E_AgaUcI-C"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":[],"metadata":{"id":"wpXHSzjPEk_u"},"execution_count":null,"outputs":[]}]} \ No newline at end of file diff --git a/docs/learning_jax/quick-references/Chex Quick Reference.pdf b/docs/learning_jax/quick-references/Chex Quick Reference.pdf new file mode 100644 index 0000000..83ef152 Binary files /dev/null and b/docs/learning_jax/quick-references/Chex Quick Reference.pdf differ diff --git a/docs/learning_jax/quick-references/Flax NNX & JAX Quick Reference.pdf b/docs/learning_jax/quick-references/Flax NNX & JAX Quick Reference.pdf new file mode 100644 index 0000000..7d98e6d Binary files /dev/null and b/docs/learning_jax/quick-references/Flax NNX & JAX Quick Reference.pdf differ diff --git a/docs/learning_jax/quick-references/Grain Quick Reference for JAX_Flax NNX.pdf b/docs/learning_jax/quick-references/Grain Quick Reference for JAX_Flax NNX.pdf new file mode 100644 index 0000000..7147c58 Binary files /dev/null and b/docs/learning_jax/quick-references/Grain Quick Reference for JAX_Flax NNX.pdf differ diff --git a/docs/learning_jax/quick-references/JAX & Flax NNX Debugging Quick Reference.pdf b/docs/learning_jax/quick-references/JAX & Flax NNX Debugging Quick Reference.pdf new file mode 100644 index 0000000..12d63a2 Binary files /dev/null and b/docs/learning_jax/quick-references/JAX & Flax NNX Debugging Quick Reference.pdf differ diff --git a/docs/learning_jax/quick-references/JAX AI Stack Quick Reference.pdf b/docs/learning_jax/quick-references/JAX AI Stack Quick Reference.pdf new file mode 100644 index 0000000..44aaaa5 Binary files /dev/null and b/docs/learning_jax/quick-references/JAX AI Stack Quick Reference.pdf differ diff --git a/docs/learning_jax/quick-references/JAX NumPy Quick Reference.pdf b/docs/learning_jax/quick-references/JAX NumPy Quick Reference.pdf new file mode 100644 index 0000000..062dab6 Binary files /dev/null and b/docs/learning_jax/quick-references/JAX NumPy Quick Reference.pdf differ diff --git a/docs/learning_jax/quick-references/JAX Serving with vLLM & SGLang.pdf b/docs/learning_jax/quick-references/JAX Serving with vLLM & SGLang.pdf new file mode 100644 index 0000000..e2b939c Binary files /dev/null and b/docs/learning_jax/quick-references/JAX Serving with vLLM & SGLang.pdf differ diff --git a/docs/learning_jax/quick-references/JAX Sharding & Parallelism with Flax NNX.pdf b/docs/learning_jax/quick-references/JAX Sharding & Parallelism with Flax NNX.pdf new file mode 100644 index 0000000..7f18902 Binary files /dev/null and b/docs/learning_jax/quick-references/JAX Sharding & Parallelism with Flax NNX.pdf differ diff --git a/docs/learning_jax/quick-references/Optax & Flax NNX Quick Reference.pdf b/docs/learning_jax/quick-references/Optax & Flax NNX Quick Reference.pdf new file mode 100644 index 0000000..d9103e2 Binary files /dev/null and b/docs/learning_jax/quick-references/Optax & Flax NNX Quick Reference.pdf differ diff --git a/docs/learning_jax/quick-references/Orbax & Flax NNX Checkpointing_ Quick Reference.pdf b/docs/learning_jax/quick-references/Orbax & Flax NNX Checkpointing_ Quick Reference.pdf new file mode 100644 index 0000000..7371ef4 Binary files /dev/null and b/docs/learning_jax/quick-references/Orbax & Flax NNX Checkpointing_ Quick Reference.pdf differ diff --git a/docs/learning_jax/slide-decks/1 - JAX AI Stack.pdf b/docs/learning_jax/slide-decks/1 - JAX AI Stack.pdf new file mode 100644 index 0000000..e289b73 Binary files /dev/null and b/docs/learning_jax/slide-decks/1 - JAX AI Stack.pdf differ diff --git a/docs/learning_jax/slide-decks/1 - JAX AI Stack.pptx b/docs/learning_jax/slide-decks/1 - JAX AI Stack.pptx new file mode 100644 index 0000000..8179a2d Binary files /dev/null and b/docs/learning_jax/slide-decks/1 - JAX AI Stack.pptx differ diff --git a/docs/learning_jax/slide-decks/10 - Sharding & Parallelism.pdf b/docs/learning_jax/slide-decks/10 - Sharding & Parallelism.pdf new file mode 100644 index 0000000..f9fd25e Binary files /dev/null and b/docs/learning_jax/slide-decks/10 - Sharding & Parallelism.pdf differ diff --git a/docs/learning_jax/slide-decks/10 - Sharding & Parallelism.pptx b/docs/learning_jax/slide-decks/10 - Sharding & Parallelism.pptx new file mode 100644 index 0000000..4144cde Binary files /dev/null and b/docs/learning_jax/slide-decks/10 - Sharding & Parallelism.pptx differ diff --git a/docs/learning_jax/slide-decks/11 - Optax optimizers.pdf b/docs/learning_jax/slide-decks/11 - Optax optimizers.pdf new file mode 100644 index 0000000..7a41639 Binary files /dev/null and b/docs/learning_jax/slide-decks/11 - Optax optimizers.pdf differ diff --git a/docs/learning_jax/slide-decks/11 - Optax optimizers.pptx b/docs/learning_jax/slide-decks/11 - Optax optimizers.pptx new file mode 100644 index 0000000..7a5556b Binary files /dev/null and b/docs/learning_jax/slide-decks/11 - Optax optimizers.pptx differ diff --git a/docs/learning_jax/slide-decks/12 - Conclusion.pdf b/docs/learning_jax/slide-decks/12 - Conclusion.pdf new file mode 100644 index 0000000..0abf6ce Binary files /dev/null and b/docs/learning_jax/slide-decks/12 - Conclusion.pdf differ diff --git a/docs/learning_jax/slide-decks/12 - Conclusion.pptx b/docs/learning_jax/slide-decks/12 - Conclusion.pptx new file mode 100644 index 0000000..da63626 Binary files /dev/null and b/docs/learning_jax/slide-decks/12 - Conclusion.pptx differ diff --git a/docs/learning_jax/slide-decks/2 - NumPy and JAX NumPy.pdf b/docs/learning_jax/slide-decks/2 - NumPy and JAX NumPy.pdf new file mode 100644 index 0000000..00e7a2e Binary files /dev/null and b/docs/learning_jax/slide-decks/2 - NumPy and JAX NumPy.pdf differ diff --git a/docs/learning_jax/slide-decks/2 - NumPy and JAX NumPy.pptx b/docs/learning_jax/slide-decks/2 - NumPy and JAX NumPy.pptx new file mode 100644 index 0000000..cbde46b Binary files /dev/null and b/docs/learning_jax/slide-decks/2 - NumPy and JAX NumPy.pptx differ diff --git a/docs/learning_jax/slide-decks/3 - Intro to NNX for PyTorch users.pdf b/docs/learning_jax/slide-decks/3 - Intro to NNX for PyTorch users.pdf new file mode 100644 index 0000000..9908e2a Binary files /dev/null and b/docs/learning_jax/slide-decks/3 - Intro to NNX for PyTorch users.pdf differ diff --git a/docs/learning_jax/slide-decks/3 - Intro to NNX for PyTorch users.pptx b/docs/learning_jax/slide-decks/3 - Intro to NNX for PyTorch users.pptx new file mode 100644 index 0000000..8d89e98 Binary files /dev/null and b/docs/learning_jax/slide-decks/3 - Intro to NNX for PyTorch users.pptx differ diff --git a/docs/learning_jax/slide-decks/4 - MNIST example.pdf b/docs/learning_jax/slide-decks/4 - MNIST example.pdf new file mode 100644 index 0000000..3a1781e Binary files /dev/null and b/docs/learning_jax/slide-decks/4 - MNIST example.pdf differ diff --git a/docs/learning_jax/slide-decks/4 - MNIST example.pptx b/docs/learning_jax/slide-decks/4 - MNIST example.pptx new file mode 100644 index 0000000..47636ae Binary files /dev/null and b/docs/learning_jax/slide-decks/4 - MNIST example.pptx differ diff --git a/docs/learning_jax/slide-decks/5 - Chex_ JAX & Flax NNX Reliability.pdf b/docs/learning_jax/slide-decks/5 - Chex_ JAX & Flax NNX Reliability.pdf new file mode 100644 index 0000000..9fa745b Binary files /dev/null and b/docs/learning_jax/slide-decks/5 - Chex_ JAX & Flax NNX Reliability.pdf differ diff --git a/docs/learning_jax/slide-decks/5 - Chex_ JAX & Flax NNX Reliability.pptx b/docs/learning_jax/slide-decks/5 - Chex_ JAX & Flax NNX Reliability.pptx new file mode 100644 index 0000000..d7c2e7d Binary files /dev/null and b/docs/learning_jax/slide-decks/5 - Chex_ JAX & Flax NNX Reliability.pptx differ diff --git a/docs/learning_jax/slide-decks/6 - Debugging JAX and Flax NNX.pdf b/docs/learning_jax/slide-decks/6 - Debugging JAX and Flax NNX.pdf new file mode 100644 index 0000000..8792d9d Binary files /dev/null and b/docs/learning_jax/slide-decks/6 - Debugging JAX and Flax NNX.pdf differ diff --git a/docs/learning_jax/slide-decks/6 - Debugging JAX and Flax NNX.pptx b/docs/learning_jax/slide-decks/6 - Debugging JAX and Flax NNX.pptx new file mode 100644 index 0000000..bd63381 Binary files /dev/null and b/docs/learning_jax/slide-decks/6 - Debugging JAX and Flax NNX.pptx differ diff --git a/docs/learning_jax/slide-decks/7 - Grain for data loading.pdf b/docs/learning_jax/slide-decks/7 - Grain for data loading.pdf new file mode 100644 index 0000000..883f607 Binary files /dev/null and b/docs/learning_jax/slide-decks/7 - Grain for data loading.pdf differ diff --git a/docs/learning_jax/slide-decks/7 - Grain for data loading.pptx b/docs/learning_jax/slide-decks/7 - Grain for data loading.pptx new file mode 100644 index 0000000..83935fc Binary files /dev/null and b/docs/learning_jax/slide-decks/7 - Grain for data loading.pptx differ diff --git a/docs/learning_jax/slide-decks/8 - Orbax for checkpointing.pdf b/docs/learning_jax/slide-decks/8 - Orbax for checkpointing.pdf new file mode 100644 index 0000000..975ac1c Binary files /dev/null and b/docs/learning_jax/slide-decks/8 - Orbax for checkpointing.pdf differ diff --git a/docs/learning_jax/slide-decks/8 - Orbax for checkpointing.pptx b/docs/learning_jax/slide-decks/8 - Orbax for checkpointing.pptx new file mode 100644 index 0000000..6463702 Binary files /dev/null and b/docs/learning_jax/slide-decks/8 - Orbax for checkpointing.pptx differ diff --git a/docs/learning_jax/slide-decks/9 - Serving JAX with vLLM_SGLang.pdf b/docs/learning_jax/slide-decks/9 - Serving JAX with vLLM_SGLang.pdf new file mode 100644 index 0000000..457978b Binary files /dev/null and b/docs/learning_jax/slide-decks/9 - Serving JAX with vLLM_SGLang.pdf differ diff --git a/docs/learning_jax/slide-decks/9 - Serving JAX with vLLM_SGLang.pptx b/docs/learning_jax/slide-decks/9 - Serving JAX with vLLM_SGLang.pptx new file mode 100644 index 0000000..cf1d875 Binary files /dev/null and b/docs/learning_jax/slide-decks/9 - Serving JAX with vLLM_SGLang.pptx differ