diff --git a/machine-learning/jax-haiku-dask-dataframe-distributed-example.ipynb b/machine-learning/jax-haiku-dask-dataframe-distributed-example.ipynb new file mode 100644 index 00000000..af0fcaee --- /dev/null +++ b/machine-learning/jax-haiku-dask-dataframe-distributed-example.ipynb @@ -0,0 +1,535 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# What are Jax and Haiku?" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[Jax](https://github.com/google/jax) is a library for \"composable transformations of Python+NumPy programs\". It includes helpful functions to easily compute the gradient of a function, `grad`, compile functions in a just-in-time manner with XLA for use on GPU/TPU, `jit`, and much more. It it primarily targeted for high-performance machine learning research.\n", + "\n", + "[Haiku](https://github.com/deepmind/dm-haiku/) is a neural network library for JAX that adds helpful model/layers/... classes, while allowing full access to JAX's pure function transformations." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Why Dask + Jax/Haiku?" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Whilst Jax supports much high performance and parallel capababilities, e.g. `pmap`, Dask supports a large number of backends (including [several types of distributed clusters](https://blog.dask.org/2020/07/23/current-state-of-distributed-dask-clusters)) and also integrates very well with much of the PyData stack. Thus in this notebook we explore how we might integrate these libraries." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Example: Learning the sine function with a neural network" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Setup" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Firstly let's setup a dask distributed client and install/import the libaries we need." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from dask.distributed import Client, as_completed\n", + "client = Client()\n", + "client" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "!pip install hvplot jax jaxlib dm-haiku" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import math\n", + "from typing import Tuple\n", + "\n", + "import dask\n", + "import dask.dataframe as dd\n", + "import haiku as hk\n", + "import holoviews as hv\n", + "import hvplot.pandas\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import pandas as pd\n", + "import numpy as np\n", + "from dask_ml.preprocessing import StandardScaler\n", + "from jax.experimental import optix" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here we create the example data (representing the sine wave/function), and covert it into a Pandas dataframe." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "num_points = 50\n", + "x = np.linspace(start=0., stop=2*np.pi, num=num_points)\n", + "y = np.sin(x)\n", + "\n", + "df_all = pd.DataFrame({\n", + " \"x\": x.flatten(),\n", + " \"y\": y.flatten(),\n", + "})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Randomly split the data into train/validation datasets." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_frac = 0.8\n", + "df_train = df_all.sample(frac=train_frac, random_state=42)\n", + "df_validation = df_all.drop(labels=df_train.index)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally we convert the Pandas dataframe into a Dask dataframe. Here we also define the batch size for use when training the model, which corresponds to the number of data points in each partition of a Dask dataframe." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "batch_size = 32 # i.e. the number of data samples passed to the neural network model\n", + "\n", + "ddf_train = dd.from_pandas(df_train, chunksize=batch_size)\n", + "ddf_validation = dd.from_pandas(df_validation, chunksize=batch_size)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Visualising the data confirms our example data does indeed correspond to the sine function:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "(df_train.hvplot.scatter(x=\"x\", y=\"y\", label='Training data') * \n", + " df_validation.hvplot.scatter(x=\"x\", y=\"y\", label='Validation data')).opts(title=\"Dataset\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Define model and initialize ready for training" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To help the training process, let's standardize `x` (the input or \"feature\") so that it has roughly mean 0 and standard deviation 1. Here we use `dask_ml`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "scaler = StandardScaler()\n", + "ddf_train[\"scaled_x\"] = scaler.fit_transform(ddf_train[[\"x\"]]).x\n", + "ddf_validation[\"scaled_x\"] = scaler.transform(ddf_validation[[\"x\"]]).x" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Define and initialize a Haiku neural network." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def net_function(x: jnp.ndarray) -> jnp.ndarray:\n", + " net = hk.Sequential([\n", + " hk.Linear(50), jax.nn.relu,\n", + " hk.Linear(100), jax.nn.relu,\n", + " hk.Linear(50), jax.nn.relu,\n", + " hk.Linear(1),\n", + " ])\n", + " pred = net(x)\n", + " return pred\n", + "\n", + "def initialize_net_function():\n", + " net_transform = hk.transform(net_function)\n", + " rng = jax.random.PRNGKey(42)\n", + " num_features = 1\n", + " example_x = jnp.array(np.random.random([batch_size, num_features]))\n", + " params = net_transform.init(rng, example_x)\n", + " return net_transform, params\n", + "\n", + "net_transform, params = initialize_net_function()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Define a mean squard error loss function." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "@jax.jit\n", + "def loss_function(params: hk.Params, x: jnp.ndarray, y_true: jnp.ndarray) -> jnp.ndarray: \n", + " def mean_squared_error(y_true: jnp.ndarray, y_pred: jnp.ndarray) -> jnp.ndarray:\n", + " loss = jnp.average((y_true - y_pred) ** 2)\n", + " return loss\n", + " \n", + " y_pred: jnp.ndarray = net_transform.apply(params, x)\n", + " loss_value: jnp.ndarray = mean_squared_error(y_true, y_pred)\n", + " return loss_value" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Define and initialise an optimizer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "optimizer: optix.InitUpdate = optix.adam(learning_rate=1e-3)\n", + "opt_state: optix.OptState = optimizer.init(params)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Define a predict function, and create a wrapper for use with Dask dataframe `.map_partitions()`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "@jax.jit\n", + "def predict(params: hk.Params, x: jnp.ndarray) -> jnp.ndarray:\n", + " return net_transform.apply(params, x)\n", + "\n", + "\n", + "def dask_predict_wrapper(df: pd.DataFrame, params: hk.Params) -> jnp.ndarray:\n", + " scaled_x = jnp.array(df[[\"scaled_x\"]].values)\n", + " return predict(params, scaled_x).flatten()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's predict before training our model to see how well it performs (it indeed performs badly as expected):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "ddf_validation[\"y_pred_no_training\"] = ddf_validation.map_partitions(dask_predict_wrapper, params=params)\n", + "\n", + "(df_train.hvplot.scatter(x=\"x\", y=\"y\", label='Training data') * \n", + " ddf_validation.compute().hvplot.scatter(x=\"x\", y=\"y_pred_no_training\", label='Predicted validation data')).opts(title=\"No training\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Training" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "There are a number of ways we can use Dask to help with deep learning/neural network training. For this example we implement the following which focus on when you run out of RAM/memory:\n", + "\n", + "1. Train model using Dask as a lazy loader of data.\n", + "2. Data-parallel training of deep learning models: Compute gradients in parallel.\n", + "\n", + "Other common use cases for distributed deep learning training include when you are CPU bound, e.g. distributed hyperparameter optimization or distributed training of an ensemble of models. This are omited here and are left as an exercise for the reader." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### CASE 1 - Train model using Dask as a lazy loader of data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If we have data that is larger than the RAM we have available, we can use Dask to load and train on it batch-by-batch, one at a time.\n", + "\n", + "Let's first define an update function (the main learning function that updates the model parameters)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "@jax.jit\n", + "def update(params: hk.Params, opt_state: optix.OptState, x: jnp.ndarray, y: jnp.ndarray) -> Tuple[hk.Params, optix.OptState]:\n", + " grads = jax.grad(loss_function)(params, x, y)\n", + " updates, opt_state = optimizer.update(grads, opt_state)\n", + " params = optix.apply_updates(params, updates)\n", + " return params, opt_state" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next we define our training loop." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "for epoch_number in range(200): # num epochs\n", + " for ddf_one_partition in ddf_train.partitions: # for each batch\n", + " df_one_partition = ddf_one_partition.compute()\n", + " scaled_x = jnp.array(df_one_partition[[\"scaled_x\"]].values)\n", + " y = jnp.array(df_one_partition[[\"y\"]].values)\n", + " params, opt_state = update(params, opt_state, scaled_x, y)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's now predict after training our model, and visualise the prediction to see how well it performs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ddf_validation[\"y_pred_with_training_CASE1\"] = ddf_validation.map_partitions(dask_predict_wrapper, params=params)\n", + "\n", + "(df_train.hvplot.scatter(x=\"x\", y=\"y\", label='Training data') * \n", + " ddf_validation.compute().hvplot.scatter(x=\"x\", y=\"y_pred_with_training_CASE1\", label='Predicted validation data')).opts(title=\"With training CASE 1\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### CASE 2 - Data-parallel training of deep learning models: Compute gradients in parallel." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First let's reset the model parameters and optimizer state (so that we don't benefit from the training in Case 1)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "_, params = initialize_net_function()\n", + "opt_state = optimizer.init(params)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In Case 1, we load data only when it's needed (i.e. when we are ready to train that particular batch of data), preventing potential RAM/memory problems. However we are not benefiting from the parallel capabilities of Dask to train multiple batches at the same time.\n", + "\n", + "For this case, we compute the gradients of the loss function in parallel, and pull these back to the client where we update the model with the optimizer there." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "@jax.jit\n", + "def compute_grads(params: hk.Params, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:\n", + " grads = jax.grad(loss_function)(params, x, y)\n", + " return grads\n", + "\n", + "def dask_compute_grads_one_partition_wrapper(ddf: dd.DataFrame, params: hk.Params) -> jnp.ndarray:\n", + " scaled_x = jnp.array(ddf[[\"scaled_x\"]].values)\n", + " y = jnp.array(ddf[[\"y\"]].values)\n", + " grads = compute_grads(params, scaled_x, y)\n", + " return grads" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The training loop now looks like:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for epoch_number in range(200): # num epochs\n", + " futures = []\n", + " for ddf_one_partition in ddf_train.partitions:\n", + " # Compute the gradients in parallel\n", + " futures.append(client.submit(dask_compute_grads_one_partition_wrapper, ddf_one_partition, params))\n", + " \n", + " for future, grads in as_completed(futures, with_results=True):\n", + " # Bring the gradients back to the client, and update the model with the optimizer on the client\n", + " updates, opt_state = optimizer.update(grads, opt_state)\n", + " params = optix.apply_updates(params, updates)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's check our predictions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "ddf_validation[\"y_pred_with_training_CASE2\"] = ddf_validation.map_partitions(dask_predict_wrapper, params=params)\n", + "\n", + "(df_train.hvplot.scatter(x=\"x\", y=\"y\", label='Training data') * \n", + " ddf_validation.compute().hvplot.scatter(x=\"x\", y=\"y_pred_with_training_CASE2\", label='Predicted validation data')).opts(title=\"With training CASE 2\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.4" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}