diff --git a/docs/tutorials/grpo.md b/docs/tutorials/grpo.md index 3030b1714b..cbf615f25b 100644 --- a/docs/tutorials/grpo.md +++ b/docs/tutorials/grpo.md @@ -45,7 +45,7 @@ Next, run the following bash script to get all the necessary installations insid This will take few minutes. Follow along the installation logs and look out for any issues! ``` -bash ~/maxtext/MaxText/examples/install_tunix_vllm_requirement.sh +bash ~/maxtext/src/MaxText/examples/install_tunix_vllm_requirement.sh ``` 1. It installs `pip install keyring keyrings.google-artifactregistry-auth` which enables pip to authenticate with Google Artifact Registry automatically. diff --git a/src/MaxText/examples/README.md b/src/MaxText/examples/README.md new file mode 100644 index 0000000000..d0e2dbfe12 --- /dev/null +++ b/src/MaxText/examples/README.md @@ -0,0 +1,158 @@ +# MaxText Examples - Setting the Jupyter Lab or Collab to run them on TPU + +This guide provides comprehensive instructions for setting up Jupyter Lab on TPU and connecting it to Google Colab for running MaxText examples. + +## πŸ“‘ Table of Contents + +- [Prerequisites](#prerequisites) +- [Method 1: Google Colab with TPU (Recommended)](#method-1-google-colab-with-tpu-recommended) +- [Method 2: Local Jupyter Lab with TPU](#method-2-local-jupyter-lab-with-tpu) +- [Method 3: Colab + Local Jupyter Lab Hybrid](#method-3-colab--local-jupyter-lab-hybrid) +- [Available Examples](#available-examples) +- [Common Pitfalls & Debugging](#common-pitfalls--debugging) +- [Support & Resources](#support--resources) +- [Contributing](#contributing) + +## Prerequisites + +Before starting, make sure you have: + +- βœ… A Google Cloud Platform (GCP) account with billing enabled +- βœ… TPU quota available in your region (check under IAM & Admin β†’ Quotas) +- βœ… Basic familiarity with Jupyter, Python, and Git +- βœ… gcloud CLI installed locally if you plan to use Method 2 or 3 +- βœ… Firewall rules open for port 8888 (Jupyter) if accessing directly + +## Method 1: Google Colab with TPU (Recommended) + +This is the fastest way to run MaxText without managing infrastructure. + +### Step 1: Open Google Colab + +1. Go to [Google Colab](https://colab.research.google.com/) +2. Sign in β†’ New Notebook + +### Step 2: Enable TPU Runtime + +1. **Runtime** β†’ **Change runtime type** +2. Set **Hardware accelerator** β†’ **TPU** +3. Select TPU version: + - **v5e-8** β†’ recommended for most MaxText examples, but it's a paid option + - **v5e-1** β†’ free tier option (slower, but works for Qwen-0.6B demos) +4. Click **Save** + +### Step 3: Upload & Prepare MaxText + +Upload notebooks or mount your GitHub repo + +> **Note:** In Colab, the repo root will usually be `/content/maxtext` + +**Example:** +```python +!git clone https://github.com/AI-Hypercomputer/maxtext.git +%cd maxtext +``` + +### Step 4: Run Examples + +1. Open `src/MaxText/examples/` +2. Try: + - `sft_qwen3_demo.ipynb` + - `sft_llama3_demo.ipynb` + - `grpo_llama3_demo.ipynb` + + +> ⚑ **Tip:** If Colab disconnects, re-enable TPU and re-run setup cells. Save checkpoints to GCS or Drive. +> ⚑ **Tip:** If Colab asks to restart session - do it and continue to run cells + +## Method 2: Local Jupyter Lab with TPU + +This method gives you more control and is better for long training runs. + +### Step 1: Set Up TPU VM + +In Google Cloud Console: + +1. **Compute Engine** β†’ **TPU** β†’ **Create TPU Node** +2. Example config: + - **Name:** `maxtext-tpu-node` + - **TPU type:** `v5e-8` (or `v6p-8` for newer hardware) + - **Runtime Version:** `tpu-ubuntu-alpha-*` (matches your VM image) + +### Step 2: Connect to TPU VM + +```bash +gcloud compute tpus tpu-vm ssh maxtext-tpu-node --zone=YOUR_ZONE +``` + +### Step 3: Install Dependencies + +```bash +sudo apt update && sudo apt upgrade -y +sudo apt install python3-pip python3-dev git -y +pip3 install jupyterlab +``` + +### Step 4: Start Jupyter Lab + +```bash +jupyter lab --ip=0.0.0.0 --port=8888 --no-browser --allow-root +``` + +Copy the URL with token from terminal + +### Step 5: Secure Access + +#### Option A: SSH Tunnel (Recommended) + +```bash +gcloud compute tpus tpu-vm ssh maxtext-tpu-node --zone=YOUR_ZONE -- -L 8888:localhost:8888 +``` + +Then open β†’ `http://localhost:8888` + + +## Method 3: Colab + Local Jupyter Lab Hybrid + +Set up Jupyter Lab as in step 2. +Use the link for Jupyter Lab as a link for "Connect to a local runtime" in Collab - at the dropdown where you select the runtime. + +## Available Examples + +### Supervised Fine-Tuning (SFT) + +- **`sft_qwen3_demo.ipynb`** β†’ Qwen3-0.6B with Hugging Face ultrachat_200k dataset +- **`sft_llama3_demo.ipynb`** β†’ Llama3.1-8B with Hugging Face ultrachat_200k dataset + +### GRPO Training + +- **`grpo_llama3_demo.ipynb`** β†’ GRPO training on math dataset + +## Common Pitfalls & Debugging + +| Issue | Solution | +|-------|----------| +| ❌ TPU runtime mismatch | Check TPU runtime version matches VM image (`tpu-ubuntu-alpha-*`) | +| ❌ Colab disconnects | Save checkpoints to GCS or Drive regularly | +| ❌ "RESOURCE_EXHAUSTED" errors | Use smaller batch size or v5e-8 instead of v5e-1 | +| ❌ Firewall blocked | Ensure port 8888 open, or always use SSH tunneling | +| ❌ Path confusion | In Colab use `/content/maxtext`; in TPU VM use `~/maxtext` | + +## Support and Resources + +- πŸ“˜ [MaxText Documentation](https://github.com/AI-Hypercomputer/maxtext) +- πŸ’» [Google Colab](https://colab.research.google.com) +- ⚑ [Cloud TPU Docs](https://cloud.google.com/tpu/docs) +- 🧩 [Jupyter Lab](https://jupyterlab.readthedocs.io) + +## Contributing + +If you encounter issues or have improvements for this guide, please: + +1. Open an issue on the MaxText repository +2. Submit a pull request with your improvements +3. Share your experience in the discussions + +--- + +**Happy Training! πŸš€** \ No newline at end of file diff --git a/src/MaxText/examples/sft_llama3_demo.ipynb b/src/MaxText/examples/sft_llama3_demo.ipynb index 457f346593..b40349916b 100644 --- a/src/MaxText/examples/sft_llama3_demo.ipynb +++ b/src/MaxText/examples/sft_llama3_demo.ipynb @@ -54,13 +54,14 @@ "metadata": {}, "outputs": [], "source": [ - "### (Optional) Run this if you just have this file and nothing else\n", "\n", - "# 1. Clone the MaxText repository (from AI‑Hypercomputer)\n", - "!git clone https://github.com/AI-Hypercomputer/maxtext.git\n", + "#Install maxtext and dependencies\n", + "# 1. Install uv, a fast Python package installer\n", + "!pip install uv\n", "\n", - "# 2. Navigate into the cloned directory\n", - "%cd maxtext" + "# 2. Install MaxText and its dependencies\n", + "!uv pip install maxtext --resolution=lowest\n", + "!install_maxtext_github_deps\n" ] }, { @@ -69,22 +70,15 @@ "metadata": {}, "outputs": [], "source": [ - "### (Optional) Do not run this if you already installed the dependencies\n", "\n", - "# 3. Ensure setup.sh is executable\n", - "!chmod +x setup.sh\n", + "## Set up the maxtext environment\n", "\n", - "# 4. Execute the setup script\n", - "!./setup.sh\n", - "\n", - "# force numpy version\n", - "!pip install --force-reinstall numpy==2.1.2\n", - "#install nest_asyncio\n", - "!pip install nest_asyncio\n", + "import MaxText\n", + "import os\n", + "MAXTEXT_REPO_ROOT=os.path.dirname(MaxText.__file__)\n", + "print(f\"MaxText installation path: {MAXTEXT_REPO_ROOT}\")\n", "\n", - "import nest_asyncio\n", - "nest_asyncio.apply()\n", - "# To fix \"This event loop is already running\" error in Colab\n" + "DEBUG = False # set to True to run in debug mode, for more print statements" ] }, { @@ -93,14 +87,16 @@ "metadata": {}, "outputs": [], "source": [ + "## Set the model checkpoint path and output directory\n", "\n", - "import os\n", - "import sys\n", - "# Set home directory. Change this to your home directory where maxtext is cloned\n", - "MAXTEXT_HOME = os.path.expanduser(\"~\") + \"/maxtext\"\n", - "print(f\"Home directory (from Python): {MAXTEXT_HOME}\")\n", - "#set the path to the Llama3.1-8B-Instruct checkpoint you want to load, gs:// supported \n", - "MODEL_CHECKPOINT_PATH = \"path/to/scanned/checkpoint\"" + "# Case 1: Set `MODEL_CHECKPOINT_PATH` to GCS path that already has `Qwen3-0.6B` model checkpoint\n", + "# Case 2: If you do not have the checkpoint, then do not update `MODEL_CHECKPOINT_PATH`\n", + "# and this colab will download the checkpoint from HF and store at `\"{MAXTEXT_REPO_ROOT}/qwen_checkpoint\\\"`\n", + "MODEL_CHECKPOINT_PATH = f\"{MAXTEXT_REPO_ROOT}/qwen_checkpoint\"\n", + "\n", + "# This is the directory where the fine-tuned model will be saved\n", + "# You can change it to any path you want including GCS gs://...\n", + "BASE_OUTPUT_DIRECTORY = \"/tmp/out/llama_output\"" ] }, { @@ -109,25 +105,7 @@ "metadata": {}, "outputs": [], "source": [ - "from pathlib import Path\n", - "from typing import Optional, Dict, Any\n", - "\n", - "# Find MaxText directory and change working directory to it\n", - "current_dir = Path.cwd()\n", - "if current_dir.name == 'examples':\n", - " # We're in the examples folder, go up one level\n", - " maxtext_path = current_dir.parent.parent\n", - "else:\n", - " # We're in the root, MaxText is a subfolder\n", - " maxtext_path = Path(f'{MAXTEXT_HOME}') / 'src' / 'MaxText'\n", - "\n", - "# Change working directory to MaxText project root\n", - "os.chdir(maxtext_path)\n", - "sys.path.insert(0, str(maxtext_path))\n", "\n", - "print(f\"βœ“ Changed working directory to: {os.getcwd()}\")\n", - "print(f\"βœ“ MaxText project root: {maxtext_path}\")\n", - "print(f\"βœ“ Added to Python path: {maxtext_path}\")\n", "import jax\n", "if not jax.distributed.is_initialized():\n", " jax.distributed.initialize() \n", @@ -199,7 +177,7 @@ " # Proper config setup using MaxText's config system\n", " config_argv = [\n", " \"\", \n", - " f\"{MAXTEXT_HOME}/src/MaxText/configs/sft.yml\", # SFT config\n", + " f\"{MAXTEXT_REPO_ROOT}/configs/sft.yml\", # SFT config\n", " f\"load_parameters_path={MODEL_CHECKPOINT_PATH}\",\n", " \"model_name=llama3.1-8b\",\n", " \"steps=100\",\n", @@ -211,7 +189,7 @@ " \"dtype=bfloat16\",\n", " \"hf_path=HuggingFaceH4/ultrachat_200k\",\n", " f\"hf_access_token={HF_TOKEN}\",\n", - " \"base_output_directory=/tmp/maxtext_output\",\n", + " \"base_output_directory={BASE_OUTPUT_DIRECTORY}\",\n", " \"run_name=sft_llama3_demo\",\n", " \"tokenizer_path=meta-llama/Llama-3.1-8B-Instruct\",\n", " \"eval_interval=10\",\n", @@ -255,6 +233,7 @@ " sft_train(config) \n", " \n", " print(\"\\nβœ… Training completed successfully!\")\n", + " print(\"Model saved at: \", BASE_OUTPUT_DIRECTORY)\n", " \n", "else:\n", " print(\"MaxText not available - cannot execute training\")\n" diff --git a/src/MaxText/examples/sft_qwen3_demo.ipynb b/src/MaxText/examples/sft_qwen3_demo.ipynb new file mode 100644 index 0000000000..bc67bec851 --- /dev/null +++ b/src/MaxText/examples/sft_qwen3_demo.ipynb @@ -0,0 +1,271 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/AI-Hypercomputer/maxtext/blob/qwen3-sft-collab/src/MaxText/examples/sft_qwen3_demo.ipynb)\n", + "\n", + "# Qwen3-0.6B Supervised Fine-Tuning (SFT) Demo\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Overview\n", + "\n", + "This notebook can run on the **public TPU 5e-1**\n", + "\n", + "This notebook demonstrates how to perform Supervised Fine-Tuning (SFT) on Qwen3-0.6B using the Hugging Face ultrachat_200k dataset with Tunix integration for efficient training.\n", + "\n", + "## Dataset Overview\n", + "\n", + "**Dataset Link:** https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k\n", + "\n", + "### Dataset Information:\n", + "- **Name:** HuggingFaceH4/ultrachat_200k\n", + "- **Type:** Supervised Fine-Tuning dataset\n", + "- **Size:** ~200k conversations\n", + "- **Format:** Chat conversations with human-AI pairs\n", + "- **Splits:** train_sft, test_sft\n", + "- **Data columns:** ['messages']\n", + "\n", + "### Dataset Structure:\n", + "Each example contains a 'messages' field with:\n", + "- **role:** 'user' or 'assistant'\n", + "- **content:** The actual message text\n", + "\n", + "### Example data format:\n", + "```json\n", + "{\n", + " \"messages\": [\n", + " {\"role\": \"user\", \"content\": \"What is the capital of France?\"},\n", + " {\"role\": \"assistant\", \"content\": \"The capital of France is Paris.\"}\n", + " ]\n", + "}\n", + "```\n", + "\n", + "## Prerequisites\n", + "- HuggingFace access token for dataset download\n", + "- Sufficient compute resources (TPU/GPU)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "5KPyOE8e9WbO" + }, + "outputs": [], + "source": [ + "#Install maxtext and dependencies\n", + "# 1. Install uv, a fast Python package installer\n", + "!pip install uv\n", + "\n", + "# 2. Install MaxText and its dependencies\n", + "!uv pip install maxtext --resolution=lowest\n", + "!install_maxtext_github_deps" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "## Set up the maxtext environment\n", + "\n", + "import MaxText\n", + "import os\n", + "MAXTEXT_REPO_ROOT=os.path.dirname(MaxText.__file__)\n", + "print(f\"MaxText installation path: {MAXTEXT_REPO_ROOT}\")\n", + "\n", + "DEBUG = False # set to True to run in debug mode, for more print statements" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "## Set the model checkpoint path and output directory\n", + "\n", + "# Case 1: Set `MODEL_CHECKPOINT_PATH` to GCS path that already has `Qwen3-0.6B` model checkpoint\n", + "# Case 2: If you do not have the checkpoint, then do not update `MODEL_CHECKPOINT_PATH`\n", + "# and this colab will download the checkpoint from HF and store at `\"{MAXTEXT_REPO_ROOT}/qwen_checkpoint\\\"`\n", + "MODEL_CHECKPOINT_PATH = f\"{MAXTEXT_REPO_ROOT}/qwen_checkpoint\"\n", + "\n", + "# This is the directory where the fine-tuned model will be saved\n", + "# You can change it to any path you want including GCS gs://...\n", + "BASE_OUTPUT_DIRECTORY = \"/tmp/out/maxtext_qwen06\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Hugging Face Authentication Setup\n", + "from huggingface_hub import login\n", + "\n", + "# Set your Hugging Face token as a secret in the Google Colab \n", + "from google.colab import userdata\n", + "HF_TOKEN = userdata.get(\"HF_TOKEN\")\n", + "# HF_TOKEN = \"your_actual_token_here\" - use this for a private jupyter lab\n", + "login(token=HF_TOKEN)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# This is the command to convert the HF model to the MaxText format \n", + "# You may omit it if you already have a checkpoint\n", + "if not os.path.exists(MODEL_CHECKPOINT_PATH):\n", + " !python3 -m MaxText.utils.ckpt_conversion.to_maxtext \\\n", + " $MAXTEXT_REPO_ROOT/src/MaxText/configs/base.yml \\\n", + " model_name=qwen3-0.6b \\\n", + " base_output_directory=$MODEL_CHECKPOINT_PATH \\\n", + " hf_access_token=$HF_TOKEN \\\n", + " use_multimodal=false \\\n", + " scan_layers=false" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "CxzKMBQd_U5-" + }, + "outputs": [], + "source": [ + "# this is the code to initialize jax if it's not initialized in the cell above\n", + "import jax\n", + "if not jax.distributed.is_initialized():\n", + " jax.distributed.initialize()\n", + "print(f\"JAX version: {jax.__version__}\")\n", + "print(f\"JAX devices: {jax.devices()}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# MaxText imports\n", + "try:\n", + " from MaxText import pyconfig\n", + " from MaxText.sft.sft_trainer import train as sft_train\n", + "\n", + " MAXTEXT_AVAILABLE = True\n", + " print(\"βœ“ MaxText imports successful\")\n", + "except ImportError as e:\n", + " print(f\"⚠️ MaxText not available: {e}\")\n", + " MAXTEXT_AVAILABLE = False" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "In-jdp1AAwrL" + }, + "outputs": [], + "source": [ + "# Fixed configuration setup for Qwen-0.6B on small TPU\n", + "if MAXTEXT_AVAILABLE:\n", + " config_argv = [\n", + " \"\",\n", + " f\"{MAXTEXT_REPO_ROOT}/src/MaxText/configs/sft.yml\", # base SFT config\n", + " f\"load_parameters_path={MODEL_CHECKPOINT_PATH}/0/items/\", # Load pre-trained weights!, replace with your checkpoint path\n", + " \"model_name=qwen3-0.6b\",\n", + " \"steps=20\", # very short run for testing\n", + " \"per_device_batch_size=1\", # minimal to avoid OOM\n", + " \"max_target_length=1024\", \n", + " \"learning_rate=2.0e-5\", # safe small LR\n", + " \"eval_steps=5\",\n", + " \"weight_dtype=bfloat16\",\n", + " \"dtype=bfloat16\",\n", + " \"hf_path=HuggingFaceH4/ultrachat_200k\", # HuggingFace dataset/model if needed\n", + " f\"hf_access_token={HF_TOKEN}\",\n", + " f\"base_output_directory={BASE_OUTPUT_DIRECTORY}\",\n", + " \"run_name=sft_qwen0.6b_test\",\n", + " \"tokenizer_path=Qwen/Qwen3-0.6B\", # Qwen tokenizer\n", + " \"eval_interval=10\",\n", + " \"profiler=xplane\",\n", + " ]\n", + "\n", + " # Initialize configuration using MaxText's pyconfig\n", + " config = pyconfig.initialize(config_argv)\n", + "\n", + " print(\"βœ“ Fixed configuration loaded:\")\n", + " print(f\" - Model: {config.model_name}\")\n", + " print(f\" - Dataset: {config.hf_path}\")\n", + " print(f\" - Steps: {config.steps}\")\n", + " print(f\" - Use SFT: {config.use_sft}\")\n", + " print(f\" - Learning Rate: {config.learning_rate}\")\n", + "else:\n", + " print(\"MaxText not available - cannot load configuration\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "EJE1ookSAzz-" + }, + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "mgwpNgQYCJEd" + }, + "outputs": [], + "source": [ + "# Execute the training using MaxText SFT trainer's train() function\n", + "if MAXTEXT_AVAILABLE:\n", + " print(\"=\"*60)\n", + " print(\"EXECUTING ACTUAL TRAINING\")\n", + " print(\"=\"*60)\n", + "\n", + " sft_train(config)\n", + "\n", + "print(\"Training complete!\")\n", + "print(\"Model saved at: \", BASE_OUTPUT_DIRECTORY)" + ] + } + ], + "metadata": { + "accelerator": "TPU", + "colab": { + "gpuType": "V5E1", + "provenance": [] + }, + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.2" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}