Skip to content

Commit a55e18a

Browse files
Merge pull request #2278 from AI-Hypercomputer:collabs-examples-sft
PiperOrigin-RevId: 807495969
2 parents 9c13728 + 63fb47b commit a55e18a

File tree

1 file changed

+288
-0
lines changed

1 file changed

+288
-0
lines changed
Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# Run SFT on Llama3.1-8B-Instruct model\n",
8+
"\n",
9+
"This notebook demonstrates how to perform Supervised Fine-Tuning (SFT) on Llama3.1-8B-Instruct using the Hugging Face ultrachat_200k dataset with Tunix integration for efficient training.\n",
10+
"\n",
11+
"## Dataset Overview\n",
12+
"https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k\n",
13+
"\n",
14+
"**Dataset Information:**\n",
15+
"- **Name**: HuggingFaceH4/ultrachat_200k\n",
16+
"- **Type**: Supervised Fine-Tuning dataset\n",
17+
"- **Size**: ~200k conversations\n",
18+
"- **Format**: Chat conversations with human-AI pairs\n",
19+
"- **Splits**: train_sft, test_sft\n",
20+
"- **Data columns**: ['messages']\n",
21+
"\n",
22+
"**Dataset Structure:**\n",
23+
"Each example contains a 'messages' field with:\n",
24+
"- role: 'user' or 'assistant'\n",
25+
"- content: The actual message text\n",
26+
"\n",
27+
"**Example data format:**\n",
28+
"```json\n",
29+
"{\n",
30+
" \"messages\": [\n",
31+
" {\"role\": \"user\", \"content\": \"What is the capital of France?\"},\n",
32+
" {\"role\": \"assistant\", \"content\": \"The capital of France is Paris.\"}\n",
33+
" ]\n",
34+
"}\n",
35+
"```\n",
36+
"\n",
37+
"## Key Features\n",
38+
"- **MaxText Llama3.1-8B-Instruct model** \n",
39+
"- **Tunix integration** for optimized training\n",
40+
"- **UltraChat-200k dataset** from HuggingFace\n",
41+
"- Tokenizes with meta-llama/Llama-3.1-8B-Instruct\n",
42+
"\n",
43+
"\n",
44+
"## Prerequisites\n",
45+
"- MaxText environment with all dependencies\n",
46+
"- Tunix installation\n",
47+
"- HuggingFace access token for dataset download\n",
48+
"- Sufficient compute resources (TPU/GPU)\n"
49+
]
50+
},
51+
{
52+
"cell_type": "code",
53+
"execution_count": null,
54+
"metadata": {},
55+
"outputs": [],
56+
"source": [
57+
"### (Optional) Run this if you just have this file and nothing else\n",
58+
"\n",
59+
"# 1. Clone the MaxText repository (from AI‑Hypercomputer)\n",
60+
"!git clone https://github.com/AI-Hypercomputer/maxtext.git\n",
61+
"\n",
62+
"# 2. Navigate into the cloned directory\n",
63+
"%cd maxtext"
64+
]
65+
},
66+
{
67+
"cell_type": "code",
68+
"execution_count": null,
69+
"metadata": {},
70+
"outputs": [],
71+
"source": [
72+
"### (Optional) Do not run this if you already installed the dependencies\n",
73+
"\n",
74+
"# 3. Ensure setup.sh is executable\n",
75+
"!chmod +x setup.sh\n",
76+
"\n",
77+
"# 4. Execute the setup script\n",
78+
"!./setup.sh\n",
79+
"\n",
80+
"# force numpy version\n",
81+
"!pip install --force-reinstall numpy==2.1.2\n",
82+
"#install nest_asyncio\n",
83+
"!pip install nest_asyncio\n",
84+
"\n",
85+
"import nest_asyncio\n",
86+
"nest_asyncio.apply()\n",
87+
"# To fix \"This event loop is already running\" error in Colab\n"
88+
]
89+
},
90+
{
91+
"cell_type": "code",
92+
"execution_count": null,
93+
"metadata": {},
94+
"outputs": [],
95+
"source": [
96+
"\n",
97+
"import os\n",
98+
"import sys\n",
99+
"# Set home directory. Change this to your home directory where maxtext is cloned\n",
100+
"MAXTEXT_HOME = os.path.expanduser(\"~\") + \"/maxtext\"\n",
101+
"print(f\"Home directory (from Python): {MAXTEXT_HOME}\")\n",
102+
"#set the path to the Llama3.1-8B-Instruct checkpoint you want to load, gs://<bucket> supported \n",
103+
"MODEL_CHECKPOINT_PATH = \"path/to/scanned/checkpoint\""
104+
]
105+
},
106+
{
107+
"cell_type": "code",
108+
"execution_count": null,
109+
"metadata": {},
110+
"outputs": [],
111+
"source": [
112+
"from pathlib import Path\n",
113+
"from typing import Optional, Dict, Any\n",
114+
"\n",
115+
"# Find MaxText directory and change working directory to it\n",
116+
"current_dir = Path.cwd()\n",
117+
"if current_dir.name == 'examples':\n",
118+
" # We're in the examples folder, go up one level\n",
119+
" maxtext_path = current_dir.parent.parent\n",
120+
"else:\n",
121+
" # We're in the root, MaxText is a subfolder\n",
122+
" maxtext_path = Path(f'{MAXTEXT_HOME}') / 'src' / 'MaxText'\n",
123+
"\n",
124+
"# Change working directory to MaxText project root\n",
125+
"os.chdir(maxtext_path)\n",
126+
"sys.path.insert(0, str(maxtext_path))\n",
127+
"\n",
128+
"print(f\"✓ Changed working directory to: {os.getcwd()}\")\n",
129+
"print(f\"✓ MaxText project root: {maxtext_path}\")\n",
130+
"print(f\"✓ Added to Python path: {maxtext_path}\")\n",
131+
"import jax\n",
132+
"if not jax.distributed.is_initialized():\n",
133+
" jax.distributed.initialize() \n",
134+
"print(f\"JAX version: {jax.__version__}\")\n",
135+
"print(f\"JAX devices: {jax.devices()}\")\n"
136+
]
137+
},
138+
{
139+
"cell_type": "markdown",
140+
"metadata": {},
141+
"source": [
142+
"## Hugging Face Authentication Setup\n",
143+
"\n",
144+
"If you encounter 401 unauthorized errors when loading datasets, you need to authenticate with Hugging Face. Set your token below:\n"
145+
]
146+
},
147+
{
148+
"cell_type": "code",
149+
"execution_count": null,
150+
"metadata": {},
151+
"outputs": [],
152+
"source": [
153+
"# Hugging Face Authentication Setup\n",
154+
"from huggingface_hub import login\n",
155+
"\n",
156+
"# Set your Hugging Face token here\n",
157+
"HF_TOKEN = \"hf_your_token_here\" # Replace with your actual token\n",
158+
"login(token=HF_TOKEN)\n",
159+
" "
160+
]
161+
},
162+
{
163+
"cell_type": "code",
164+
"execution_count": null,
165+
"metadata": {},
166+
"outputs": [],
167+
"source": [
168+
"# MaxText imports \n",
169+
"try:\n",
170+
" from MaxText import pyconfig\n",
171+
" from MaxText.sft.sft_trainer import train as sft_train\n",
172+
"\n",
173+
" MAXTEXT_AVAILABLE = True\n",
174+
" print(\"✓ MaxText imports successful\")\n",
175+
"except ImportError as e:\n",
176+
" print(f\"⚠️ MaxText not available: {e}\")\n",
177+
" MAXTEXT_AVAILABLE = False\n"
178+
]
179+
},
180+
{
181+
"cell_type": "markdown",
182+
"metadata": {},
183+
"source": [
184+
"## Configuration Setup\n",
185+
"\n",
186+
"## Notes\n",
187+
"- Trains on completion only (sft_train_on_completion_only=True)\n",
188+
"- Please set sft_train_on_completion_only=False to train both on prompts and completions. By default SFT will train only on completions."
189+
]
190+
},
191+
{
192+
"cell_type": "code",
193+
"execution_count": null,
194+
"metadata": {},
195+
"outputs": [],
196+
"source": [
197+
"# Fixed configuration setup\n",
198+
"if MAXTEXT_AVAILABLE:\n",
199+
" # Proper config setup using MaxText's config system\n",
200+
" config_argv = [\n",
201+
" \"\", \n",
202+
" f\"{MAXTEXT_HOME}/src/MaxText/configs/sft.yml\", # SFT config\n",
203+
" f\"load_parameters_path={MODEL_CHECKPOINT_PATH}\",\n",
204+
" \"model_name=llama3.1-8b\",\n",
205+
" \"steps=100\",\n",
206+
" \"per_device_batch_size=1\",\n",
207+
" \"max_target_length=1024\",\n",
208+
" \"learning_rate=2.0e-5\",\n",
209+
" \"eval_steps=5\",\n",
210+
" \"weight_dtype=bfloat16\",\n",
211+
" \"dtype=bfloat16\",\n",
212+
" \"hf_path=HuggingFaceH4/ultrachat_200k\",\n",
213+
" f\"hf_access_token={HF_TOKEN}\",\n",
214+
" \"base_output_directory=/tmp/maxtext_output\",\n",
215+
" \"run_name=sft_llama3_demo\",\n",
216+
" \"tokenizer_path=meta-llama/Llama-3.1-8B-Instruct\",\n",
217+
" \"eval_interval=10\",\n",
218+
" \"profiler=xplane\",\n",
219+
" ]\n",
220+
" \n",
221+
" # Initialize configuration using MaxText's pyconfig\n",
222+
" config = pyconfig.initialize(config_argv)\n",
223+
" \n",
224+
" print(\"✓ Fixed configuration loaded:\")\n",
225+
" print(f\" - Model: {config.model_name}\")\n",
226+
" print(f\" - Dataset: {config.hf_path}\")\n",
227+
" print(f\" - Steps: {config.steps}\")\n",
228+
" print(f\" - Use SFT: {config.use_sft}\")\n",
229+
" print(f\" - Learning Rate: {config.learning_rate}\")\n",
230+
"else:\n",
231+
" print(\"MaxText not available - cannot load configuration\")\n"
232+
]
233+
},
234+
{
235+
"cell_type": "markdown",
236+
"metadata": {},
237+
"source": [
238+
"## Execute Actual Training\n",
239+
"\n",
240+
"Let's actually run the training using the MaxText SFT trainer's `train()` function.\n"
241+
]
242+
},
243+
{
244+
"cell_type": "code",
245+
"execution_count": null,
246+
"metadata": {},
247+
"outputs": [],
248+
"source": [
249+
"# Execute the training using MaxText SFT trainer's train() function\n",
250+
"if MAXTEXT_AVAILABLE:\n",
251+
" print(\"=\"*60)\n",
252+
" print(\"EXECUTING ACTUAL TRAINING\")\n",
253+
" print(\"=\"*60)\n",
254+
" \n",
255+
" sft_train(config) \n",
256+
" \n",
257+
" print(\"\\n✅ Training completed successfully!\")\n",
258+
" \n",
259+
"else:\n",
260+
" print(\"MaxText not available - cannot execute training\")\n"
261+
]
262+
},
263+
{
264+
"cell_type": "markdown",
265+
"metadata": {},
266+
"source": [
267+
"## Summary\n",
268+
"\n",
269+
"This notebook demonstrated the complete MaxText & Tunix integration for SFT training.\n",
270+
"\n",
271+
"\n",
272+
"The integration provides the best of both worlds: MaxText's high-performance LLM training and Tunix's optimized training infrastructure, making it ideal for production SFT training on large datasets like UltraChat-200k.\n"
273+
]
274+
},
275+
{
276+
"cell_type": "markdown",
277+
"metadata": {},
278+
"source": []
279+
}
280+
],
281+
"metadata": {
282+
"language_info": {
283+
"name": "python"
284+
}
285+
},
286+
"nbformat": 4,
287+
"nbformat_minor": 2
288+
}

0 commit comments

Comments
 (0)