diff --git a/docs/source/JAX_using_LoRA.ipynb b/docs/source/JAX_using_LoRA.ipynb new file mode 100644 index 0000000..765dc5d --- /dev/null +++ b/docs/source/JAX_using_LoRA.ipynb @@ -0,0 +1,1315 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "QEhawzCcCcFR" + }, + "source": [ + "#Using LoRA in Jax\n", + "\n", + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_using_LoRA.ipynb)\n", + "\n", + "\n", + "This tutorial demonstrates how to implement LoRA for efficient fine-tuning of language models in JAX.\n", + "It builds upon the [JAX for LLM pretraining](https://docs.jaxstack.ai/en/latest/JAX_for_LLM_pretraining.html) tutorial by showing how to replace standard linear\n", + "layers with LoRA-enabled linear layers to significantly reduce the number of trainable parameters.\n", + "\n", + "LoRA (Low-Rank Adaptation) is a parameter-efficient fine-tuning technique that:\n", + "- Keeps pre-trained model weights frozen\n", + "- Adds small trainable low-rank decomposition matrices to certain layers\n", + "- Drastically reduces the number of trainable parameters (often by 90%+)\n", + "\n", + "In the first chapter we will buildi a LoRA-enabled model from scratch, while the next chapter: \"2. Fine-tuning a pre-trained LLM with LoRA\" will demonstrate the more common and practical workflow of applying LoRA to existing pre-trained models.\n", + "\n", + "Both chapters show how to implement these techniques using JAX and Flax's NNX library." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NIOXoY1xgiww" + }, + "source": [ + "# 1.Creating a LoRa enabled LLM in Jax from scratch\n", + "\n", + "In this chapter, we'll take an unconventional approach by implementing a language model with LoRA from scratch. This is different from standard practice, where LoRA is typically applied to already pre-trained models as a fine-tuning technique.\n", + "\n", + "Why are we doing it this way? While not the optimal approach to train a model that achives good preformace (as we'll see in our results), building from scratch makes the integration of LoRA components within the model architecture more clear.\n", + "\n", + "If you're interested in the more practical approach of applying LoRA to an existing pre-trained model, you can skip to the next chapter where we demonstrate that workflow." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hTmz5Cbco7n_" + }, + "source": [ + "## Setup\n", + "Install required packages" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "execution": { + "iopub.execute_input": "2025-03-17T06:56:05.97404Z", + "iopub.status.busy": "2025-03-17T06:56:05.973829Z", + "iopub.status.idle": "2025-03-17T06:56:33.328581Z", + "shell.execute_reply": "2025-03-17T06:56:33.32728Z", + "shell.execute_reply.started": "2025-03-17T06:56:05.974017Z" + }, + "id": "6zMsOIc7ouCO", + "outputId": "40d84dff-b5c6-45ed-df08-fb22d3eeb01a" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/99.7 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m99.7/99.7 kB\u001b[0m \u001b[31m3.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/424.2 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m424.2/424.2 kB\u001b[0m \u001b[31m13.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.3/2.3 MB\u001b[0m \u001b[31m51.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.2/2.2 MB\u001b[0m \u001b[31m74.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m394.9/394.9 kB\u001b[0m \u001b[31m26.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m86.2/86.2 kB\u001b[0m \u001b[31m6.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m102.0/102.0 MB\u001b[0m \u001b[31m11.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m55.2/55.2 kB\u001b[0m \u001b[31m3.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m50.9/50.9 kB\u001b[0m \u001b[31m3.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.2/1.2 MB\u001b[0m \u001b[31m16.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m433.5/433.5 kB\u001b[0m \u001b[31m24.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.6/8.6 MB\u001b[0m \u001b[31m102.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m63.0/63.0 kB\u001b[0m \u001b[31m4.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h" + ] + } + ], + "source": [ + "!pip install -q jax-ai-stack\n", + "!pip install -Uq tiktoken grain matplotlib" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Rcji_799n4eA" + }, + "source": [ + "Confirm we have TPUs set up." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "collapsed": true, + "execution": { + "iopub.execute_input": "2025-03-17T06:56:33.329853Z", + "iopub.status.busy": "2025-03-17T06:56:33.32961Z", + "iopub.status.idle": "2025-03-17T06:56:41.601645Z", + "shell.execute_reply": "2025-03-17T06:56:41.600599Z", + "shell.execute_reply.started": "2025-03-17T06:56:33.329829Z" + }, + "id": "LS9sQEY3n0mB", + "outputId": "b516c248-777f-4a59-a550-26e12bc2e2fc" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),\n", + " TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),\n", + " TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),\n", + " TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),\n", + " TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),\n", + " TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),\n", + " TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),\n", + " TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import jax\n", + "jax.devices()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OHzJ_bokoovZ" + }, + "source": [ + "Get the [TinyStories dataset from Hugging Face](https://huggingface.co/datasets/roneneldan/TinyStories). We only use the training split." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "collapsed": true, + "execution": { + "iopub.execute_input": "2025-03-17T06:56:41.603132Z", + "iopub.status.busy": "2025-03-17T06:56:41.602793Z", + "iopub.status.idle": "2025-03-17T06:57:08.265245Z", + "shell.execute_reply": "2025-03-17T06:57:08.26394Z", + "shell.execute_reply.started": "2025-03-17T06:56:41.603105Z" + }, + "id": "wUjQsgQEmI1N", + "outputId": "90fc683c-696f-4f25-a75c-6a2a5b032cef" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--2025-03-22 10:21:26-- https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories-train.txt?download=true\n", + "Resolving huggingface.co (huggingface.co)... 18.172.134.24, 18.172.134.124, 18.172.134.4, ...\n", + "Connecting to huggingface.co (huggingface.co)|18.172.134.24|:443... connected.\n", + "HTTP request sent, awaiting response... 302 Found\n", + "Location: https://cdn-lfs.hf.co/repos/42/7f/427f7497b6c6596c18b46d5a72e61364fcad12aa433c60a0dbd4d344477b9d81/c5cf5e22ff13614e830afbe61a99fbcbe8bcb7dd72252b989fa1117a368d401f?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27TinyStories-train.txt%3B+filename%3D%22TinyStories-train.txt%22%3B&response-content-type=text%2Fplain&Expires=1742642487&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc0MjY0MjQ4N319LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy80Mi83Zi80MjdmNzQ5N2I2YzY1OTZjMThiNDZkNWE3MmU2MTM2NGZjYWQxMmFhNDMzYzYwYTBkYmQ0ZDM0NDQ3N2I5ZDgxL2M1Y2Y1ZTIyZmYxMzYxNGU4MzBhZmJlNjFhOTlmYmNiZThiY2I3ZGQ3MjI1MmI5ODlmYTExMTdhMzY4ZDQwMWY%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qJnJlc3BvbnNlLWNvbnRlbnQtdHlwZT0qIn1dfQ__&Signature=ceYF5N6wd0LwLJeI2W8Kin98I1WRxebZ7JRieOSh2CjLM9zjcqGN1hBljgY57bqPvAwNbdaKDgq1A%7EfawfV%7Ek9bidYFPStA3qDmL6uojttQVzwrTgkNQfIh6Lr1DQx8n0aYtrKsoZHnnCAl4XpK4iZpOixcVkgpxxp44EwiSJGxQNQFRc%7ERgcBsj9rzAS6%7EBb-TCq71jxA%7EQQJbccigJLEubGSIwPK4cSEpvX2AmUKxD2d%7EjdxyDxiISB4H86s0F183g2zXhc-wihfP2B7hgFM589pTNNke-Q0EY8tM5dNWCNN6-AxTu-dkCsOSCGRJsL%7EvM%7ERRwPEYVSZ21M%7ERYCw__&Key-Pair-Id=K3RPWS32NSSJCE [following]\n", + "--2025-03-22 10:21:27-- https://cdn-lfs.hf.co/repos/42/7f/427f7497b6c6596c18b46d5a72e61364fcad12aa433c60a0dbd4d344477b9d81/c5cf5e22ff13614e830afbe61a99fbcbe8bcb7dd72252b989fa1117a368d401f?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27TinyStories-train.txt%3B+filename%3D%22TinyStories-train.txt%22%3B&response-content-type=text%2Fplain&Expires=1742642487&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc0MjY0MjQ4N319LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy80Mi83Zi80MjdmNzQ5N2I2YzY1OTZjMThiNDZkNWE3MmU2MTM2NGZjYWQxMmFhNDMzYzYwYTBkYmQ0ZDM0NDQ3N2I5ZDgxL2M1Y2Y1ZTIyZmYxMzYxNGU4MzBhZmJlNjFhOTlmYmNiZThiY2I3ZGQ3MjI1MmI5ODlmYTExMTdhMzY4ZDQwMWY%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qJnJlc3BvbnNlLWNvbnRlbnQtdHlwZT0qIn1dfQ__&Signature=ceYF5N6wd0LwLJeI2W8Kin98I1WRxebZ7JRieOSh2CjLM9zjcqGN1hBljgY57bqPvAwNbdaKDgq1A%7EfawfV%7Ek9bidYFPStA3qDmL6uojttQVzwrTgkNQfIh6Lr1DQx8n0aYtrKsoZHnnCAl4XpK4iZpOixcVkgpxxp44EwiSJGxQNQFRc%7ERgcBsj9rzAS6%7EBb-TCq71jxA%7EQQJbccigJLEubGSIwPK4cSEpvX2AmUKxD2d%7EjdxyDxiISB4H86s0F183g2zXhc-wihfP2B7hgFM589pTNNke-Q0EY8tM5dNWCNN6-AxTu-dkCsOSCGRJsL%7EvM%7ERRwPEYVSZ21M%7ERYCw__&Key-Pair-Id=K3RPWS32NSSJCE\n", + "Resolving cdn-lfs.hf.co (cdn-lfs.hf.co)... 18.154.185.84, 18.154.185.78, 18.154.185.64, ...\n", + "Connecting to cdn-lfs.hf.co (cdn-lfs.hf.co)|18.154.185.84|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 1924281556 (1.8G) [text/plain]\n", + "Saving to: ‘TinyStories-train.txt’\n", + "\n", + "TinyStories-train.t 100%[===================>] 1.79G 194MB/s in 9.0s \n", + "\n", + "2025-03-22 10:21:36 (205 MB/s) - ‘TinyStories-train.txt’ saved [1924281556/1924281556]\n", + "\n" + ] + } + ], + "source": [ + "!wget https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories-train.txt?download=true -O TinyStories-train.txt" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "sKE2uUafLobI" + }, + "source": [ + "Import necessary libraries" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "execution": { + "iopub.execute_input": "2025-03-17T06:57:08.26664Z", + "iopub.status.busy": "2025-03-17T06:57:08.266392Z", + "iopub.status.idle": "2025-03-17T06:57:10.140961Z", + "shell.execute_reply": "2025-03-17T06:57:10.140253Z", + "shell.execute_reply.started": "2025-03-17T06:57:08.266614Z" + }, + "id": "MKYFNOhdLq98" + }, + "outputs": [], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "import flax.nnx as nnx\n", + "from flax.nnx.nn.lora import LoRALinear # Import LoRALinear\n", + "import optax\n", + "from dataclasses import dataclass\n", + "import grain.python as pygrain\n", + "from jax.experimental import mesh_utils\n", + "from jax.sharding import Mesh, PartitionSpec as P, NamedSharding\n", + "import pandas as pd\n", + "import tiktoken\n", + "import time" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rPyt7MV6prz1" + }, + "source": [ + "## Building the Model with LoRA\n", + "\n", + "We'll use the same tokenizer and parallelism strategy as in the [pre-training tutorial](https://docs.jaxstack.ai/en/latest/JAX_for_LLM_pretraining.html).\n", + "The mesh defines how our computation will be distributed across TPU cores." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "execution": { + "iopub.execute_input": "2025-03-17T06:57:10.142677Z", + "iopub.status.busy": "2025-03-17T06:57:10.142052Z", + "iopub.status.idle": "2025-03-17T06:57:10.147888Z", + "shell.execute_reply": "2025-03-17T06:57:10.147122Z", + "shell.execute_reply.started": "2025-03-17T06:57:10.142651Z" + }, + "id": "xuMlCK3Q8WJD" + }, + "outputs": [], + "source": [ + "tokenizer = tiktoken.get_encoding(\"gpt2\")\n", + "mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('batch', 'model'))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0XHQ0BQ9-KIj" + }, + "source": [ + "The key difference from the original pre-training model is that we replace standard\n", + "`nnx.Linear` layers with `LoRALinear` layers from [Flax](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/lora.html).\n", + "\n", + "\n", + "This way, only the small rank decomposition matrices need to be trained." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "execution": { + "iopub.execute_input": "2025-03-17T06:57:13.018347Z", + "iopub.status.busy": "2025-03-17T06:57:13.01792Z", + "iopub.status.idle": "2025-03-17T06:57:13.045763Z", + "shell.execute_reply": "2025-03-17T06:57:13.044457Z", + "shell.execute_reply.started": "2025-03-17T06:57:13.018317Z" + }, + "id": "z0p-IHurrB9i" + }, + "outputs": [], + "source": [ + "def causal_attention_mask(seq_len):\n", + " return jnp.tril(jnp.ones((seq_len, seq_len)))\n", + "\n", + "class TransformerBlock(nnx.Module):\n", + " # update the __init__ function arguments to include lora_rank\n", + " def __init__(self, embed_dim: int, num_heads: int, ff_dim: int, *, rngs: nnx.Rngs, rate: float = 0.1, lora_rank=8):\n", + " self.mha = nnx.MultiHeadAttention(num_heads=num_heads,\n", + " in_features=embed_dim,\n", + " kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), NamedSharding(mesh, P(None, 'model'))),\n", + " bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P('model'))),\n", + " rngs=rngs)\n", + " self.dropout1 = nnx.Dropout(rate=rate)\n", + " self.layer_norm1 = nnx.LayerNorm(epsilon=1e-6,\n", + " num_features=embed_dim,\n", + " scale_init=nnx.with_partitioning(nnx.initializers.ones_init(), NamedSharding(mesh, P('model'))),\n", + " bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P('model'))),\n", + " rngs=rngs)\n", + " # here we replace the regular linear layer with the LoRALinea layer\n", + " self.linear1 = LoRALinear(\n", + " in_features=embed_dim,\n", + " out_features=ff_dim,\n", + " lora_rank=lora_rank, # set the rank for the low-rank matrices\n", + " kernel_init=nnx.with_partitioning(nnx.initializers.normal(0.02), P('model', None)),\n", + " bias_init=nnx.with_partitioning(nnx.initializers.zeros, None),\n", + " rngs=rngs\n", + " )\n", + " # here we replace the regular linear layer with the LoRALinea layer\n", + " self.linear2 = LoRALinear(\n", + " in_features=ff_dim,\n", + " out_features=embed_dim,\n", + " lora_rank=lora_rank,\n", + " kernel_init=nnx.with_partitioning(nnx.initializers.normal(0.02), P('model', None)),\n", + " bias_init=nnx.with_partitioning(nnx.initializers.zeros, None),\n", + " rngs=rngs\n", + " )\n", + " self.dropout2 = nnx.Dropout(rate=rate)\n", + " self.layer_norm2 = nnx.LayerNorm(epsilon=1e-6,\n", + " num_features=embed_dim,\n", + " scale_init=nnx.with_partitioning(nnx.initializers.ones_init(), NamedSharding(mesh, P(None, 'model'))),\n", + " bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P(None, 'model'))),\n", + " rngs=rngs)\n", + "\n", + "\n", + " def __call__(self, inputs, training: bool = False):\n", + " input_shape = inputs.shape\n", + " _, seq_len, _ = input_shape\n", + " mask = causal_attention_mask(seq_len)\n", + " attention_output = self.mha(\n", + " inputs_q=inputs,\n", + " mask=mask,\n", + " decode=False\n", + " )\n", + " attention_output = self.dropout1(attention_output, deterministic=not training)\n", + " out1 = self.layer_norm1(inputs + attention_output)\n", + " # feed-forward network with LoRA layer\n", + " ffn_output = self.linear1(out1)\n", + " ffn_output = nnx.relu(ffn_output)\n", + " ffn_output = self.linear2(ffn_output)\n", + " ffn_output = self.dropout2(ffn_output, deterministic=not training)\n", + "\n", + " return self.layer_norm2(out1 + ffn_output)\n", + "\n", + "\n", + "class TokenAndPositionEmbedding(nnx.Module):\n", + "\n", + " def __init__(self, maxlen: int, vocab_size: int, embed_dim: int, *, rngs: nnx.Rngs):\n", + " self.token_emb = nnx.Embed(num_embeddings=vocab_size, features=embed_dim, rngs=rngs)\n", + " self.pos_emb = nnx.Embed(num_embeddings=maxlen, features=embed_dim, rngs=rngs)\n", + "\n", + " def __call__(self, x):\n", + " positions = jnp.arange(0, x.shape[1])[None, :]\n", + " position_embedding = self.pos_emb(positions)\n", + " token_embedding = self.token_emb(x)\n", + " return token_embedding + position_embedding\n", + "\n", + "\n", + "class MiniGPT(nnx.Module):\n", + " # update the __init__ function arguments to include lora_rank\n", + " def __init__(self, maxlen: int, vocab_size: int, embed_dim: int, num_heads: int, feed_forward_dim: int, num_transformer_blocks: int, rngs: nnx.Rngs, lora_rank=8):\n", + " self.embedding_layer = TokenAndPositionEmbedding(\n", + " maxlen, vocab_size, embed_dim, rngs=rngs\n", + " )\n", + " # create transformer blocks with LoRA\n", + " self.transformer_blocks = [TransformerBlock(\n", + " embed_dim, num_heads, feed_forward_dim, rngs=rngs, lora_rank=lora_rank\n", + " ) for _ in range(num_transformer_blocks)]\n", + "\n", + " # modify the output layer to use LoRALinear instead of regular linear layer\n", + " self.output_layer = LoRALinear(\n", + " in_features=embed_dim,\n", + " out_features=vocab_size,\n", + " lora_rank=lora_rank,\n", + " kernel_init=nnx.with_partitioning(nnx.initializers.normal(0.02), P('model', None)),\n", + " bias_init=nnx.with_partitioning(nnx.initializers.zeros, None),\n", + " rngs=rngs\n", + " )\n", + "\n", + "\n", + " def __call__(self, inputs, training: bool = False):\n", + " x = self.embedding_layer(inputs)\n", + " for transformer_block in self.transformer_blocks:\n", + " x = transformer_block(x, training=training)\n", + " outputs = self.output_layer(x)\n", + " return outputs\n", + "\n", + " def generate_text(self, max_tokens: int, start_tokens: [int], top_k=10):\n", + " def sample_from(logits):\n", + " logits, indices = jax.lax.top_k(logits, k=top_k)\n", + " logits = nnx.softmax(logits)\n", + " return jax.random.choice(jax.random.PRNGKey(0), indices, p=logits)\n", + "\n", + " def generate_step(start_tokens):\n", + " pad_len = maxlen - len(start_tokens)\n", + " sample_index = len(start_tokens) - 1\n", + " if pad_len < 0:\n", + " x = jnp.array(start_tokens[:maxlen])\n", + " sample_index = maxlen - 1\n", + " elif pad_len > 0:\n", + " x = jnp.array(start_tokens + [0] * pad_len)\n", + " else:\n", + " x = jnp.array(start_tokens)\n", + "\n", + " x = x[None, :]\n", + " logits = self(x)\n", + " next_token = sample_from(logits[0][sample_index])\n", + " return next_token\n", + "\n", + " generated = []\n", + " for _ in range(max_tokens):\n", + " next_token = generate_step(start_tokens + generated)\n", + " if next_token == tokenizer.encode('<|endoftext|>', allowed_special={'<|endoftext|>'})[0]:\n", + " break\n", + " generated.append(int(next_token))\n", + " return tokenizer.decode(start_tokens + generated)\n", + "\n", + "# modify the function arguments to include lora_rank\n", + "def create_model(rngs, lora_rank=8):\n", + " return MiniGPT(maxlen, vocab_size, embed_dim, num_heads, feed_forward_dim, num_transformer_blocks=4, rngs=rngs,\n", + " lora_rank=lora_rank)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "igX_eoGNMTGR" + }, + "source": [ + "## Set Hyperparameters\n", + "\n", + "We'll use the same hyperparameters as in the [pre-training tutorial](https://docs.jaxstack.ai/en/latest/JAX_for_LLM_pretraining.html) for consistency." + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": { + "execution": { + "iopub.execute_input": "2025-03-17T06:57:13.046952Z", + "iopub.status.busy": "2025-03-17T06:57:13.046709Z", + "iopub.status.idle": "2025-03-17T06:57:13.060711Z", + "shell.execute_reply": "2025-03-17T06:57:13.059813Z", + "shell.execute_reply.started": "2025-03-17T06:57:13.04693Z" + }, + "id": "GRhiDsCrMZRp" + }, + "outputs": [], + "source": [ + "vocab_size = tokenizer.n_vocab\n", + "num_transformer_blocks = 8\n", + "maxlen = 256\n", + "embed_dim = 256\n", + "num_heads = 8\n", + "feed_forward_dim = 256\n", + "batch_size = 256 # You can adjust batch size based on your TP\n", + "num_epochs = 1\n", + "lora_rank = 128 # A higher rank will capture more complex patterns in the LLM, and will also increase the number of trainable parameters" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mI1ci-HyMspJ" + }, + "source": [ + "## Prepare data\n", + "\n", + "Data loading and preprocessing remains the same as in the [pre-training tutorial](https://docs.jaxstack.ai/en/latest/JAX_for_LLM_pretraining.html).\n", + "We create a TextDataset class to handle tokenization and padding." + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": { + "execution": { + "iopub.execute_input": "2025-03-17T06:57:13.062239Z", + "iopub.status.busy": "2025-03-17T06:57:13.06186Z", + "iopub.status.idle": "2025-03-17T06:57:45.256964Z", + "shell.execute_reply": "2025-03-17T06:57:45.255537Z", + "shell.execute_reply.started": "2025-03-17T06:57:13.062202Z" + }, + "id": "rGUFsn1GMuzh" + }, + "outputs": [], + "source": [ + "@dataclass\n", + "class TextDataset:\n", + " data: list\n", + " maxlen: int\n", + "\n", + " def __len__(self):\n", + " return len(self.data)\n", + "\n", + " def __getitem__(self, idx: int):\n", + " encoding = tokenizer.encode(self.data[idx], allowed_special={'<|endoftext|>'})[:self.maxlen] # Tokenize and truncate\n", + " return encoding + [0] * (self.maxlen - len(encoding)) # Pad to maxlen\n", + "\n", + "def load_and_preprocess_data(file_path, batch_size, maxlen):\n", + "\n", + " with open(file_path, 'r') as f:\n", + " text = f.read()\n", + "\n", + " stories = text.split('<|endoftext|>')\n", + " stories = [story+'<|endoftext|>' for story in stories if story.strip()]\n", + " df = pd.DataFrame({'text': stories})\n", + " data = df['text'].dropna().tolist()\n", + " dataset = TextDataset(data, maxlen)\n", + "\n", + " sampler = pygrain.IndexSampler(\n", + " len(dataset),\n", + " shuffle=False,\n", + " seed=42,\n", + " shard_options=pygrain.NoSharding(),\n", + " num_epochs=num_epochs,\n", + " )\n", + "\n", + " dl = pygrain.DataLoader(\n", + " data_source=dataset,\n", + " sampler=sampler,\n", + " operations=[pygrain.Batch(batch_size=batch_size, drop_remainder=True)],\n", + " )\n", + "\n", + " return dl\n", + "\n", + "text_dl = load_and_preprocess_data('TinyStories-train.txt', batch_size, maxlen)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BKVSD8KSM1um" + }, + "source": [ + "## Train the model with LoRA" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WbSt_MuyaG48" + }, + "source": [ + "LoRA's efficiency lies in how we train only the small adapter matrices while keeping the rest of the model frozen. Let's look at how we implement this in JAX:" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "metadata": { + "id": "h9hXS0NngSAw" + }, + "outputs": [], + "source": [ + "# Create the model with LoRA\n", + "lora_model = create_model(rngs=nnx.Rngs(0), lora_rank=lora_rank)\n", + "# Filter for LoRA parameters only (look for lora_a and lora_b in the parameter path)\n", + "lora_params = nnx.All(nnx.Param, nnx.PathContains('lora_a') or nnx.PathContains('lora_b'))\n", + "# Create optimizer to only update LoRA parameters\n", + "optimizer = nnx.Optimizer(lora_model, optax.adam(1e-3), wrt=lora_params)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "e5hooDhBadPb" + }, + "source": [ + " Using `nnx.All` create a mask that identifies only our LoRA parameters, looking for lora_a or lora_b in the parameter paths. Then we:\n", + "\n", + "- Configure the optimizer to only update these selected parameters using the `wrt` argument\n", + "-Create a special `diff_state` that directs gradient computation to only flow to these parameters\n", + "\n", + "Now we can use this `diff_state` when computing gradients in our training step:" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": { + "id": "reUqnpEtiy0e" + }, + "outputs": [], + "source": [ + "def loss_fn(model, batch):\n", + " logits = model(batch[0])\n", + " loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=batch[1]).mean()\n", + " return loss, logits\n", + "\n", + "@nnx.jit\n", + "def train_step(model: MiniGPT, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch):\n", + " # Create differentiable state that only includes LoRA parameters\n", + " diff_state = nnx.DiffState(0, lora_params)\n", + " grad_fn = nnx.value_and_grad(loss_fn, argnums=diff_state, has_aux=True)\n", + " (loss, logits), grads = grad_fn(model, batch)\n", + " metrics.update(loss=loss, logits=logits, lables=batch[1])\n", + " optimizer.update(grads)" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "collapsed": true, + "execution": { + "iopub.status.busy": "2025-03-17T06:57:45.281022Z", + "iopub.status.idle": "2025-03-17T06:57:45.281857Z", + "shell.execute_reply": "2025-03-17T06:57:45.281296Z", + "shell.execute_reply.started": "2025-03-17T06:57:45.281249Z" + }, + "id": "Ysl6CsfENeJN", + "outputId": "3236e7a1-4d6b-4378-b580-65a509236b23" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Initial generated text:\n", + "Once upon a time choiricity electronically src LaureinterpretHatis Aviation utilizing electronicallyis revealingis Lund353 stationary choiris Ho Ho showdown showdown choir Journalsis MM showdownerreddefault Kag allotted showdownis showdown choir不is Laure showdown Schiff MM KagisFig Aviation markupbysserredis revealing Laure hanging Needs exhausted showdown srcisis Aviationborgh disgis Crystal showdown showdown showdown bystand flippedis showdownisFiguggestsemberredis revealingis revealingADRA Laure hanging/** Aviation Needsisampions nonsatur Stamfordis Laure bystandis wink (£ showdown showdown nons Maul Laure (£ compos showdownampions nonsbyss crafts Rarityis lic busiest stationary markup electronically diedbyssbean electronicallyis dried showdown showdown showdownampions nonsisis exhausteddefault markup electronically Laure showdown electronically MM showdown nonsbyssis showdown showdown roller Maul compos roller Ghost electronically MMis showdown showdown electronically MM showdown 1931 revealing showdown bystand dried showdownerredis dried Laure lic W markup markupbyssbeanett electronicallydefaulticitysemb Edinburgh Oxy Aviation markup Severbyss freeway sources Rarity Rarity bystandis Ecology不 showdown showdown ........byss bystanddefault markupis revealingis353 showdown bystandis who rollerEast crafts showdown 1931default markup electronicallyampions showdown bystand lens Ecology markup353 1931soliddefaultassetsis showdownampions (£ Zombiesatur Ecology Hundreds sourcessolid Stamfordbyssis353 showdown nons Piano bystand不assetsis revealingis Ecology died showdown showdown showdown showdown showdown showdown\n", + "\n", + "Step 200, Loss: 6.526708602905273, Elapsed Time: 101.13 seconds\n", + "Generated text:\n", + "Once upon a time, there was a little boy.\n", + "\n", + "\n", + "Step 400, Loss: 5.55381965637207, Elapsed Time: 36.60 seconds\n", + "Generated text:\n", + "Once upon a time there was a boy. He was very he was very he had a big and he had to the460. He had to the certainty he had to the certainty. He was very happy to theInitial he was very happy to the other.\n", + "The little girl was very happy and he had a big and he had to the certainty.\n", + "The little girl was very happy and he had to the other.\n", + "\n", + "\n", + "Step 600, Loss: 5.371800422668457, Elapsed Time: 47.13 seconds\n", + "Generated text:\n", + "Once upon a time, there was a little girl named Lily. She loved to the park. She was very happy and the park. She was very happy to the park and it was very happy to the park.\n", + "The little girl was so happy to the park. She was very happy to the park. She was very happy to the little girl was very happy to the little girl was very happy and the little girl was very happy to the little girl was very happy to the little girl was very happy to the little girl was very happy to the]).\n", + "\n", + "\n", + "Step 800, Loss: 5.213316440582275, Elapsed Time: 52.99 seconds\n", + "Generated text:\n", + "Once upon a time, he had to the park. The little boy was very happy. He had a big, he was very happy to the park.\n", + "The little girl was very happy and the park. He was very happy and he had a big, and he was very happy to theWorks.\n", + "The little girl was very happy and the little girl.\n", + "The little girl was very happy and the little girl was very happy to the little girl was very happy to the little girl was very happy to the little girl was very happy to the little girl was very happy and the little girl.\n", + "\n", + "\n", + "Step 1000, Loss: 5.096640586853027, Elapsed Time: 55.33 seconds\n", + "Generated text:\n", + "Once upon a time, there were two friends. The little girl was very happy and the park. \n", + "The little girl was so happy to the park. She was so happy to the little girl was so happy and the little girl. \n", + "The little girl was so happy that she had a big, she had a big, she had a big, she had a big, and the little girl was so she was so happy to the little girl. \n", + "\n", + "\n", + "Step 1200, Loss: 5.006669044494629, Elapsed Time: 48.46 seconds\n", + "Generated text:\n", + "Once upon a time, there were two friends. He was very happy and he had a big and he had a big and he had a big, he was very happy.\n", + "The little girl was very happy and he had a big and he was very happy. He was very happy to be careful. He was so happy and he had a big and he had a big and he was very happy.\n", + "The little boy. He was so happy to the little boy was so happy and he was so happy and he had a big and he was so happy and he had a big and he had a big, and he was so excited to the little boy.\n", + "\n", + "\n", + "Step 1400, Loss: 4.9487762451171875, Elapsed Time: 58.27 seconds\n", + "Generated text:\n", + "Once upon a time, plots were two friends. He was so he had to make his friends.\n", + "The little boy was so he had to the park. He was so he had to theennis, he could not to the man. He was so he was so he had to the little boy.\n", + "The little boy was so he was so he was so he could not to the little boy was so he had to the little boy. He was so he could not to the little boy. He was so he could not to the little boy was so he was so he could not to the little boy.\n", + "The little boy was so he was so happy to the little boy was so he was so he could not to the little boy was so he could be careful.\n", + "\n", + "\n", + "Step 1600, Loss: 4.888418197631836, Elapsed Time: 62.03 seconds\n", + "Generated text:\n", + "Once upon a time, the park. He had a big and the park. The little girl was very happy.\n", + "The girl was very happy. The little girl was very happy. The little girl was so happy.\n", + "The little girl was so happy and the little girl had to the little girl. The little girl was so happy.\n", + "The little girl was so happy and the little girl was so happy. The little girl was so happy.\n", + "The little girl was so happy and the little girl. The little girl was so happy. The little girl was so happy and the little girl was so happy.\n", + "\n", + "\n", + "Step 1800, Loss: 4.831171989440918, Elapsed Time: 55.54 seconds\n", + "Generated text:\n", + "Once upon a time, the demol and the diving. The little girl was so happy. The little girl was so happy.\n", + "The little girl was so happy to the little girl. The little girl was so happy to the little girl.\n", + "The little girl was so happy to the little girl and the little girl.\n", + "The little girl was so happy to the little girl was so happy to the little girl. The little girl and the little girl was so happy to be so excited to the little girl.\n", + "\n", + "\n", + "Step 2000, Loss: 4.822394371032715, Elapsed Time: 49.92 seconds\n", + "Generated text:\n", + "Once upon a time, the Venus was very happy. He was so excited to be a big and he could help. He was so he could not to the park to the other.\n", + "The boy was so excited to the little boy was so he was so he could not to the little boy. He was so he was so he could395 to the little boy.\n", + "The little boy was so excited to the little boy was so excited to the little boy. He was so excited to the little boy. He was so he was so he was so happy that he could be careful with his mom smiled and he was so he was so excited to be careful.\n", + "\n", + "\n", + "Step 2200, Loss: 4.767486572265625, Elapsed Time: 57.48 seconds\n", + "Generated text:\n", + "Once upon a time, the Venus was very happy. He was so excited to the park. He was so excited to be a big, he could be careful.\n", + "The little boy was so excited to the little boy. He was so excited to the little boy. He was so he could help his mom.\n", + "The little boy was so happy to the little boy was so happy to the little boy. He was so happy to the little boy. He was so happy that he could not to the little boy.\n", + "\n", + "\n", + "Step 2400, Loss: 4.686198711395264, Elapsed Time: 50.59 seconds\n", + "Generated text:\n", + "Once upon a time, the demol was very happy. He had a big and he was so he was very happy.\n", + "The little boy was so he could help his mom. He was so excited to be a big, he had to be a big and he could help him.\n", + "The little boy was so excited to the little boy and he was so excited to be the little boy. He was so he was so excited to the little boy.\n", + "The boy was so excited to the little boy and he was so happy that he was so happy that he was so happy that he was so he was so he could help his mom.\n", + "\n", + "\n", + "Step 2600, Loss: 4.7401885986328125, Elapsed Time: 57.33 seconds\n", + "Generated text:\n", + "Once upon a time, thestrate Dul. He had a big, and he was very happy.\n", + "The boy was very happy. He had a big, and he was very happy. He was very happy.\n", + "The boy was so happy. He had a big and he could not want to the neutrality. He was so happy.\n", + "The boy was so happy to the man. He was so happy. He was so excited that he had a big, he had a big, he had a big and he could not to the little boy.\n", + "The boy was so happy that he had to the little boy. He was so happy that he could not to the little boy.\n", + "\n", + "\n", + "Step 2800, Loss: 4.7345733642578125, Elapsed Time: 57.90 seconds\n", + "Generated text:\n", + "Once upon a time, the little girl and the little girl. She had to always a big, and the little girl. She was so she wanted to the little girl.\n", + "The little girl was so excited to the little girl. She wanted to the little girl and the little girl.\n", + "The little girl was so happy to the little girl and the little girl. She was so happy that she was so excited!\n", + "The little girl was so happy that she could help her mommy and the little girl. She was so happy that she had to the little girl and the little girl.\n", + "\n", + "\n", + "Step 3000, Loss: 4.726464748382568, Elapsed Time: 54.94 seconds\n", + "Generated text:\n", + "Once upon a time, the demol was very happy.\n", + "The next day, the little girl and the little girl. She was so happy.\n", + "The little girl was so happy that she had to the little girl.\n", + "The little girl was so excited!\n", + "The little girl was so excited!\n", + "The little girl was so happy!\n", + "The little girl was so happy!\n", + "The little girl was so excited!\n", + "The little girl was so happy!\n", + "\n", + "\n", + "Step 3200, Loss: 4.574701309204102, Elapsed Time: 48.22 seconds\n", + "Generated text:\n", + "Once upon a time, the little girl and the little girl. She was so excited!\n", + "The little girl was so excited!\n", + "The little girl was so excited!\n", + "The little girl was so excited!\n", + "The little girl was so excited!\n", + "The little girl was so excited!\n", + "The little girl was so happy!\n", + "The little girl was so happy!\n", + "The little girl was so happy!\n", + "The little girl was so happy!\n", + "The little girl was so happy!\n", + "\n", + "\n", + "Step 3400, Loss: 4.61902379989624, Elapsed Time: 50.40 seconds\n", + "Generated text:\n", + "Once upon a time, the little girl and the little girl. She was so excited!\n", + "The little girl was so excited!\n", + "The little girl was so excited!\n", + "The little girl was so excited!\n", + "The little girl was so excited!\n", + "The little girl was so excited!\n", + "The little girl was so happy!\n", + "The little girl was so happy!\n", + "The little girl was so happy!\n", + "The little girl was so happy!\n", + "The little girl was so excited!\n", + "The little girl was so happy!\n", + "\n", + "\n", + "Step 3600, Loss: 4.632308006286621, Elapsed Time: 51.33 seconds\n", + "Generated text:\n", + "Once upon a time, the diving, the divingsheet. The bird had a big and the bird. The bird and the bird. The bird was very happy. The bird and the bird and the bird. The bird and the bird. The bird and the bird and the bird. The bird and the bird and the bird. The bird and the bird and the bird viewing the bird. The bird and the bird. The bird and the bird and the bird and the bird. The bird and the bird and the bird. The bird and the bird. The bird and the bird and the bird and the bird. The bird and the bird and the bird and the bird. The bird and the bird and the bird viewing the bird. The bird and the bird and the bird. The bird and the bird and the bird. The bird and the bird and the bird. The bird and the bird and the bird. The bird and the bird and the bird and the bird. The bird and the bird and the bird and the bird. The bird. The bird and the bird and the bird. The bird and the bird and the bird and the bird. The bird and the bird viewing the bird and the bird and the bird. The bird and the bird. The bird. The bird bird bird bird\n", + "\n", + "Step 3800, Loss: 4.5787882804870605, Elapsed Time: 85.77 seconds\n", + "Generated text:\n", + "Once upon a time, the little girl and her friends. She was so happy and she could alwaysKate.\n", + "The little girl was so happy!\n", + "The little girl was so excited!\n", + "The little girl was so happy!\n", + "The little girl was so excited!\n", + "\n", + "\n", + "Step 4000, Loss: 4.602183818817139, Elapsed Time: 40.17 seconds\n", + "Generated text:\n", + "Once upon a time, the little girl and the little girl. She was so excited!\n", + "The little girl was so excited!\n", + "The little girl was so excited!\n", + "The little girl was so excited!\n", + "The little girl was so excited!\n", + "The little girl was so excited!\n", + "The little girl was so happy!\n", + "The little girl was so happy!\n", + "The little girl was so happy!\n", + "\n", + "\n", + "Step 4200, Loss: 4.579535961151123, Elapsed Time: 47.37 seconds\n", + "Generated text:\n", + "Once upon a time, the little girl and her friends. She was so happy to play with her friends.\n", + "The girl was so happy. She had a big, and she had a big, and she had a big, and she had a big, and she was very happy.\n", + "The girl had a big, and the little girl. She was so happy that she had a big and she had a big and she had a big, and she had a big, and the little girl.\n", + "The little girl was so happy. She had a big and the little girl smiled. She was so happy. She was so happy that she had a big and she had a big and the little girl.\n", + "\n", + "\n", + "Step 4400, Loss: 4.503463268280029, Elapsed Time: 59.19 seconds\n", + "Generated text:\n", + "Once upon a time, the little girl and the little girl. She was so excited!\n", + "The little girl was so excited!\n", + "The little girl was so excited!\n", + "The little girl was so excited!\n", + "The little girl was so excited!\n", + "The little girl was so excited!\n", + "The little girl was so excited!\n", + "The little girl was so excited!\n", + "The little girl was so excited!\n", + "The little girl was so happy!\n", + "The little girl was so excited!\n", + "\n", + "\n", + "Step 4600, Loss: 4.479018688201904, Elapsed Time: 49.91 seconds\n", + "Generated text:\n", + "Once upon a time, the little girl and the little girl. She was very happy.\n", + "The little girl had to the little girl.\n", + "The little girl was very happy.\n", + "The little girl was so excited!\n", + "The little girl was so excited!\n", + "The little girl was so excited!\n", + "The little girl was so happy!\n", + "The little girl was so happy!\n", + "The little girl was so happy!\n", + "The little girl was so happy!\n", + "\n", + "\n", + "Step 4800, Loss: 4.536823749542236, Elapsed Time: 48.80 seconds\n", + "Generated text:\n", + "Once upon a time, the diving and the diving. The little girl was very happy.\n", + "The little girl had to the little girl.\n", + "The little girl was very happy.\n", + "The little girl was very happy.\n", + "The little girl was very happy and the little girl.\n", + "The little girl was so happy that she had to the little girl had to the little girl.\n", + "The little girl was so happy that she had to the little girl was very happy.\n", + "\n", + "\n", + "Step 5000, Loss: 4.547352313995361, Elapsed Time: 49.35 seconds\n", + "Generated text:\n", + "Once upon a time, thechron and the diving. The car was very happy.\n", + "The next day, the diving, the diving. The cat and the diving. The cat had a big, the Grad. The cat was very happy. The cat and the demol had a big and the pans. The cat was very happy. The cat was happy. The cat was happy and the pans. The cat was happy to the pans. The cat was happy and the pans and the pans. The cat was happy. The cat was happy and the pans and the pans. The cat had a big, the pans and the pans and the pans. The cat was happy. The cat was happy to the demol was happy.\n", + "\n", + "\n", + "Step 5200, Loss: 4.467193126678467, Elapsed Time: 59.71 seconds\n", + "Generated text:\n", + "Once upon a time, there were Fixes.\n", + "One day, theCold and the little girl was very happy. She was so she had to her.\n", + "The little girl was so happy that she had to her. She was so happy that she had to her mom.\n", + "The little girl was so happy that she could help her. She was so happy that she could help her mom.\n", + "\n", + "\n", + "Step 5400, Loss: 4.455582618713379, Elapsed Time: 45.89 seconds\n", + "Generated text:\n", + "Once upon a time, there were SPD. \n", + "One day, the diving, the diving, and theheter. The little girl was very happy.\n", + "The little girl was so happy that she could help her. She was so happy that she could help her. \n", + "The little girl was so happy that she could help her. She was so happy to the little girl. \n", + "The little girl was so happy that she could help her. She was so happy that she had been a big hug.\n", + "\n", + "\n", + "Step 5600, Loss: 4.450286865234375, Elapsed Time: 51.68 seconds\n", + "Generated text:\n", + "Once upon a time, there were SPD.\n", + "One day, theheter and the little girl was very happy. She was very happy to the little girl.\n", + "The little girl was very happy. She was so excited!\n", + "The little girl was so excited! She had to her mom and she could help her a big, she could make her.\n", + "The little girl was so happy. She was so happy!\n", + "\n", + "\n", + "Step 5800, Loss: 4.39848518371582, Elapsed Time: 48.04 seconds\n", + "Generated text:\n", + "Once upon a time, there were two friends.\n", + "One day, they had a big, and a little girl. The little girl was very happy.\n", + "The little girl was so happy that she had a big, and she had a big, and she was very happy.\n", + "The little girl was so happy! She was so happy that she had a big and she had a big and she had a big and she had a big, but she had a big and she was very happy.\n", + "\n", + "\n", + "Step 6000, Loss: 4.475782871246338, Elapsed Time: 50.52 seconds\n", + "Generated text:\n", + "Once upon a time there were two friends. They bystand and they were playing together.\n", + "One day they had a bigreplace. They had a bigreplace and a bigreplace. They were very happy. They were so happy.\n", + " appreciate the neutrality to play with their friends. They had a big and had a big and a bigreplace. They had so much fun. They had so much fun.\n", + "\n", + "\n", + "Step 6200, Loss: 4.4102349281311035, Elapsed Time: 45.92 seconds\n", + "Generated text:\n", + "Once upon a time there were two friends. They had a big, and they were very happy.\n", + "One day, they had a big, they had a big, and a bigreplace. The next time, they had a bigreplace. The ended and the little girl and the little girl and the little girl and the little girl. The little girl was very happy.\n", + "The little girl was happy and the little girl was very happy. She said, \"I'm sorry, \"I'm sorry, I will be careful.\"\n", + "The little girl and said, \"I'm sorry, I will be careful and I can be careful.\" The little girl smiled and said, \"I'm sorry for you.\"\n", + "The little girl smiled and said, \"I'm sorry, \"I'm sorry, I can't be careful with you.\"\n", + "The little girl and said, \"I'm sorry, \"Yes, I can be careful.\"\n", + "The little girl and said, \"Yes, \"I'm sorry, you, I can be careful.\"\n", + "The little girl was so happy to the little girl and said, \"I'm sorry for you can help.\"\n", + "The little girl and said, \"I'm sorry to the little girl and the little girl.\n", + "The!!!!\n", + "\n", + "Step 6400, Loss: 4.447348117828369, Elapsed Time: 84.21 seconds\n", + "Generated text:\n", + "Once upon a time there were two friends. They were playing together and they had a big house.\n", + "One day they had a big, they had a bigreplace. They were playing together. They had a bigreplace and they had a big and they had a bigreplace.\n", + "The next time. They had a big and a bigreplace. They had a bigreplace. They had a bigreplace and a bigreplace. They had so much fun. They had so much fun.\n", + "\n", + "\n", + "Step 6600, Loss: 4.448988437652588, Elapsed Time: 50.07 seconds\n", + "Generated text:\n", + "Once upon a time there were two friends. They were playing together and they had a big house.\n", + "One day, they had a bigreplace. They were playing together. They were very happy. They had a big and had a bigreplace.\n", + "The next time, they had a bigreplace. They had a big and a bigreplace. They had a bigreplace. They had a bigreplace. They had a bigreplace. They had a big and a bigreplace. They had a bigreplace.\n", + "The next to the other. They had a bigreplace. They had a bigreplace. They had a bigreplace. They had a big and a bigreplace. They had a bigreplace. They had a big and a bigreplace. They had a bigreplace. They had a bigreplace. They had a bigreplace. They had a bigreplace. They had a big and a bigreplace. They had a bigreplace. They had a big and a bigreplace. They had a big. They had a bigreplace. They had so happy. They had a bigreplace. They had a big and their mom. They had a big. They had a big and a big and a big and a big. They had a big. They had!!!!\n", + "\n", + "Step 6800, Loss: 4.477509021759033, Elapsed Time: 84.28 seconds\n", + "Generated text:\n", + "Once upon a time there were two friends. They were playing together and they had a big, and they were very happy.\n", + "One day, they had a big, they had a bigreplace. They had a big, and a bigreplace. The next time, and the uprising. The Stre they had a bigreplace and the other animals. The Stre they had a big, and the little girl and the little girl and the little girl and the little girl and the little girl.\n", + "The little girl and the little girl was very happy. They said, \"I'm sorry, I'm so much fun, but we can be careful with you, but we can be careful and the little girl.\n", + "The little girl and the little girl was very happy and the little girl was very happy. The little girl was so happy and the little girl was so happy.\n", + "\n", + "\n", + "Step 7000, Loss: 4.44502067565918, Elapsed Time: 66.80 seconds\n", + "Generated text:\n", + "Once upon a time there were two friends. They were playing together and they had a bigreplace.\n", + "One day they decided to play together. They were playing together and they had a bigreplace. They had a bigreplace and they had a bigreplace.\n", + "The next day, they had a bigreplace and a bigreplace. They had a bigreplace and a bigreplace. They had a bigreplace and a bigreplace. They had so much fun.\n", + "The little girl and the little girl said, \"Let's go sending. It's Citiz and the little girl. It was so much fun and the little girl and the little girl and the little girl.\n", + "The little girl was so happy. They said, \"I'm sorry, I can be careful with you. I can be careful with you.\n", + "\n", + "\n", + "Step 7200, Loss: 4.388001441955566, Elapsed Time: 64.41 seconds\n", + "Generated text:\n", + "Once upon a time there were two friends. They were playing together and they had a big, and they were playing together.\n", + "One day they had a big, they had a bigreplace. They had a bigreplace and they had a bigreplace. They had a bigreplace. They were very happy.\n", + "The next day, they had a bigreplace and they had a bigreplace. They had a bigreplace and they had a bigreplace. They had a bigreplace and they had a big and a bigreplace. They had a bigreplace.\n", + "The next time, they had a bigreplace. They had a big and a bigreplace. They had a bigreplace. They had a bigreplace and a bigreplace. They had a bigreplace. They had a big and a bigreplace. They had a bigreplace. They were happy. They had a big and a big and a bigreplace. They were happy.\n", + "\n", + "\n", + "Step 7400, Loss: 4.435971736907959, Elapsed Time: 68.92 seconds\n", + "Generated text:\n", + "Once upon a time there were two friends. They had a big, they had a big, and they were very happy. \n", + "One day, they had a big, they had a big, and they had a bigreplace. They had a bigreplace and they had a bigreplace. They had a bigreplace and they had a bigreplace. \n", + "The next time, they had a big, and they had a bigreplace. They had a bigreplace and they had a bigreplace. They had a bigreplace. They had a big, and they had a bigreplace. \n", + "The next time, they had a bigreplace. They had a bigreplace and a big, and a big, and a bigreplace. They had a bigreplace. They had a bigreplace. They had a bigreplace and a big and a bigreplace. They had a bigreplace.\n", + "\n", + "\n", + "Step 7600, Loss: 4.371002674102783, Elapsed Time: 68.06 seconds\n", + "Generated text:\n", + "Once upon a time there were two friends. They had a big, and they were very happy.\n", + "One day, they decided to play together. They had a big, and they had a bigreplace. They had a bigreplace. They had a bigreplace and a bigreplace. They were very happy.\n", + "They had a bigreplace and a bigreplace. They had a bigreplace. They had so much fun. They had so much fun. They had a bigreplace. They had a big and a bigreplace. They had so much fun. They had so much fun.\n", + "\n", + "\n", + "Step 7800, Loss: 4.364511966705322, Elapsed Time: 54.37 seconds\n", + "Generated text:\n", + "Once upon a time there were two friends. They were playing together and they had a big, and they were very happy.\n", + "The next day, they had a bigreplace. They had a bigreplace and they had a bigreplace. The1980 and the sun was very happy.\n", + "The next time, they had a big, and the sun. The cat and the cat had a bigreplace. The cat said, \"I'm sorry, but it's go home. It's go home.\"\n", + "The cat and the cat said, \"I'm sorry, but it's go sending.\"\n", + "The cat and the cat and the cat said, \"I'm sorry, but it's go sending. I can't know, but it's go home.\"\n", + "The cat said, \"I'm sorry, but it's go home.\"\n", + "The cat said, \"I'm sorry, I can't be careful. I can't be careful, but it.\"\n", + "The cat and the cat and the cat and the cat and the cat and the cat.\n", + "The cat was happy to the cat and said, \"I'm sorry, newcomer. I can help you.\"\n", + "\n", + "\n", + "Step 8000, Loss: 4.309453010559082, Elapsed Time: 79.94 seconds\n", + "Generated text:\n", + "Once upon a time there were two friends. They were playing together.\n", + "The next time, they had a bigreplace. They were very happy. They were so much.\n", + "The Ammo a big, they had a bigreplace. They were so much. They were so much that they had so much fun!\n", + "\n", + "\n", + "Step 8200, Loss: 4.42409610748291, Elapsed Time: 42.27 seconds\n", + "Generated text:\n", + "Once upon a time there was a little girl. She was very happy. She was so happy that she could help her.\n", + "The little girl was so happy. She was so excited to her mom.\n", + "\n", + "\n", + "Final generated text:\n", + "Once upon a time there was very happy.\n", + ", the sun was very happy and he had a big and he was very happy.\n", + "The next day, he saw a big tree and he had a big tree. He wanted to help his friends and he could help.\n", + "The next day, he saw a little bird and he had a big tree. The bird was very happy and he had a big bird.\n", + "The bird was very happy and he had a big bird. He was so happy and he had a big bird.\n", + "\n" + ] + } + ], + "source": [ + "metrics = nnx.MultiMetric(\n", + " loss=nnx.metrics.Average('loss'),\n", + ")\n", + "rng = jax.random.PRNGKey(0)\n", + "\n", + "start_prompt = \"Once upon a time\"\n", + "start_tokens = tokenizer.encode(start_prompt)[:maxlen]\n", + "generated_text = lora_model.generate_text(\n", + " maxlen, start_tokens\n", + ")\n", + "print(f\"Initial generated text:\\n{generated_text}\\n\")\n", + "\n", + "\n", + "metrics_history = {\n", + " 'train_loss': [],\n", + "}\n", + "\n", + "prep_target_batch = jax.vmap(lambda tokens: jnp.concatenate((tokens[1:], jnp.array([0]))))\n", + "\n", + "step = 0\n", + "for epoch in range(num_epochs):\n", + " start_time = time.time()\n", + " for batch in text_dl:\n", + " if len(batch) % len(jax.devices()) != 0:\n", + " continue\n", + " input_batch = jnp.array(jnp.array(batch).T)\n", + " target_batch = prep_target_batch(input_batch)\n", + " train_step(lora_model, optimizer, metrics, jax.device_put((input_batch, target_batch), NamedSharding(mesh, P('batch', None))))\n", + "\n", + " if (step + 1) % 200 == 0:\n", + " for metric, value in metrics.compute().items():\n", + " metrics_history[f'train_{metric}'].append(value)\n", + " metrics.reset()\n", + "\n", + " elapsed_time = time.time() - start_time\n", + " print(f\"Step {step + 1}, Loss: {metrics_history['train_loss'][-1]}, Elapsed Time: {elapsed_time:.2f} seconds\")\n", + " start_time = time.time()\n", + "\n", + " generated_text = lora_model.generate_text(\n", + " maxlen, start_tokens\n", + " )\n", + " print(f\"Generated text:\\n{generated_text}\\n\")\n", + " step += 1\n", + "\n", + "generated_text = lora_model.generate_text(\n", + " maxlen, start_tokens\n", + ")\n", + "print(f\"Final generated text:\\n{generated_text}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "thaLs6TD0lt5" + }, + "source": [ + "Visualize the training loss." + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 472 + }, + "execution": { + "iopub.status.busy": "2025-03-17T06:57:45.282555Z", + "iopub.status.idle": "2025-03-17T06:57:45.283541Z", + "shell.execute_reply": "2025-03-17T06:57:45.282809Z", + "shell.execute_reply.started": "2025-03-17T06:57:45.282762Z" + }, + "id": "B6Eg1Cz2y_iP", + "outputId": "227b6ad5-21de-45d8-ab0b-7dc9a931834f" + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjcAAAHHCAYAAABDUnkqAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAT4tJREFUeJzt3XlYVGX/BvD7wDDDvu8KCKigIm6IopmalpqVW2qmqZn2Zla2vmVlqfXm0q8yKzUrpdS0zNSyzK0wzQVR3FBRZJVV2YZ9YOb8/kCmiEXAmTkwc3+ua64Xzpxz5vt4Xp275zzPeQRRFEUQERERGQkzqQsgIiIi0iWGGyIiIjIqDDdERERkVBhuiIiIyKgw3BAREZFRYbghIiIio8JwQ0REREaF4YaIiIiMCsMNERERGRWGGyLSu5kzZ6JDhw4tOnbRokUQBEG3BRGRUWO4ITJhgiA06RUVFSV1qZKYOXMmbG1tpS6DiJpJ4NpSRKZr06ZNtX7/5ptvsH//fmzcuLHW9nvvvRceHh4t/pzKykpoNBooFIpmH1tVVYWqqipYWlq2+PNbaubMmfjhhx9QXFxs8M8mopaTSV0AEUln2rRptX4/fvw49u/fX2f7v5WWlsLa2rrJn2NhYdGi+gBAJpNBJuM/VUTUdLwtRUSNGjJkCEJCQnDq1CncfffdsLa2xuuvvw4A2LVrF0aPHg1vb28oFAoEBgbinXfegVqtrnWOf4+5SU5OhiAI+L//+z+sW7cOgYGBUCgU6Nu3L06ePFnr2PrG3AiCgGeeeQY7d+5ESEgIFAoFunXrht9++61O/VFRUQgLC4OlpSUCAwPx+eef63wcz7Zt29CnTx9YWVnB1dUV06ZNQ3p6eq19srKy8Pjjj6N9+/ZQKBTw8vLCmDFjkJycrN0nJiYGI0aMgKurK6ysrODv749Zs2bprE4iU8H/HCKi28rNzcWoUaPwyCOPYNq0adpbVJGRkbC1tcWLL74IW1tb/P7773jrrbegVCrx/vvv3/a83377LYqKivCf//wHgiBgxYoVGD9+PBITE2/b23PkyBH8+OOPePrpp2FnZ4dVq1ZhwoQJSE1NhYuLCwAgNjYWI0eOhJeXFxYvXgy1Wo0lS5bAzc3tzv9QbomMjMTjjz+Ovn37YunSpcjOzsbHH3+Mv/76C7GxsXB0dAQATJgwAXFxcXj22WfRoUMH5OTkYP/+/UhNTdX+ft9998HNzQ2vvfYaHB0dkZycjB9//FFntRKZDJGI6JZ58+aJ//5nYfDgwSIAce3atXX2Ly0trbPtP//5j2htbS2Wl5drt82YMUP08/PT/p6UlCQCEF1cXMS8vDzt9l27dokAxJ9//lm77e23365TEwBRLpeLCQkJ2m1nz54VAYiffPKJdtuDDz4oWltbi+np6dptV69eFWUyWZ1z1mfGjBmijY1Ng++rVCrR3d1dDAkJEcvKyrTbd+/eLQIQ33rrLVEURTE/P18EIL7//vsNnmvHjh0iAPHkyZO3rYuIGsfbUkR0WwqFAo8//nid7VZWVtqfi4qKcPPmTQwaNAilpaW4fPnybc87efJkODk5aX8fNGgQACAxMfG2xw4fPhyBgYHa30NDQ2Fvb689Vq1W48CBAxg7diy8vb21+3Xs2BGjRo267fmbIiYmBjk5OXj66adrDXgePXo0goOD8csvvwCo/nOSy+WIiopCfn5+veeq6eHZvXs3KisrdVIfkaliuCGi22rXrh3kcnmd7XFxcRg3bhwcHBxgb28PNzc37WDkwsLC257X19e31u81QaehANDYsTXH1xybk5ODsrIydOzYsc5+9W1riZSUFABAUFBQnfeCg4O17ysUCixfvhx79uyBh4cH7r77bqxYsQJZWVna/QcPHowJEyZg8eLFcHV1xZgxY7BhwwZUVFTopFYiU8JwQ0S39c8emhoFBQUYPHgwzp49iyVLluDnn3/G/v37sXz5cgCARqO57XnNzc3r3S424QkVd3KsFJ5//nlcuXIFS5cuhaWlJRYuXIguXbogNjYWQPUg6R9++AHHjh3DM888g/T0dMyaNQt9+vThVHSiZmK4IaIWiYqKQm5uLiIjIzF//nw88MADGD58eK3bTFJyd3eHpaUlEhIS6rxX37aW8PPzAwDEx8fXeS8+Pl77fo3AwEC89NJL2LdvHy5cuACVSoUPPvig1j79+/fH//73P8TExGDz5s2Ii4vD1q1bdVIvkalguCGiFqnpOflnT4lKpcLq1aulKqkWc3NzDB8+HDt37kRGRoZ2e0JCAvbs2aOTzwgLC4O7uzvWrl1b6/bRnj17cOnSJYwePRpA9XOBysvLax0bGBgIOzs77XH5+fl1ep169uwJALw1RdRMnApORC0yYMAAODk5YcaMGXjuuecgCAI2btzYqm4LLVq0CPv27cPAgQMxd+5cqNVqfPrppwgJCcGZM2eadI7Kykq8++67dbY7Ozvj6aefxvLly/H4449j8ODBmDJlinYqeIcOHfDCCy8AAK5cuYJhw4Zh0qRJ6Nq1K2QyGXbs2IHs7Gw88sgjAICvv/4aq1evxrhx4xAYGIiioiJ88cUXsLe3x/3336+zPxMiU8BwQ0Qt4uLigt27d+Oll17Cm2++CScnJ0ybNg3Dhg3DiBEjpC4PANCnTx/s2bMHL7/8MhYuXAgfHx8sWbIEly5datJsLqC6N2rhwoV1tgcGBuLpp5/GzJkzYW1tjWXLluHVV1+FjY0Nxo0bh+XLl2tnQPn4+GDKlCk4ePAgNm7cCJlMhuDgYHz//feYMGECgOoBxdHR0di6dSuys7Ph4OCA8PBwbN68Gf7+/jr7MyEyBVxbiohMztixYxEXF4erV69KXQoR6QHH3BCRUSsrK6v1+9WrV/Hrr79iyJAh0hRERHrHnhsiMmpeXl6YOXMmAgICkJKSgjVr1qCiogKxsbHo1KmT1OURkR5wzA0RGbWRI0diy5YtyMrKgkKhQEREBN577z0GGyIjxp4bIiIiMiocc0NERERGheGGiIiIjIrJjbnRaDTIyMiAnZ0dBEGQuhwiIiJqAlEUUVRUBG9vb5iZNd43Y3LhJiMjAz4+PlKXQURERC2QlpaG9u3bN7qPyYUbOzs7ANV/OPb29hJXQ0RERE2hVCrh4+Oj/R5vjMmFm5pbUfb29gw3REREbUxThpRwQDEREREZFYYbIiIiMioMN0RERGRUGG6IiIjIqDDcEBERkVFhuCEiIiKjwnBDRERERoXhhoiIiIwKww0REREZFYYbIiIiMioMN0RERGRUGG6IiIjIqDDc6EiVWoMcZTlSc0ulLoWIiMikMdzoSHRyHsLfO4gnvj4pdSlEREQmTfJwk56ejmnTpsHFxQVWVlbo3r07YmJiGtw/KioKgiDUeWVlZRmw6rqcrOUAgPxSlaR1EBERmTqZlB+en5+PgQMHYujQodizZw/c3Nxw9epVODk53fbY+Ph42Nvba393d3fXZ6m35WxTE24qIYoiBEGQtB4iIiJTJWm4Wb58OXx8fLBhwwbtNn9//yYd6+7uDkdHRz1V1nyO1hYAALVGhLK8Cg5WFhJXREREZJokvS31008/ISwsDBMnToS7uzt69eqFL774oknH9uzZE15eXrj33nvx119/NbhfRUUFlEplrZc+KGTmsJGbAwDyS3hrioiISCqShpvExESsWbMGnTp1wt69ezF37lw899xz+Prrrxs8xsvLC2vXrsX27duxfft2+Pj4YMiQITh9+nS9+y9duhQODg7al4+Pj76aAycbjrshIiKSmiCKoijVh8vlcoSFheHo0aPabc899xxOnjyJY8eONfk8gwcPhq+vLzZu3FjnvYqKClRUVGh/VyqV8PHxQWFhYa0xO7rw4CdHcD69EOtnhuGeYA+dnpuIiMiUKZVKODg4NOn7W9KeGy8vL3Tt2rXWti5duiA1NbVZ5wkPD0dCQkK97ykUCtjb29d66Yu256akUm+fQURERI2TNNwMHDgQ8fHxtbZduXIFfn5+zTrPmTNn4OXlpcvSWsTp1qBi3pYiIiKSjqSzpV544QUMGDAA7733HiZNmoTo6GisW7cO69at0+6zYMECpKen45tvvgEArFy5Ev7+/ujWrRvKy8vx5Zdf4vfff8e+ffukaoYWn3VDREQkPUnDTd++fbFjxw4sWLAAS5Ysgb+/P1auXImpU6dq98nMzKx1m0qlUuGll15Ceno6rK2tERoaigMHDmDo0KFSNKGWmnCTx9tSREREkpF0QLEUmjMgqbk2HkvGwl1xGNnNE2sf66PTcxMREZmyNjOg2NhwKjgREZH0GG50iGNuiIiIpMdwo0N/hxuOuSEiIpIKw40OOdncmgpeooKJDWUiIiJqNRhudKim56ZKI6KookriaoiIiEwTw40OWVqYw/rW4pkFnA5OREQkCYYbHdM+64aDiomIiCTBcKNj2nE3DDdERESSYLjRMe2MqRKGGyIiIikw3OgYp4MTERFJi+FGx7Qrg7PnhoiISBIMNzpWswQDBxQTERFJg+FGx5xvhZsChhsiIiJJMNzomGPNVHDeliIiIpIEw42OOVvX9NxwQDEREZEUGG50zPHWgGL23BAREUmD4UbH/h5zU8nFM4mIiCTAcKNjNc+5Uak1KFGpJa6GiIjI9DDc6JiV3ByWFtV/rHzWDRERkeEx3OiBs/YpxQw3REREhsZwowecDk5ERCQdhhs9+OegYiIiIjIshhs94HRwIiIi6TDc6AGXYCAiIpIOw40eaMfcMNwQEREZHMONHjjfui2VX8IxN0RERIbGcKMHTjacCk5ERCQVhhs9cOJUcCIiIskw3OgBp4ITERFJh+FGD7RTwUtVXDyTiIjIwBhu9KCm50ZVpUFZJRfPJCIiMiSGGz2wsjCHXFb9R8txN0RERIbFcKMHgiD8vXgmp4MTEREZFMONnnA6OBERkTQYbvTEqeZBfgw3REREBsVwoyfanhuOuSEiIjIohhs9cdJOB+eYGyIiIkNiuNGTmgHFXBmciIjIsBhu9MSRSzAQERFJguFGT5w5W4qIiEgSDDd68veAYo65ISIiMiSGGz3hVHAiIiJpMNzoiZM1b0sRERFJgeFGT2puS5VXalCm4uKZREREhsJwoyc2cnPIzav/eNl7Q0REZDgMN3oiCAIcax7kx+ngREREBsNwo0ecDk5ERGR4DDd69PegYk4HJyIiMhSGGz1ysrk1HZy3pYiIiAyG4UaPOB2ciIjI8Bhu9EgbbthzQ0REZDAMN3pU86ybPI65ISIiMhiGGz2qWYKhgLeliIiIDIbhRo+0PTe8LUVERGQwDDd65HxrzE0Bb0sREREZDMONHtUMKGbPDRERkeEw3OhRzXNuyirVKK/k4plERESGwHCjR7YKGWRmAgA+64aIiMhQGG70SBAEDiomIiIyMIYbPft7OjgHFRMRERkCw42ecVAxERGRYTHc6JmzTc10cIYbIiIiQ2C40TNHbc8Nb0sREREZAsONnjnfmg7O2VJERESGIXm4SU9Px7Rp0+Di4gIrKyt0794dMTExjR4TFRWF3r17Q6FQoGPHjoiMjDRMsS2gXRmc4YaIiMggJA03+fn5GDhwICwsLLBnzx5cvHgRH3zwAZycnBo8JikpCaNHj8bQoUNx5swZPP/885g9ezb27t1rwMqbjgOKiYiIDEsm5YcvX74cPj4+2LBhg3abv79/o8esXbsW/v7++OCDDwAAXbp0wZEjR/DRRx9hxIgReq23JWqeUsyp4ERERIYhac/NTz/9hLCwMEycOBHu7u7o1asXvvjii0aPOXbsGIYPH15r24gRI3Ds2DF9ltpi7LkhIiIyLEnDTWJiItasWYNOnTph7969mDt3Lp577jl8/fXXDR6TlZUFDw+PWts8PDygVCpRVlZWZ/+KigoolcpaL0PiVHAiIiLDkvS2lEajQVhYGN577z0AQK9evXDhwgWsXbsWM2bM0MlnLF26FIsXL9bJuVqiZip4iUqNiio1FDJzyWohIiIyBZL23Hh5eaFr1661tnXp0gWpqakNHuPp6Yns7Oxa27Kzs2Fvbw8rK6s6+y9YsACFhYXaV1pamm6KbyJ7SxnMby2eyXE3RERE+idpz83AgQMRHx9fa9uVK1fg5+fX4DERERH49ddfa23bv38/IiIi6t1foVBAoVDcebEtJAgCnKwtcLNYhbwSFTzsLSWrhYiIyBRI2nPzwgsv4Pjx43jvvfeQkJCAb7/9FuvWrcO8efO0+yxYsADTp0/X/v7UU08hMTER//3vf3H58mWsXr0a33//PV544QUpmtAk2mfdcFAxERGR3kkabvr27YsdO3Zgy5YtCAkJwTvvvIOVK1di6tSp2n0yMzNr3aby9/fHL7/8gv3796NHjx744IMP8OWXX7bKaeA1/n6QH29LERER6Zukt6UA4IEHHsADDzzQ4Pv1PX14yJAhiI2N1WNVulXzrJs8zpgiIiLSO8mXXzAF2ungvC1FRESkdww3BqBdGZw9N0RERHrHcGMAztY1D/LjmBsiIiJ9Y7gxAEfrW2NueFuKiIhI7xhuDKBmzE0+b0sRERHpHcONAThaM9wQEREZCsONAWh7bko45oaIiEjfGG4MoGZAcXFFFVRVGomrISIiMm4MNwZgZynDrbUzUcBbU0RERHrFcGMAZmYCl2AgIiIyEIYbA+F0cCIiIsNguDEQTgcnIiIyDIYbA3HidHAiIiKDYLgxEG244W0pIiIivWK4MRAnGw4oJiIiMgSGGwNxujWgmD03RERE+sVwYyA1PTd5HHNDRESkVww3BsLn3BARERkGw42BONvwthQREZEhMNwYCKeCExERGQbDjYHUhJui8ipUqrl4JhERkb4w3BiIvZXFPxbP5LgbIiIifWG4MRBzMwEOVrfG3fDWFBERkd4w3BiQdjo4BxUTERHpDcONAdWMuylgzw0REZHeMNwYUE24ySvhmBsiIiJ9YbgxIO2zbthzQ0REpDcMNwbElcGJiIj0j+HGgLgyOBERkf4x3BiQdmVw3pYiIiLSG4YbA/p7QDHDDRERkb4w3BhQzW0pTgUnIiLSH4YbA2LPDRERkf4x3BiQ862eG2V5Faq4eCYREZFeMNwYkIOVBYSaxTPLOGOKiIhIHxhuDOifi2dy3A0REZF+MNwYGJdgICIi0i+GGwOredYNBxUTERHpB8ONgXFlcCIiIv1iuDGwmmfd5DHcEBER6QXDjYE5ax/kxzE3RERE+sBwY2COHHNDRESkVww3BubMMTdERER6xXBjYI5cgoGIiEivGG4MrGbMTT7H3BAREekFw42B1TznJp+3pYiIiPSC4cbAaqaCF5ZVQq0RJa6GiIjI+DDcGJjjrbWlRLE64BAREZFuMdwYmMzcDPaWMgAcVExERKQPDDcS+PtBfgw3REREusZwIwFOByciItIfhhsJ/D0dnOGGiIhI1xhuJOConQ7OAcVERES6xnAjgZolGPJ5W4qIiEjnGG4k4MTbUkRERHrDcCMBJ+2AYt6WIiIi0jWGGwk421SPueFUcCIiIt1juJGAdio4ww0REZHOMdxIQDsVnAOKiYiIdI7hRgI1U8G5eCYREZHuMdxIoGZAsUYElFw8k4iISKcYbiRgYW4GV1sFAODM9QJpiyEiIjIyDDcSebCHFwDg2xOpEldCRERkXBhuJDK1nx8A4OClbGQUlElcDRERkfFguJFIR3db9A9whkYEtkaz94aIiEhXJA03ixYtgiAItV7BwcEN7h8ZGVlnf0tLSwNWrFvT+lf33mw9mYZKtUbiaoiIiIyDTOoCunXrhgMHDmh/l8kaL8ne3h7x8fHa3wVB0Ftt+nZfV0+42iqQU1SBAxezMaq7l9QlERERtXmShxuZTAZPT88m7y8IQrP2b83kMjNM7tsen/1xDZtOpDDcEBER6YDkY26uXr0Kb29vBAQEYOrUqUhNbXz8SXFxMfz8/ODj44MxY8YgLi7OQJXqx5RwXwgC8FdCLhJvFEtdDhERUZsnabjp168fIiMj8dtvv2HNmjVISkrCoEGDUFRUVO/+QUFBWL9+PXbt2oVNmzZBo9FgwIABuH79eoOfUVFRAaVSWevVmrR3ssbQIHcAnBZORESkC4Iois1+/n9aWhoEQUD79u0BANHR0fj222/RtWtXPPnkky0upqCgAH5+fvjwww/xxBNP3Hb/yspKdOnSBVOmTME777xT7z6LFi3C4sWL62wvLCyEvb19i2vVpd8vZ2NWZAwcrCxw4vVhsLQwl7okIiKiVkWpVMLBwaFJ398t6rl59NFH8ccffwAAsrKycO+99yI6OhpvvPEGlixZ0pJTAgAcHR3RuXNnJCQkNGl/CwsL9OrVq9H9FyxYgMLCQu0rLS2txfXpy+DO7mjnaIXCskr8ci5T6nKIiIjatBaFmwsXLiA8PBwA8P333yMkJARHjx7F5s2bERkZ2eJiiouLce3aNXh5NW1grVqtxvnz5xvdX6FQwN7evtartTE3E/BoP18AwKYTKRJXQ0RE1La1KNxUVlZCoaheG+nAgQN46KGHAADBwcHIzGx6z8PLL7+MQ4cOITk5GUePHsW4ceNgbm6OKVOmAACmT5+OBQsWaPdfsmQJ9u3bh8TERJw+fRrTpk1DSkoKZs+e3ZJmtCqTwnwgMxMQm1qAuIxCqcshIiJqs1oUbrp164a1a9fi8OHD2L9/P0aOHAkAyMjIgIuLS5PPc/36dUyZMgVBQUGYNGkSXFxccPz4cbi5uQEAUlNTa4Wl/Px8zJkzB126dMH9998PpVKJo0ePomvXri1pRqviZqfAiJDqKe6bObCYiIioxVo0oDgqKgrjxo2DUqnEjBkzsH79egDA66+/jsuXL+PHH3/UeaG60pwBSYZ27FoupnxxHNZyc5x4fRjsLC2kLomIiKhVaM73d4se4jdkyBDcvHkTSqUSTk5O2u1PPvkkrK2tW3JKAtA/wBmBbja4dqMEO89k4LFbyzMQERFR07XotlRZWRkqKiq0wSYlJQUrV65EfHw83N3ddVqgKREEQbta+ObjKWhBpxoREZHJa1G4GTNmDL755hsA1c+m6devHz744AOMHTsWa9as0WmBpmZC7/awtDDD5awinErJl7ocIiKiNqdF4eb06dMYNGgQAOCHH36Ah4cHUlJS8M0332DVqlU6LdDUOFhb4MFQbwAcWExERNQSLQo3paWlsLOzAwDs27cP48ePh5mZGfr374+UFD6n5U5NuzXW5pdzmcgrUUlcDRERUdvSonDTsWNH7Ny5E2lpadi7dy/uu+8+AEBOTk6rm4HUFvXwcUT3dg5QqTXYFtP6nqhMRETUmrUo3Lz11lt4+eWX0aFDB4SHhyMiIgJAdS9Or169dFqgqZp664nF30anQqPhwGIiIqKmalG4efjhh5GamoqYmBjs3btXu33YsGH46KOPdFacKXuopzfsFDKk5JbiSMJNqcshIiJqM1oUbgDA09MTvXr1QkZGBq5fvw4ACA8PR3BwsM6KM2XWchnG924HANh0nOOYiIiImqpF4Uaj0WDJkiVwcHCAn58f/Pz84OjoiHfeeQcajUbXNZqsqbcGFh+8nIPMwjKJqyEiImobWhRu3njjDXz66adYtmwZYmNjERsbi/feew+ffPIJFi5cqOsaTVZnDzuE+ztDrRGxNZoDi4mIiJqiRWtLeXt7Y+3atdrVwGvs2rULTz/9NNLT03VWoK615rWl6rPrTDrmbz0DD3sFjrx6DyzMW3wnkYiIqM1qzvd3i74p8/Ly6h1bExwcjLy8vJackhowMsQTLjZyZCsrsP9ittTlEBERtXotCjc9evTAp59+Wmf7p59+itDQ0Dsuiv6mkJljSnj1tPDP/0zkelNERES30aJVwVesWIHRo0fjwIED2mfcHDt2DGlpafj11191WiABMwZ0wBeHE3E2rQDHEnMxINBV6pKIiIharRb13AwePBhXrlzBuHHjUFBQgIKCAowfPx5xcXHYuHGjrms0eW52CkwK8wEArIm6JnE1RERErVuLBhQ35OzZs+jduzfUarWuTqlzbW1AcY20vFIM+b8oqDUidj97F0LaOUhdEhERkcHofUAxGZ6PszUeCPUCAKw5xN4bIiKihjDctCFzhwQCAPacz0TSzRKJqyEiImqdGG7akGBPe9wT7A6NCKz7k703RERE9WnWbKnx48c3+n5BQcGd1EJNMHdIIH6/nIPtp9Lx/PDO8LC3lLokIiKiVqVZ4cbBofFBrA4ODpg+ffodFUSN69vBGWF+TohJycf6I0lYcH8XqUsiIiJqVXQ6W6otaKuzpf7p4KVsPPF1DGzk5jj62jA4WFtIXRIREZFecbaUkbsn2B1BHnYoUamx6USK1OUQERG1Kgw3bZAgCNqZU+uPJKG8svU+V4iIiMjQGG7aqAdCvdDeyQq5JSpsi0mTuhwiIqJWg+GmjZKZm+HJuwMAVC+oWaXWSFwRERFR68Bw04ZN7OMDFxs5rueXYfe5TKnLISIiahUYbtowK7k5Hh/YAUD1gpomNvGNiIioXgw3bdxjER1gq5AhPrsIf8TnSF0OERGR5Bhu2jgHKwtM7ecLoLr3hoiIyNQx3BiBWXf5Q25uhpPJ+TiZnCd1OURERJJiuDECHvaWmNCnHQBgLXtviIjIxDHcGIkn7w6EmQAcvJyDy1lKqcshIiKSDMONkfB3tcGo7l4AgM8PJUpcDRERkXQYbozI3MHVSzL8dDYDaXmlEldDREQkDYYbIxLSzgGDOrlCrRHx2R8JUpdDREQkCYYbI/PcsE4AgK0n0xDDmVNERGSCGG6MTN8OzpgU1h4A8Or2c6io4orhRERkWhhujNDr93eBq60C126UYPUfnBpORESmheHGCDlay7Hooa4AgNVRCbiaXSRxRURERIbDcGOkRnf3wrBgd1SqRbz243loNFxUk4iITAPDjZESBAHvjA2Bjdwcp1LysflEitQlERERGQTDjRHzdrTCf0cGAwCW/xaPzMIyiSsiIiLSP4YbIzetvx96+TqiuKIKb+2Kgyjy9hQRERk3hhsjZ24mYNn4UFiYC9h/MRu/XciSuiQiIiK9YrgxAUGednjq1tIMb/0Uh8KySokrIiIi0h+GGxMxb2hHBLjZ4EZRBZbtuSR1OURERHrDcGMiLC3MsWx8KABgS3QajifmSlwRERGRfjDcmJBwf2dMCfcFALz+43mUV3JpBiIiMj4MNybmtVHBcLdTIPFmCVcOJyIio8RwY2IcrCywZEw3AMCaqGu4nKWUuCIiIiLdYrgxQSNDvHBfVw9UaUS8tv081FyagYiIjAjDjYlaMiYEdgoZzqQVYOOxZKnLISIi0hmGGxPl6WCJ/46qXpph2W+XcSG9UOKKiIiIdIPhxoRNDffFkCA3lFdqMOebGOQoy6UuiYiI6I4x3JgwMzMBq6b0QqCbDTILy/HkxlOcHk5ERG0ew42Js7e0wFcz+sLBygJn0grw+o/nubgmERG1aQw3hA6uNlgztTfMzQT8GJuOz/9MlLokIiKiFmO4IQDAgI6uWPRgVwDA8t8u48DFbIkrIiIiahmGG9J6LKIDpvbzhSgC87fGIj6rSOqSiIiImo3hhmpZ9FA39A9wRolKjdnfnEReiUrqkoiIiJqF4YZqsTA3w5qpfeDrbI20vDLM3XQKqiqN1GURERE1GcMN1eFkI8dXM8Jgq5DhRFIe3v4pjjOoiIiozWC4oXp18rDDqik9IQjAluhUfHMsReqSiIiImkTScLNo0SIIglDrFRwc3Ogx27ZtQ3BwMCwtLdG9e3f8+uuvBqrW9NwT7IEFt5ZoWLL7Ig5fvSFxRURERLcnec9Nt27dkJmZqX0dOXKkwX2PHj2KKVOm4IknnkBsbCzGjh2LsWPH4sKFCwas2LTMGRSA8b3bQa0RMW/zaSTeKJa6JCIiokZJHm5kMhk8PT21L1dX1wb3/fjjjzFy5Ei88sor6NKlC9555x307t0bn376qQErNi2CIOC9cd3R29cRyvIqzP46BukFZVKXRURE1CDJw83Vq1fh7e2NgIAATJ06FampqQ3ue+zYMQwfPrzWthEjRuDYsWMNHlNRUQGlUlnrRc1jaWGOtY/1gbeDJRJvluCBVYdx5OpNqcsiIiKql6Thpl+/foiMjMRvv/2GNWvWICkpCYMGDUJRUf0Pj8vKyoKHh0etbR4eHsjKymrwM5YuXQoHBwfty8fHR6dtMBXudpb47j8RCGlnj/zSSkxffwKf/ZEAjYazqIiIqHWRNNyMGjUKEydORGhoKEaMGIFff/0VBQUF+P7773X2GQsWLEBhYaH2lZaWprNzmxofZ2v88NQATA7zgUYE3t8bjyc3xqCwrFLq0oiIiLQkvy31T46OjujcuTMSEhLqfd/T0xPZ2bXXPMrOzoanp2eD51QoFLC3t6/1opaztDDH8odDsWx8d8hlZjhwKQcPfXoEFzN4u4+IiFqHVhVuiouLce3aNXh5edX7fkREBA4ePFhr2/79+xEREWGI8ugfHgn3xfanBqCdoxVScksxfs1f2H7qutRlERERSRtuXn75ZRw6dAjJyck4evQoxo0bB3Nzc0yZMgUAMH36dCxYsEC7//z58/Hbb7/hgw8+wOXLl7Fo0SLExMTgmWeekaoJJq17ewfsfvYuDO7shvJKDV7adhZv7jyPiiq11KUREZEJkzTcXL9+HVOmTEFQUBAmTZoEFxcXHD9+HG5ubgCA1NRUZGZmavcfMGAAvv32W6xbtw49evTADz/8gJ07dyIkJESqJpg8Jxs51s/si/nDOkEQgE3HUzHp8+PI4HRxIiKSiCCa2KJBSqUSDg4OKCws5PgbHfvjcg6e/+4MCssq4Wwjx6pHeuGuTg0/t4iIiKipmvP93arG3FDbNjTYHbufvQvdvO2RV6LC9PUnsPbQNS66SUREBsVwQzrl42yN7XMHYFJYe2hEYNmey5i/9QzKVByHQ0REhsFwQzpnaWGO5RNC8c6YbpCZCfjpbAYmfn6UyzYQEZFBMNyQXgiCgMciOmDT7H5wtpHjQroSYz49guikPKlLIyIiI8dwQ3rVP8AFPz0zEF287HGzWIVHvziOzSdSpC6LiIiMGMMN6V17J2tsnxuB0aFeqNKIeGPHBbyx4zxUVRqpSyMiIiPEcEMGYS2X4dMpvfDKiCAIArD5RCqmfXkCN4srpC6NiIiMDMMNGYwgCJg3tCO+mhEGO4UM0cl5eOiTI7iQXih1aUREZEQYbsjg7gn2wI55AxHgaoOMwnI8vPYofjqbIXVZRERkJBhuSBId3W2xY95ADAmqXpfquS2xeO/XS6hScxwOERHdGYYbkoyDlQW+mtEXc4cEAgDW/ZmIR788gRxlucSVERFRW8ZwQ5IyNxPw6shgrJ7aG7YKGaKT8nD/qiM4npgrdWlERNRGMdxQq3B/dy/89MxABHnY4WZxBaZ+yXWpiIioZRhuqNUIcLPFjnkDML5XO6g1IpbtuYwnN55CYVml1KUREVEbwnBDrYq1XIYPJvXA/8aFQG5uhv0Xs/HQp0cQl8Hp4kRE1DQMN9TqCIKAqf388MPcCLRztEJKbinGrT6K706mSl0aERG1AQw31GqFtnfEL8/dhaFBblBVafDq9vP47w9nUV6plro0IiJqxRhuqFVztJbjqxl98fJ9nWEmAN/HXMe41UcRl1HItamIiKhegmhi01GUSiUcHBxQWFgIe3t7qcuhZvgr4Sae2xKL3BIVAEAQAHc7BbwdreDtaIV2jlbwdrBEOydreDtaop2jFRysLCAIgsSVExHRnWrO9zfDDbUpWYXleHX7ORxPzEVFE3purOXm6Oxhh7ce7Irevk4GqJCIiPSB4aYRDDfGQRRF5JaokFFQhoyCMqQXlCM9v/rnjMIypOeXaXt4AEBmJuDlEUF4clAAzMzYk0NE1NYw3DSC4cZ0lFeqcT2/DB8fvIqfby3MeXdnN3w4qQdcbRUSV0dERM3RnO9vDigmo2VpYY6O7rZY9UhPLBvfHZYWZvjzyg2M+vgwjibclLo8IiLSE4YbMnqCIOCRcF/89Mxd6OxhixtFFZj61Ql8sC+eq5ATERkhhhsyGZ097LBr3l2YEu4DUQQ++T0Bj35xApmFZVKXRkREOsRwQybFSm6OpeNDsWpKr+pVyJPzcP/Hh3HwUrbUpRERkY4w3JBJeqiHN3Y/exe6t3NAfmklnvg6Bkt+vsgHAxIRGQGGGzJZHVxtsH3uADxxlz8AYP1fSZiw5iiu55dKXBkREd0JhhsyaXKZGRY+0BVfzQiDo7UFzqcX4tEvTiCnqFzq0oiIqIUYbogADOvigV+fGwRfZ2uk5pVixvqTUJZXSl0WERG1AMMN0S3ejlbY+EQ4XG0VuJSpxOyvY7gCORFRG8RwQ/QPfi42iHy8L+wUMkQn5eG5LbE6eRbOlewi/Hj6On6/nI3TqflIulmCwtJKaDQm9YBwIiKD4PILRPU4npiL6eujoarSYHKYD5ZN6N6i1cXVGhGr/0jAyoNXoa4nyJgJgKO1HE7WFnCylsPRWg5nGws42cjhaW8JLwdLeDpYwcvBEq62CphzXSwiMlHN+f6WGagmojalf4ALPpnSC3M3ncJ3MWlwtpXj1ZHBzTpHZmEZnt96BieS8gAAPXwcodGIyCtRoaBUhRKVGhoRyCtRIa9EBaCk0fOZmwnwsFPA08ESXg5W8HSwhKe9JbwdrTA4yA22Cv51JiICGG6IGjSimyeWju+OV7efx5qoa3CxkWP2oIAmHbsvLgv/3X4OBaWVsJGb452xIRjfu32tfSqq1CgorUR+qepW4Kn+Ob9EhZvFKmQry5FZWI6swnLkFJVDrRGRUViOjMJyAAX/qtUDnz8WpqOWExG1bQw3RI2Y3NcXuSUqrPgtHu/+cgnONvI6IeWfyivV+N8vl7DxeAoAoHs7B6ya0gv+rjZ19lXIzOFhbw4Pe8vb1lGl1uBmsQqZhWXIKrwVepTlyCgowy/nM7E3LhtxGYXo5u3Q8sYSERkJhhui25g7OBC5xSp8dSQJr/xwDo7WFrgn2KPOfleyi/Dst7GIzy4CADx5dwBevi8Ictmdj9uXmZtV34ZyqCcIfXsau89l4tPfE7BmWp87/iwioraOs6WIbkMQBLxxfxeM69UOao2IpzefRkxynvZ9URSx+UQKHvzkCOKzi+Bqq8DXs8Lx+v1ddBJsbufZezoBAPZcyMKVW8GKiMiUMdwQNYGZmYAVD4diaJAbyis1mBV5EvFZRSgoVWHuptN4Y8cFVFRpMLizG/bMH4TBnd0MVluQpx1GdvMEAHz6e4LBPpeIqLXiVHCiZihTqTHtqxM4lZIPdzsFZGYCMgrLYWEu4NWRwZg10B9mEkzXjssoxOhVRyAIwIEXByPQzdbgNRAR6VNzvr/Zc0PUDFZyc3w1IwydPWyRU1SBjMJy+LvaYMfTAzF7UIAkwQYAunk7YHgXd4gi8Nkf7L0hItPGcEPUTI7Wcnwzqx8GdXLFY/39sPvZuxDSTvpZSjVjb3adyUBKbuPPzCEiMmYMN0Qt4OlgiY1P9MM7Y0Ng00oentfDxxGDO7vdeiryNanLISKSDMMNkRF5blh1783209eRllcqcTVERNJguCEyIn38nHBXR1dUaUSsPcTeGyIyTQw3REbm2Xs6AgC2xVxHZmGZxNUQERkeww2RkekX4IJ+/s5QqTX4/FCi1OUQERkcww2REaoZe/NtdCpylOUG+UwTe2QWEbVirWOaBxHp1IBAF/Txc8KplHys+zMRbz7QVWfnVlVpkHizGJczi3ApS4nLmUXVT2suU2HGgA546V7drKdFRNRSfEIxkZGKis/BzA0nYWVhjsOvDoWrraJZx4uiiGxlxT8CjBKXs4qQkFOMKk3D/2x0b+eAjx/piQA+JZmIdKg539/suSEyUoM7u6FHewecvV6ILw8n4bVRwU0+9mxaAd7+KQ5n0grqfd9OIUOwlx2CPe0R5GmHLl52yCqswBs7z+N8eiEe+OQIFj3UDRP7tIcgSPPUZiIyXey5ITJiBy5mY/Y3MbCRm+PIq/fAyUbe6P55JSqs+O0yvotJgygC5mYCAlxtbgUYewR72iHI0w7tHK3qDS2ZhWV44bszOJ5YvWr66FAvvDeuOxysLPTSPiIyHc35/ma4ITJioihi9KojuJipxLP3dMRL9wXVu59aI+LbEyn4v31XUFhWCQAY36sdXhsVDHd7y2Z9ploj4vM/r+HDfVdQpRHRztEKKx/pib4dnJtdf36JClFXciCKwN2d3Zp9a42IjAfDTSMYbsjU/HYhE09tOg07hQxHXrunTi/KqZQ8LNwZh4uZSgBAFy97LBnTrUVh5J/OpBVg/tZYpOSWwkwAnrmnE567pyNk5o0PNr6eX4r9F7OxLy4b0cl5UN8a3yMIQG9fJwzv4oF7u7oj0M2Wt7yITAjDTSMYbsjUaDQiRn18GPHZRXhheGfMH149TfxGUQWW7bmM7aevAwDsLWV4eUQQHg33vW0Aaariiiq8tesCfjydDqD6CcorJ/eEj7O1dh9RFHEluxh747Kw72IWLqQra52ji5c9zAQgLqP29g4u1hjexQPDu3ogzM9JZzUTUevEcNMIhhsyRT+fzcCzW2LhYGWBQ68MwY+n0/HR/isoqqgCAEwO88ErI4P0dttn15l0vLnjAooqqmCnkOHdcSFo52iFfRezsTcuCym5f6+DZSYAYR2ccV9XD4zo5qkNQpmFZThwKQcHLmbj2LVcqNQa7TEOVhYYGuSG4V09MLizG+wsOcaHyNgw3DSC4YZMkVoj4r6PDuHajRLYW8qgLK8ONd3bOWDJmG7o5euk9xrS8koxf2ssTqcW1HlPLjPDoI6uGNHNE8O6uMPlNiGruKIKh6/cwP5L2fjjcg7ySyu179nIzfHZ1N4YEuSu6yYQkYQYbhrBcEOmakfsdbzw3VkAgKO1Bf47IhiT+/rA3Mxw41aq1Bqs+j0Bn/2RAGu5OYYFu+O+bp4Y3NkNNoqWPZmiSq3B6dQCHLj0dy+QXGaG9TP64q5OrjpuARFJheGmEQw3ZKqq1Bos3XMZ5mYC5g4OvO20cH0qr1TDTBB0/iRjVZUGT28+jQOXsmFpYYYNM8MREeii088gImkw3DSC4YbIuFVUqfHUxlP4I/4GrOXm+HpW+B3P/CIi6TXn+5vTC4jIqChk5lgzrQ8GdXJFqUqNxzecxOnUfKnLalBhWSUyC8ukLoPIqLDnhoiMUplKjVmRJ3EsMRd2Chk2z+mH0PaOktRSUKpCcm4pkm+WIDm3BCm5pdr/zStRAQDeHN0FswcFSFIfUVvA21KNYLghMh2lqirMXH8S0cl5sLeU4ds5/RHSzkGvnymKInafy8T+i9lIyS1Bcm6p9qnPt7Pi4VBMCvPRa31EbRXDTSMYbohMS3FFFaZ/dQKnUwvgZG2BLU/2R7Cnfv7ul1eqsXDnBWw7db3Oex72Cvi52KCDizU6uNqgg4sN/Fys4edig1UHr2Ldn4kwE4DVU/tgZIinXuojassYbhrBcENkepTllXjsyxM4e70QLjZybH2yPzp52On0M1JzSzF38ynEZShhJgBP3OWPPn5O8LsVYqzlDU91F0URr24/h+9jrkNuboYNj/fFwI6cxk70T21yQPGyZcsgCAKef/75BveJjIyEIAi1XpaWzVvUj4hMj72lBb6Z1Q8h7eyRW6LCo1+ewLUbxTo7/++Xs/HAJ4cRl6GEi40cG5/ohzdGd8XIEC908bJvNNgAgCAIeG9cd4zs5gmVWoMnv4nB2bQCndVHZGpaRbg5efIkPv/8c4SGht52X3t7e2RmZmpfKSkpBqiQiNo6B2sLbJzVD8GedrhRVIFHvziO5Jsld3ROtUbEh/viMSsyBsryKvT0ccTu5+5qUa+LzNwMH0/piYEdXVCiUmPmhmhczS66o/qITJXk4aa4uBhTp07FF198ASen2z8CXhAEeHp6al8eHh4GqJKIjIGTjRybZ/dDZw9bZCurA87hqzdQ9Y91qpoqr0SFmRuiser3BADAY/398N1/+sPLwarF9Slk5lj3WBh6+Dgiv7QSj30VjbS80tsfSES1SB5u5s2bh9GjR2P48OFN2r+4uBh+fn7w8fHBmDFjEBcX1+j+FRUVUCqVtV5EZLpcbBXYPLs/At1skFFYjse+ikbf/x3Aa9vP4dCVG6hsQtA5m1aABz85gsNXb8LSwgwfTe6Bd8aGQCEzv+P6bBQyRM7si07utshSluOxr07gRlHFHZ+XyJRIGm62bt2K06dPY+nSpU3aPygoCOvXr8euXbuwadMmaDQaDBgwANev152ZUGPp0qVwcHDQvnx8OM2SyNS52Smw5cn+mBLuC2cbOfJLK7H1ZBpmrI9G2LsH8PK2s/j9cjYqqtS1jhNFEZtPpGDi2mNILyhDBxdr7Jw3EON6tddpfU63xu20d7JCcm4pZqyPbvJ0ciKScLZUWloawsLCsH//fu1YmyFDhqBnz55YuXJlk85RWVmJLl26YMqUKXjnnXfq3aeiogIVFX//V49SqYSPjw9nSxERgOo1t6KT8vDrhUz8diEbN4v//vfCTiHD8K4eGBXiiX7+Lliy+yK2n67+j6n7unrg/yb1gL2lhd5qS75ZgofXHsPN4gr07eCEb2b1g5X8znuHiNqiNjEVfOfOnRg3bhzMzf/+i6pWqyEIAszMzFBRUVHrvYZMnDgRMpkMW7ZsadLncio4ETVErRERk5yHPReysOdCJrKVfwcdQQBEETATgFdGBOOpwQEQBP2vqH4xQ4nJ646hqLwKQ4PcsG56GCzMJR9R0GSiKOJ8eiH8nG3gYK2/IKhLCTnF2BuXBXc7BR7s4Q1LCwbK1qBNhJuioqI6M50ef/xxBAcH49VXX0VISMhtz6FWq9GtWzfcf//9+PDDD5v0uQw3RNQUGo2I06n5+PV8ddDJLCyHi40cn0zphQEGfgZNTHIepn11AuWVGozp6Y2PJvWEmZn+g5UufLgvHqt+T4CnvSUiZ/XV2wMU71RxRRV+OZeB706m4XRqgXa7s40cU/v5Ylp/P3jY89EjUmoT4aY+/74tNX36dLRr1047JmfJkiXo378/OnbsiIKCArz//vvYuXMnTp06ha5duzbpMxhuiKi5NBoRl7KU8HawgpONXJIa/ojPwZyvY1ClEeFqq0BHdxsEutlWv9xtEehmA28Hq1YVej77IwHv743X/m6nkGHtY31azQMKRVHEyeR8fB+Thl/OZaKssnqMlbmZgEGdXHE1uxjpBdWLmlqYCxjd3Quz7vKXbI0yU9ec7+/GnywlsdTUVJiZ/d39mp+fjzlz5iArKwtOTk7o06cPjh492uRgQ0TUEmZmArp563dNqtsZGuSOjyb3xCs/nMXN4grcLK7A8cS8WvtYWpghwNUWAW422tDTztESDlZyOFpbwNHKAjID3dL66kiSNtg8N6wTjifmIjopDzM3RGPFw6E6H4TdHFmF5dh++jq2xaQhOffvqfYBbjaYFOaD8b3bwd3OElVqDfZfzMb6v5JwMjkfO89kYOeZDIT5OWHWXf64r6uHwf48qXlaVc+NIbDnhojaspKKKly7UVz9yinR/px8sxSqJkxjt1PI4GBtAUdrCzhZy+FgZXEr+MjR0d0WD/bwhvkd9v5sPpGCN3ZcAADMH9YJL9zbGRVVarz0/VnsPpcJAHhlRBCeHhJ4x+OWKtUaqKo0qFKLUKk1qFQ3/HO2shw7Y9Nx6MoNaG5989nIzfFAqDcm9W2P3r5ODdZz/nohNvyVhJ/PZaBSXX1wO0crzBjgh8lhvm1mPFFb1mZvSxkCww0RGaMqtQbX88vqBJ8bxRXIL1FBWV7VpPOE+Tnhw0k94eti3aI6tp+6jpd/OAtRBP4zOACvjQzWBgaNRsSy3y5j3Z+JAICp/Xyx+KFuze79EEURxxJz8fGBqziRlHf7A+oR3sEZE8Pa4/7uXrBRNP0mRo6yHJuOp2DTiVTklagAAFYW5pgS7otXRgRxNpseMdw0guGGiEyRWiNCWVaJgrJK5JeqUFhaiYIyFQpKK1FQWom8EhV2xKajuKIKNnJzvP1gN0wMa9+snpXd5zLw3JZYaERgRoQfFj3Urd7jI/9KwuLdFyGKwPAu7lg1pddt198C/g41Kw9cRXQDoUZubgaZuQALczNYaP+3epulzByDg9wwsU97BLjZNrld9SmvVOOnsxlYfyQJl7Oql8kI8rDDmmm97/jcNTSa6plmQZ52nLEFhptGMdwQEdUvLa8UL31/FtHJ1cHhvq4eWDq+O1xsFbc9dv/FbMzddApVGhGP9PXBe+O6Nzq4+bcLWZi/NRYVVRr0aO+Ar2b2hWsDnyOKIo5duxVqbtUmNzfDI+E+mH1XANzsFJCZC5CZCQaZnv/v2qLib+CVH87hZnEFbBUyrHg4FPd397qj8ybeKMaCH8/jRFIeOnvYYsPj4Wjn2PKlPYwBw00jGG6IiBqm1oj44nAiPtgXj0q1CFdbOZZPCMWwLg2v43foyg3M+ToGKrUGY3t644NJPZs0budUSj5mf30S+aWV8HW2xtezwuHvaqN9v6FQMyXcB08NCbyjdbx0LUdZjme2xGp7lGYN9Mdro4IhlzXvllulWoMvDidi5YGrUFX9PYbK3U6B9TP7IqSdtAPbb6dSrcFL35/FE3f5o4ePo07PzXDTCIYbIqLbu5ihxAvfnUH8rZXJp4T74s3RXeqMTzl2LRczN0SjokqDUSGe+GRKr2aNoUm8UYyZG04iNa8UTtYW+HJGX/T2dcTRa7lYeeAKTibnAwDkMjM8Gu6LpwYHwtOhdT5vpkqtwfv74vH5oeoxRb19HfHZ1N5NDmEX0gvx6vZziMuoXgNxUCdXPHtPJ7y58zyuZBfDRm6OT6f2xtAgd7214U59+vtV/N++K3C1lePIq/fo9HYaw00jGG6IiJqmvFKN/9sbjy+PJAEAOrhY48PJPdHb1wlAdc/LY1+dQKlKjXuC3bF2Wp9m91QAwM3iCjwReRJnrxdCITNDFy97nEkrANA2Qs2/7YvLwkvbzqKovArONnJ8/EhPDOrk1uD+5ZVqrDxwFV8cToRaI8LBygILH+iKCb3bQRAEFJZVYu6mUzh6LRfmZgLeHRuCKeG+BmxR0yTkFOH+j49ApdZg5eSeGNurnU7Pz3DTCIYbIqLmOZpwEy9vO4uMwnKYCcAzQztiaLA7pn8VjaKKKtzV0RVfzgi7o/9KL1VV4dlvY3Hwcg6Av0PN3CGBbfLJwKm5pZi7+RTiMpQQBOD5YZ3x7D0d64xDOp6YiwU/nkfSzRIAwOhQLyx6sBvc7GqPP1JVafDaj+fw4+l0AMC8oYF4+b4gg48xaohaI2Li2qM4nVqAe4Ld8dWMMJ3XxnDTCIYbIqLmKyyrxNu7LmDnmYxa28M7OCNyVt8mzXa6nSq1BqujrqFUpcbjAzu0yVDzT+WVaiz+OQ5botMAAHd3dsPKyT3hbCOHsrwSS3+9jC3RqQAAD3sF3h3bHfd2bXhskyiK+OjAVaw6eBUAMLanN5Y/HAqFTPqZVJF/JWHRzxdhq5Bh3wt3w1sPg58ZbhrBcENE1HI/n83AmzsvoLCsEj18HLHpiXDY6XFldGPww6nreHPneZRXauDlYIk5gwLw+Z/XtAuzPtrPF6+NCm7yCvPfn0zD6zvOo0ojon+AMz6fFibpQwTT8koxYuWfKFWp8e7YEEzr76eXz2G4aQTDDRHRnclWluNQ/A2M6u7JYNNEl7OUmLvptPb2EwD4u9pg6fju6B/g0uzzHb56A3M3nUZxRRU6uttiw8y+8HFu2YMX74Qoipi+PhqHr95EuL8zts7pr7f1zZrz/c1FMYiIqFk87C0xqa8Pg00zBHva46dnBuKBUC/IZWaYOyQQe+YPalGwAYBBndyw7akIeNpbIiGnGONWH8X564U6rvr2tp9Ox+GrN6GQmWHZ+MafbWRI7LkhIiIyILVGvOP1u2pkFpbh8Q0ncTmrCFYW5njx3s7wcbaGq60crrYKuNjKYauQ6WXgcU5ROe798E8UllXitVHBeGpwoM4/45+MZlVwIiIiY6OrYAMAXg5W2PZUBJ7efBqHr97E/369VGcfhcxMG3RcbGpCjwJBnrZ4qEe7Ftfz9q44FJZVIqSdPWbf5X+nTdEphhsiIqI2zM7SAutn9sW6PxMRm1qA3JIK5BarcLO4AqUqNSqqNEgvKEN6QVmdY7fFXMfKyT3h3syZaXvOZ2LPhSzIzASsmNCj2Yuf6hvDDRERURtnYW6GeUM71tleqqrSBp3cYhVySypws1iFbGU5fjh1HUev5eL+VYfx4aSeuLtzww8a/KfC0kos3BUHAHhqcCC6ere+IR4MN0REREbKWi6DtbOs3plU0yM64JlvT+NyVhGmr4/G00MC8eK9nW/bC/PuLxdxs7gCgW42eHZY3UDVGrSufiQiIiIyiI7uttg5byCm9a9eymF11DVMXne83ttXNQ5fvYFtp65DEIAVreQBgvVhuCEiIjJRlhbmeHdsd6ye2ht2ChlOpeTj/o8PY//F7Dr7llRUYcGP5wEAMyI6oI+fs6HLbTKGGyIiIhN3f3cv/Dp/EHr4OKKwrBJzvonB4p/jUFGl1u7z/t54XM8vQztHK7wyIkjCam+P4YaIiIjg42yNbf+JwJxB1dO6N/yVjAlrjiL5ZglOpeTh62PJAICl47vDRtG6h+y27uqIiIjIYOQyM7wxuisiAl3w0vdncSFdiQc+OQIHKwuIIvBwn/ZNnlUlJfbcEBERUS33BHvg1/mDEN7BGcUVVUgvKIOrrQJvju4idWlNwnBDREREdXg5WOHbOf0wf1gndHCxxgeTesDRWi51WU3CtaWIiIio1eOq4ERERGSyGG6IiIjIqDDcEBERkVFhuCEiIiKjwnBDRERERoXhhoiIiIwKww0REREZFYYbIiIiMioMN0RERGRUGG6IiIjIqDDcEBERkVFhuCEiIiKjwnBDRERERoXhhoiIiIyKTOoCDE0URQDVS6cTERFR21DzvV3zPd4Ykws3RUVFAAAfHx+JKyEiIqLmKioqgoODQ6P7CGJTIpAR0Wg0yMjIgJ2dHQRB0Om5lUolfHx8kJaWBnt7e52euzUw9vYBxt9Gtq/tM/Y2sn1tn77aKIoiioqK4O3tDTOzxkfVmFzPjZmZGdq3b6/Xz7C3tzfa/9MCxt8+wPjbyPa1fcbeRrav7dNHG2/XY1ODA4qJiIjIqDDcEBERkVFhuNEhhUKBt99+GwqFQupS9MLY2wcYfxvZvrbP2NvI9rV9raGNJjegmIiIiIwbe26IiIjIqDDcEBERkVFhuCEiIiKjwnBDRERERoXhRkc+++wzdOjQAZaWlujXrx+io6OlLklnFi1aBEEQar2Cg4OlLqvF/vzzTzz44IPw9vaGIAjYuXNnrfdFUcRbb70FLy8vWFlZYfjw4bh69ao0xbbQ7do4c+bMOtd05MiR0hTbAkuXLkXfvn1hZ2cHd3d3jB07FvHx8bX2KS8vx7x58+Di4gJbW1tMmDAB2dnZElXcPE1p35AhQ+pcw6eeekqiiptnzZo1CA0N1T7kLSIiAnv27NG+35avXY3btbEtX7/6LFu2DIIg4Pnnn9duk/I6MtzowHfffYcXX3wRb7/9Nk6fPo0ePXpgxIgRyMnJkbo0nenWrRsyMzO1ryNHjkhdUouVlJSgR48e+Oyzz+p9f8WKFVi1ahXWrl2LEydOwMbGBiNGjEB5ebmBK22527URAEaOHFnrmm7ZssWAFd6ZQ4cOYd68eTh+/Dj279+PyspK3HfffSgpKdHu88ILL+Dnn3/Gtm3bcOjQIWRkZGD8+PESVt10TWkfAMyZM6fWNVyxYoVEFTdP+/btsWzZMpw6dQoxMTG45557MGbMGMTFxQFo29euxu3aCLTd6/dvJ0+exOeff47Q0NBa2yW9jiLdsfDwcHHevHna39Vqtejt7S0uXbpUwqp05+233xZ79OghdRl6AUDcsWOH9neNRiN6enqK77//vnZbQUGBqFAoxC1btkhQ4Z37dxtFURRnzJghjhkzRpJ69CEnJ0cEIB46dEgUxeprZmFhIW7btk27z6VLl0QA4rFjx6Qqs8X+3T5RFMXBgweL8+fPl64oHXNychK//PJLo7t2/1TTRlE0nutXVFQkdurUSdy/f3+tNkl9Hdlzc4dUKhVOnTqF4cOHa7eZmZlh+PDhOHbsmISV6dbVq1fh7e2NgIAATJ06FampqVKXpBdJSUnIysqqdT0dHBzQr18/o7qeABAVFQV3d3cEBQVh7ty5yM3NlbqkFissLAQAODs7AwBOnTqFysrKWtcxODgYvr6+bfI6/rt9NTZv3gxXV1eEhIRgwYIFKC0tlaK8O6JWq7F161aUlJQgIiLC6K4dULeNNYzh+s2bNw+jR4+udb0A6f8OmtzCmbp28+ZNqNVqeHh41Nru4eGBy5cvS1SVbvXr1w+RkZEICgpCZmYmFi9ejEGDBuHChQuws7OTujydysrKAoB6r2fNe8Zg5MiRGD9+PPz9/XHt2jW8/vrrGDVqFI4dOwZzc3Opy2sWjUaD559/HgMHDkRISAiA6usol8vh6OhYa9+2eB3rax8APProo/Dz84O3tzfOnTuHV199FfHx8fjxxx8lrLbpzp8/j4iICJSXl8PW1hY7duxA165dcebMGaO5dg21EWj71w8Atm7ditOnT+PkyZN13pP67yDDDd3WqFGjtD+HhoaiX79+8PPzw/fff48nnnhCwsqopR555BHtz927d0doaCgCAwMRFRWFYcOGSVhZ882bNw8XLlxo0+PAGtNQ+5588kntz927d4eXlxeGDRuGa9euITAw0NBlNltQUBDOnDmDwsJC/PDDD5gxYwYOHTokdVk61VAbu3bt2uavX1paGubPn4/9+/fD0tJS6nLq4G2pO+Tq6gpzc/M6I8Czs7Ph6ekpUVX65ejoiM6dOyMhIUHqUnSu5pqZ0vUEgICAALi6ura5a/rMM89g9+7d+OOPP9C+fXvtdk9PT6hUKhQUFNTav61dx4baV59+/foBQJu5hnK5HB07dkSfPn2wdOlS9OjRAx9//LHRXDug4TbWp61dv1OnTiEnJwe9e/eGTCaDTCbDoUOHsGrVKshkMnh4eEh6HRlu7pBcLkefPn1w8OBB7TaNRoODBw/WurdqTIqLi3Ht2jV4eXlJXYrO+fv7w9PTs9b1VCqVOHHihNFeTwC4fv06cnNz28w1FUURzzzzDHbs2IHff/8d/v7+td7v06cPLCwsal3H+Ph4pKamtonreLv21efMmTMA0Gau4b9pNBpUVFS0+WvXmJo21qetXb9hw4bh/PnzOHPmjPYVFhaGqVOnan+W9DrqfciyCdi6dauoUCjEyMhI8eLFi+KTTz4pOjo6illZWVKXphMvvfSSGBUVJSYlJYl//fWXOHz4cNHV1VXMycmRurQWKSoqEmNjY8XY2FgRgPjhhx+KsbGxYkpKiiiKorhs2TLR0dFR3LVrl3ju3DlxzJgxor+/v1hWViZx5U3XWBuLiorEl19+WTx27JiYlJQkHjhwQOzdu7fYqVMnsby8XOrSm2Tu3Lmig4ODGBUVJWZmZmpfpaWl2n2eeuop0dfXV/z999/FmJgYMSIiQoyIiJCw6qa7XfsSEhLEJUuWiDExMWJSUpK4a9cuMSAgQLz77rslrrxpXnvtNfHQoUNiUlKSeO7cOfG1114TBUEQ9+3bJ4pi2752NRprY1u/fg359wwwKa8jw42OfPLJJ6Kvr68ol8vF8PBw8fjx41KXpDOTJ08Wvby8RLlcLrZr106cPHmymJCQIHVZLfbHH3+IAOq8ZsyYIYpi9XTwhQsXih4eHqJCoRCHDRsmxsfHS1t0MzXWxtLSUvG+++4T3dzcRAsLC9HPz0+cM2dOmwrj9bUNgLhhwwbtPmVlZeLTTz8tOjk5idbW1uK4cePEzMxM6Ypuhtu1LzU1Vbz77rtFZ2dnUaFQiB07dhRfeeUVsbCwUNrCm2jWrFmin5+fKJfLRTc3N3HYsGHaYCOKbfva1WisjW39+jXk3+FGyusoiKIo6r9/iIiIiMgwOOaGiIiIjArDDRERERkVhhsiIiIyKgw3REREZFQYboiIiMioMNwQERGRUWG4ISIiIqPCcENERERGheGGiFqlGzduYO7cufD19YVCoYCnpydGjBiBv/76CwAgCAJ27twpbZFE1CrJpC6AiKg+EyZMgEqlwtdff42AgABkZ2fj4MGDyM3Nlbo0ImrluPwCEbU6BQUFcHJyQlRUFAYPHlzn/Q4dOiAlJUX7u5+fH5KTkwEAu3btwuLFi3Hx4kV4e3tjxowZeOONNyCTVf+3nCAIWL16NX766SdERUXBy8sLK1aswMMPP2yQthGR/vG2FBG1Ora2trC1tcXOnTtRUVFR5/2TJ08CADZs2IDMzEzt74cPH8b06dMxf/58XLx4EZ9//jkiIyPxv//9r9bxCxcuxIQJE3D27FlMnToVjzzyCC5duqT/hhGRQbDnhohape3bt2POnDkoKytD7969MXjwYDzyyCMIDQ0FUN0Ds2PHDowdO1Z7zPDhwzFs2DAsWLBAu23Tpk3473//i4yMDO1xTz31FNasWaPdp3///ujduzdWr15tmMYRkV6x54aIWqUJEyYgIyMDP/30E0aOHImoqCj07t0bkZGRDR5z9uxZLFmyRNvzY2trizlz5iAzMxOlpaXa/SIiImodFxERwZ4bIiPCAcVE1GpZWlri3nvvxb333ouFCxdi9uzZePvttzFz5sx69y8uLsbixYsxfvz4es9FRKaBPTdE1GZ07doVJSUlAAALCwuo1epa7/fu3Rvx8fHo2LFjnZeZ2d//3B0/frzWccePH0eXLl303wAiMgj23BBRq5Obm4uJEydi1qxZCA0NhZ2dHWJiYrBixQqMGTMGQPWMqYMHD2LgwIFQKBRwcnLCW2+9hQceeAC+vr54+OGHYWZmhrNnz+LChQt49913tefftm0bwsLCcNddd2Hz5s2Ijo7GV199JVVziUjHOKCYiFqdiooKLFq0CPv27cO1a9dQWVkJHx8fTJw4Ea+//jqsrKzw888/48UXX0RycjLatWunnQq+d+9eLFmyBLGxsbCwsEBwcDBmz56NOXPmAKgeUPzZZ59h586d+PPPP+Hl5YXly5dj0qRJEraYiHSJ4YaITEp9s6yIyLhwzA0REREZFYYbIiIiMiocUExEJoV34omMH3tuiIiIyKgw3BAREZFRYbghIiIio8JwQ0REREaF4YaIiIiMCsMNERERGRWGGyIiIjIqDDdERERkVBhuiIiIyKj8Pw2PG/uKu49DAAAAAElFTkSuQmCC", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "plt.plot(metrics_history['train_loss'])\n", + "plt.title('Training Loss')\n", + "plt.xlabel('Step')\n", + "plt.ylabel('Loss')\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Sy74Hsey8bBR", + "outputId": "78d9ce88-54aa-4501-fca9-42b4df44b466" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total model parameters: 34,420,433\n", + "Frozen parameters: 27,430,481 (79.69%)\n", + "Trainable LoRA parameters: 6,989,952 (20.31%)\n", + " - LoRA A matrices: 294,912\n", + " - LoRA B matrices: 6,695,040\n" + ] + } + ], + "source": [ + "# Analysis of LoRA Parameter Efficiency using proper module iteration\n", + "import jax.numpy as jnp\n", + "import matplotlib.pyplot as plt\n", + "\n", + "# Initialize counters\n", + "total_params = 0\n", + "lora_a_params = 0\n", + "lora_b_params = 0\n", + "other_params = 0\n", + "layer_counts = {}\n", + "\n", + "# Iterate through all modules and count parameters\n", + "for path, module in lora_model.iter_modules():\n", + " module_name = '.'.join(str(p) for p in path) if path else \"root\"\n", + "\n", + " module_params = 0\n", + " module_lora_params = 0\n", + "\n", + " for name, attr in vars(module).items():\n", + " if name.startswith('_') or not hasattr(attr, 'value'):\n", + " continue\n", + "\n", + " # Get the parameter array\n", + " param_array = attr.value\n", + " if not isinstance(param_array, jnp.ndarray):\n", + " continue\n", + "\n", + " param_count = param_array.size\n", + " module_params += param_count\n", + " total_params += param_count\n", + "\n", + " # Check if this is a LoRA parameter\n", + " if name == 'lora_a':\n", + " lora_a_params += param_count\n", + " module_lora_params += param_count\n", + " elif name == 'lora_b':\n", + " lora_b_params += param_count\n", + " module_lora_params += param_count\n", + " else:\n", + " other_params += param_count\n", + "\n", + "# Calculate total LoRA parameters and ratios\n", + "lora_params = lora_a_params + lora_b_params\n", + "trainable_ratio = lora_params / total_params\n", + "frozen_ratio = other_params / total_params\n", + "\n", + "print(f\"Total model parameters: {total_params:,}\")\n", + "print(f\"Frozen parameters: {other_params:,} ({frozen_ratio:.2%})\")\n", + "print(f\"Trainable LoRA parameters: {lora_params:,} ({trainable_ratio:.2%})\")\n", + "print(f\" - LoRA A matrices: {lora_a_params:,}\")\n", + "print(f\" - LoRA B matrices: {lora_b_params:,}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bOl4CY61Llcp" + }, + "source": [ + "This example has demonstrated the implementation of LoRA in JAX, showing how to replace standard linear layers with LoRA-enabled versions and train only these adapter parameters while keeping the base model frozen.\n", + "\n", + "As we've seen from our experiment results, this approach of applying LoRA to a model trained from scratch produced limited generation quality. The text outputs were repetitive and lacked coherence.\n", + "\n", + "That is because LoRA is designed to make incremental adaptations to already capable models, not to carry the full burden of learning language structure from scratch. The small parameter space of the LoRA matrices (even with rank=128) simply cannot capture the full complexity of language when starting from random initialization.\n", + "\n", + "\n", + "###Next Steps\n", + "In a subsequent chapter, we'll explore how to integrate LoRA into existing pre-trained language models rather than training from scratch. If you want to stop here and save your progress you can run the follwing cells:" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "soPqiR1JNmjf" + }, + "source": [ + "### Saving" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.status.busy": "2025-03-17T06:57:45.284249Z", + "iopub.status.idle": "2025-03-17T06:57:45.285697Z", + "shell.execute_reply": "2025-03-17T06:57:45.284666Z", + "shell.execute_reply.started": "2025-03-17T06:57:45.284618Z" + }, + "id": "EkoFGCgSZ1yz" + }, + "outputs": [], + "source": [ + "import orbax.checkpoint as orbax\n", + "\n", + "state = nnx.state(lora_model)\n", + "\n", + "checkpointer = orbax.PyTreeCheckpointer()\n", + "checkpointer.save('/content/save', state)\n", + "\n", + "# Make sure the files are there\n", + "!ls /content/save/" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jCApVd7671c1" + }, + "source": [ + "### Disconnect the Colab runtime" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.status.busy": "2025-03-17T06:57:45.300296Z", + "iopub.status.idle": "2025-03-17T06:57:45.301123Z", + "shell.execute_reply": "2025-03-17T06:57:45.300683Z", + "shell.execute_reply.started": "2025-03-17T06:57:45.300636Z" + }, + "id": "NsqYdbrDVKSq" + }, + "outputs": [], + "source": [ + "from google.colab import runtime\n", + "runtime.unassign()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BRE3MBAfRS7i" + }, + "source": [ + "# 2. Fine-tuning a pre-trained LLM with LoRA" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "jupytext": { + "formats": "ipynb,md:myst" + }, + "kaggle": { + "accelerator": "nvidiaTeslaT4", + "dataSources": [], + "dockerImageVersionId": 30919, + "isGpuEnabled": true, + "isInternetEnabled": true, + "language": "python", + "sourceType": "notebook" + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/docs/source/JAX_using_LoRA.md b/docs/source/JAX_using_LoRA.md new file mode 100644 index 0000000..998aaf2 --- /dev/null +++ b/docs/source/JAX_using_LoRA.md @@ -0,0 +1,574 @@ +--- +jupytext: + formats: ipynb,md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.16.7 +kernelspec: + display_name: Python 3 + name: python3 +--- + ++++ {"id": "QEhawzCcCcFR"} + +#Using LoRA in Jax + +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_using_LoRA.ipynb) + + +This tutorial demonstrates how to implement LoRA for efficient fine-tuning of language models in JAX. +It builds upon the [JAX for LLM pretraining](https://docs.jaxstack.ai/en/latest/JAX_for_LLM_pretraining.html) tutorial by showing how to replace standard linear +layers with LoRA-enabled linear layers to significantly reduce the number of trainable parameters. + +LoRA (Low-Rank Adaptation) is a parameter-efficient fine-tuning technique that: +- Keeps pre-trained model weights frozen +- Adds small trainable low-rank decomposition matrices to certain layers +- Drastically reduces the number of trainable parameters (often by 90%+) + +In the first chapter we will buildi a LoRA-enabled model from scratch, while the next chapter: "2. Fine-tuning a pre-trained LLM with LoRA" will demonstrate the more common and practical workflow of applying LoRA to existing pre-trained models. + +Both chapters show how to implement these techniques using JAX and Flax's NNX library. + ++++ {"id": "NIOXoY1xgiww"} + +# 1.Creating a LoRa enabled LLM in Jax from scratch + +In this chapter, we'll take an unconventional approach by implementing a language model with LoRA from scratch. This is different from standard practice, where LoRA is typically applied to already pre-trained models as a fine-tuning technique. + +Why are we doing it this way? While not the optimal approach to train a model that achives good preformace (as we'll see in our results), building from scratch makes the integration of LoRA components within the model architecture more clear. + +If you're interested in the more practical approach of applying LoRA to an existing pre-trained model, you can skip to the next chapter where we demonstrate that workflow. + ++++ {"id": "hTmz5Cbco7n_"} + +## Setup +Install required packages + +```{code-cell} ipython3 +--- +colab: + base_uri: https://localhost:8080/ +id: 6zMsOIc7ouCO +outputId: 40d84dff-b5c6-45ed-df08-fb22d3eeb01a +--- +!pip install -q jax-ai-stack +!pip install -Uq tiktoken grain matplotlib +``` + ++++ {"id": "Rcji_799n4eA"} + +Confirm we have TPUs set up. + +```{code-cell} ipython3 +--- +colab: + base_uri: https://localhost:8080/ +id: LS9sQEY3n0mB +outputId: b516c248-777f-4a59-a550-26e12bc2e2fc +--- +import jax +jax.devices() +``` + ++++ {"id": "OHzJ_bokoovZ"} + +Get the [TinyStories dataset from Hugging Face](https://huggingface.co/datasets/roneneldan/TinyStories). We only use the training split. + +```{code-cell} ipython3 +--- +colab: + base_uri: https://localhost:8080/ +id: wUjQsgQEmI1N +outputId: 90fc683c-696f-4f25-a75c-6a2a5b032cef +--- +!wget https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories-train.txt?download=true -O TinyStories-train.txt +``` + ++++ {"id": "sKE2uUafLobI"} + +Import necessary libraries + +```{code-cell} ipython3 +:id: MKYFNOhdLq98 + +import jax +import jax.numpy as jnp +import flax.nnx as nnx +from flax.nnx.nn.lora import LoRALinear # Import LoRALinear +import optax +from dataclasses import dataclass +import grain.python as pygrain +from jax.experimental import mesh_utils +from jax.sharding import Mesh, PartitionSpec as P, NamedSharding +import pandas as pd +import tiktoken +import time +``` + ++++ {"id": "rPyt7MV6prz1"} + +## Building the Model with LoRA + +We'll use the same tokenizer and parallelism strategy as in the [pre-training tutorial](https://docs.jaxstack.ai/en/latest/JAX_for_LLM_pretraining.html). +The mesh defines how our computation will be distributed across TPU cores. + +```{code-cell} ipython3 +:id: xuMlCK3Q8WJD + +tokenizer = tiktoken.get_encoding("gpt2") +mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('batch', 'model')) +``` + ++++ {"id": "0XHQ0BQ9-KIj"} + +The key difference from the original pre-training model is that we replace standard +`nnx.Linear` layers with `LoRALinear` layers from [Flax](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/lora.html). + + +This way, only the small rank decomposition matrices need to be trained. + +```{code-cell} ipython3 +:id: z0p-IHurrB9i + +def causal_attention_mask(seq_len): + return jnp.tril(jnp.ones((seq_len, seq_len))) + +class TransformerBlock(nnx.Module): + # update the __init__ function arguments to include lora_rank + def __init__(self, embed_dim: int, num_heads: int, ff_dim: int, *, rngs: nnx.Rngs, rate: float = 0.1, lora_rank=8): + self.mha = nnx.MultiHeadAttention(num_heads=num_heads, + in_features=embed_dim, + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), NamedSharding(mesh, P(None, 'model'))), + bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P('model'))), + rngs=rngs) + self.dropout1 = nnx.Dropout(rate=rate) + self.layer_norm1 = nnx.LayerNorm(epsilon=1e-6, + num_features=embed_dim, + scale_init=nnx.with_partitioning(nnx.initializers.ones_init(), NamedSharding(mesh, P('model'))), + bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P('model'))), + rngs=rngs) + # here we replace the regular linear layer with the LoRALinea layer + self.linear1 = LoRALinear( + in_features=embed_dim, + out_features=ff_dim, + lora_rank=lora_rank, # set the rank for the low-rank matrices + kernel_init=nnx.with_partitioning(nnx.initializers.normal(0.02), P('model', None)), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, None), + rngs=rngs + ) + # here we replace the regular linear layer with the LoRALinea layer + self.linear2 = LoRALinear( + in_features=ff_dim, + out_features=embed_dim, + lora_rank=lora_rank, + kernel_init=nnx.with_partitioning(nnx.initializers.normal(0.02), P('model', None)), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, None), + rngs=rngs + ) + self.dropout2 = nnx.Dropout(rate=rate) + self.layer_norm2 = nnx.LayerNorm(epsilon=1e-6, + num_features=embed_dim, + scale_init=nnx.with_partitioning(nnx.initializers.ones_init(), NamedSharding(mesh, P(None, 'model'))), + bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P(None, 'model'))), + rngs=rngs) + + + def __call__(self, inputs, training: bool = False): + input_shape = inputs.shape + _, seq_len, _ = input_shape + mask = causal_attention_mask(seq_len) + attention_output = self.mha( + inputs_q=inputs, + mask=mask, + decode=False + ) + attention_output = self.dropout1(attention_output, deterministic=not training) + out1 = self.layer_norm1(inputs + attention_output) + # feed-forward network with LoRA layer + ffn_output = self.linear1(out1) + ffn_output = nnx.relu(ffn_output) + ffn_output = self.linear2(ffn_output) + ffn_output = self.dropout2(ffn_output, deterministic=not training) + + return self.layer_norm2(out1 + ffn_output) + + +class TokenAndPositionEmbedding(nnx.Module): + + def __init__(self, maxlen: int, vocab_size: int, embed_dim: int, *, rngs: nnx.Rngs): + self.token_emb = nnx.Embed(num_embeddings=vocab_size, features=embed_dim, rngs=rngs) + self.pos_emb = nnx.Embed(num_embeddings=maxlen, features=embed_dim, rngs=rngs) + + def __call__(self, x): + positions = jnp.arange(0, x.shape[1])[None, :] + position_embedding = self.pos_emb(positions) + token_embedding = self.token_emb(x) + return token_embedding + position_embedding + + +class MiniGPT(nnx.Module): + # update the __init__ function arguments to include lora_rank + def __init__(self, maxlen: int, vocab_size: int, embed_dim: int, num_heads: int, feed_forward_dim: int, num_transformer_blocks: int, rngs: nnx.Rngs, lora_rank=8): + self.embedding_layer = TokenAndPositionEmbedding( + maxlen, vocab_size, embed_dim, rngs=rngs + ) + # create transformer blocks with LoRA + self.transformer_blocks = [TransformerBlock( + embed_dim, num_heads, feed_forward_dim, rngs=rngs, lora_rank=lora_rank + ) for _ in range(num_transformer_blocks)] + + # modify the output layer to use LoRALinear instead of regular linear layer + self.output_layer = LoRALinear( + in_features=embed_dim, + out_features=vocab_size, + lora_rank=lora_rank, + kernel_init=nnx.with_partitioning(nnx.initializers.normal(0.02), P('model', None)), + bias_init=nnx.with_partitioning(nnx.initializers.zeros, None), + rngs=rngs + ) + + + def __call__(self, inputs, training: bool = False): + x = self.embedding_layer(inputs) + for transformer_block in self.transformer_blocks: + x = transformer_block(x, training=training) + outputs = self.output_layer(x) + return outputs + + def generate_text(self, max_tokens: int, start_tokens: [int], top_k=10): + def sample_from(logits): + logits, indices = jax.lax.top_k(logits, k=top_k) + logits = nnx.softmax(logits) + return jax.random.choice(jax.random.PRNGKey(0), indices, p=logits) + + def generate_step(start_tokens): + pad_len = maxlen - len(start_tokens) + sample_index = len(start_tokens) - 1 + if pad_len < 0: + x = jnp.array(start_tokens[:maxlen]) + sample_index = maxlen - 1 + elif pad_len > 0: + x = jnp.array(start_tokens + [0] * pad_len) + else: + x = jnp.array(start_tokens) + + x = x[None, :] + logits = self(x) + next_token = sample_from(logits[0][sample_index]) + return next_token + + generated = [] + for _ in range(max_tokens): + next_token = generate_step(start_tokens + generated) + if next_token == tokenizer.encode('<|endoftext|>', allowed_special={'<|endoftext|>'})[0]: + break + generated.append(int(next_token)) + return tokenizer.decode(start_tokens + generated) + +# modify the function arguments to include lora_rank +def create_model(rngs, lora_rank=8): + return MiniGPT(maxlen, vocab_size, embed_dim, num_heads, feed_forward_dim, num_transformer_blocks=4, rngs=rngs, + lora_rank=lora_rank) +``` + ++++ {"id": "igX_eoGNMTGR"} + +## Set Hyperparameters + +We'll use the same hyperparameters as in the [pre-training tutorial](https://docs.jaxstack.ai/en/latest/JAX_for_LLM_pretraining.html) for consistency. + +```{code-cell} ipython3 +:id: GRhiDsCrMZRp + +vocab_size = tokenizer.n_vocab +num_transformer_blocks = 8 +maxlen = 256 +embed_dim = 256 +num_heads = 8 +feed_forward_dim = 256 +batch_size = 256 # You can adjust batch size based on your TP +num_epochs = 1 +lora_rank = 128 # A higher rank will capture more complex patterns in the LLM, and will also increase the number of trainable parameters +``` + ++++ {"id": "mI1ci-HyMspJ"} + +## Prepare data + +Data loading and preprocessing remains the same as in the [pre-training tutorial](https://docs.jaxstack.ai/en/latest/JAX_for_LLM_pretraining.html). +We create a TextDataset class to handle tokenization and padding. + +```{code-cell} ipython3 +:id: rGUFsn1GMuzh + +@dataclass +class TextDataset: + data: list + maxlen: int + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx: int): + encoding = tokenizer.encode(self.data[idx], allowed_special={'<|endoftext|>'})[:self.maxlen] # Tokenize and truncate + return encoding + [0] * (self.maxlen - len(encoding)) # Pad to maxlen + +def load_and_preprocess_data(file_path, batch_size, maxlen): + + with open(file_path, 'r') as f: + text = f.read() + + stories = text.split('<|endoftext|>') + stories = [story+'<|endoftext|>' for story in stories if story.strip()] + df = pd.DataFrame({'text': stories}) + data = df['text'].dropna().tolist() + dataset = TextDataset(data, maxlen) + + sampler = pygrain.IndexSampler( + len(dataset), + shuffle=False, + seed=42, + shard_options=pygrain.NoSharding(), + num_epochs=num_epochs, + ) + + dl = pygrain.DataLoader( + data_source=dataset, + sampler=sampler, + operations=[pygrain.Batch(batch_size=batch_size, drop_remainder=True)], + ) + + return dl + +text_dl = load_and_preprocess_data('TinyStories-train.txt', batch_size, maxlen) +``` + ++++ {"id": "BKVSD8KSM1um"} + +## Train the model with LoRA + ++++ {"id": "WbSt_MuyaG48"} + +LoRA's efficiency lies in how we train only the small adapter matrices while keeping the rest of the model frozen. Let's look at how we implement this in JAX: + +```{code-cell} ipython3 +:id: h9hXS0NngSAw + +# Create the model with LoRA +lora_model = create_model(rngs=nnx.Rngs(0), lora_rank=lora_rank) +# Filter for LoRA parameters only (look for lora_a and lora_b in the parameter path) +lora_params = nnx.All(nnx.Param, nnx.PathContains('lora_a') or nnx.PathContains('lora_b')) +# Create optimizer to only update LoRA parameters +optimizer = nnx.Optimizer(lora_model, optax.adam(1e-3), wrt=lora_params) +``` + ++++ {"id": "e5hooDhBadPb"} + + Using `nnx.All` create a mask that identifies only our LoRA parameters, looking for lora_a or lora_b in the parameter paths. Then we: + +- Configure the optimizer to only update these selected parameters using the `wrt` argument +-Create a special `diff_state` that directs gradient computation to only flow to these parameters + +Now we can use this `diff_state` when computing gradients in our training step: + +```{code-cell} ipython3 +:id: reUqnpEtiy0e + +def loss_fn(model, batch): + logits = model(batch[0]) + loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=batch[1]).mean() + return loss, logits + +@nnx.jit +def train_step(model: MiniGPT, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch): + # Create differentiable state that only includes LoRA parameters + diff_state = nnx.DiffState(0, lora_params) + grad_fn = nnx.value_and_grad(loss_fn, argnums=diff_state, has_aux=True) + (loss, logits), grads = grad_fn(model, batch) + metrics.update(loss=loss, logits=logits, lables=batch[1]) + optimizer.update(grads) +``` + +```{code-cell} ipython3 +--- +colab: + base_uri: https://localhost:8080/ +id: Ysl6CsfENeJN +outputId: 3236e7a1-4d6b-4378-b580-65a509236b23 +--- +metrics = nnx.MultiMetric( + loss=nnx.metrics.Average('loss'), +) +rng = jax.random.PRNGKey(0) + +start_prompt = "Once upon a time" +start_tokens = tokenizer.encode(start_prompt)[:maxlen] +generated_text = lora_model.generate_text( + maxlen, start_tokens +) +print(f"Initial generated text:\n{generated_text}\n") + + +metrics_history = { + 'train_loss': [], +} + +prep_target_batch = jax.vmap(lambda tokens: jnp.concatenate((tokens[1:], jnp.array([0])))) + +step = 0 +for epoch in range(num_epochs): + start_time = time.time() + for batch in text_dl: + if len(batch) % len(jax.devices()) != 0: + continue + input_batch = jnp.array(jnp.array(batch).T) + target_batch = prep_target_batch(input_batch) + train_step(lora_model, optimizer, metrics, jax.device_put((input_batch, target_batch), NamedSharding(mesh, P('batch', None)))) + + if (step + 1) % 200 == 0: + for metric, value in metrics.compute().items(): + metrics_history[f'train_{metric}'].append(value) + metrics.reset() + + elapsed_time = time.time() - start_time + print(f"Step {step + 1}, Loss: {metrics_history['train_loss'][-1]}, Elapsed Time: {elapsed_time:.2f} seconds") + start_time = time.time() + + generated_text = lora_model.generate_text( + maxlen, start_tokens + ) + print(f"Generated text:\n{generated_text}\n") + step += 1 + +generated_text = lora_model.generate_text( + maxlen, start_tokens +) +print(f"Final generated text:\n{generated_text}") +``` + ++++ {"id": "thaLs6TD0lt5"} + +Visualize the training loss. + +```{code-cell} ipython3 +--- +colab: + base_uri: https://localhost:8080/ + height: 472 +id: B6Eg1Cz2y_iP +outputId: 227b6ad5-21de-45d8-ab0b-7dc9a931834f +--- +import matplotlib.pyplot as plt +plt.plot(metrics_history['train_loss']) +plt.title('Training Loss') +plt.xlabel('Step') +plt.ylabel('Loss') +plt.show() +``` + +```{code-cell} ipython3 +--- +colab: + base_uri: https://localhost:8080/ +id: Sy74Hsey8bBR +outputId: 78d9ce88-54aa-4501-fca9-42b4df44b466 +--- +# Analysis of LoRA Parameter Efficiency using proper module iteration +import jax.numpy as jnp +import matplotlib.pyplot as plt + +# Initialize counters +total_params = 0 +lora_a_params = 0 +lora_b_params = 0 +other_params = 0 +layer_counts = {} + +# Iterate through all modules and count parameters +for path, module in lora_model.iter_modules(): + module_name = '.'.join(str(p) for p in path) if path else "root" + + module_params = 0 + module_lora_params = 0 + + for name, attr in vars(module).items(): + if name.startswith('_') or not hasattr(attr, 'value'): + continue + + # Get the parameter array + param_array = attr.value + if not isinstance(param_array, jnp.ndarray): + continue + + param_count = param_array.size + module_params += param_count + total_params += param_count + + # Check if this is a LoRA parameter + if name == 'lora_a': + lora_a_params += param_count + module_lora_params += param_count + elif name == 'lora_b': + lora_b_params += param_count + module_lora_params += param_count + else: + other_params += param_count + +# Calculate total LoRA parameters and ratios +lora_params = lora_a_params + lora_b_params +trainable_ratio = lora_params / total_params +frozen_ratio = other_params / total_params + +print(f"Total model parameters: {total_params:,}") +print(f"Frozen parameters: {other_params:,} ({frozen_ratio:.2%})") +print(f"Trainable LoRA parameters: {lora_params:,} ({trainable_ratio:.2%})") +print(f" - LoRA A matrices: {lora_a_params:,}") +print(f" - LoRA B matrices: {lora_b_params:,}") +``` + ++++ {"id": "bOl4CY61Llcp"} + +This example has demonstrated the implementation of LoRA in JAX, showing how to replace standard linear layers with LoRA-enabled versions and train only these adapter parameters while keeping the base model frozen. + +As we've seen from our experiment results, this approach of applying LoRA to a model trained from scratch produced limited generation quality. The text outputs were repetitive and lacked coherence. + +That is because LoRA is designed to make incremental adaptations to already capable models, not to carry the full burden of learning language structure from scratch. The small parameter space of the LoRA matrices (even with rank=128) simply cannot capture the full complexity of language when starting from random initialization. + + +###Next Steps +In a subsequent chapter, we'll explore how to integrate LoRA into existing pre-trained language models rather than training from scratch. If you want to stop here and save your progress you can run the follwing cells: + ++++ {"id": "soPqiR1JNmjf"} + +### Saving + +```{code-cell} ipython3 +:id: EkoFGCgSZ1yz + +import orbax.checkpoint as orbax + +state = nnx.state(lora_model) + +checkpointer = orbax.PyTreeCheckpointer() +checkpointer.save('/content/save', state) + +# Make sure the files are there +!ls /content/save/ +``` + ++++ {"id": "jCApVd7671c1"} + +### Disconnect the Colab runtime + +```{code-cell} ipython3 +:id: NsqYdbrDVKSq + +from google.colab import runtime +runtime.unassign() +``` + ++++ {"id": "BRE3MBAfRS7i"} + +# 2. Fine-tuning a pre-trained LLM with LoRA