Skip to content

Commit 3b27bed

Browse files
Merge pull request #2298 from AI-Hypercomputer:ckpt_conversion
PiperOrigin-RevId: 805869447
2 parents 3fd9e49 + 3f91bde commit 3b27bed

File tree

6 files changed

+74
-56
lines changed

6 files changed

+74
-56
lines changed

end_to_end/tpu/gemma2/2b/test_gemma2_to_hf.sh

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
#!/bin/bash
22

3-
# This file is both an integration test that runs once a day on a v4-8 and documentation for how to get started with Qwen3-4B.
3+
# This script is both an end-to-end test that runs once a day on a v4-8 and documentation for how to get started with Gemma2-2B.
44

5-
# The flow of this file is as follows:
6-
# 1. Convert the checkpoint downloaded from Hugging Face to make it compatible with MaxText
7-
# 2. Run a forward pass check to compare the logits and KL divergence between the converted ckpt and orginal golden HF model
5+
# The flow of this script is as follows:
6+
# 1. Convert a MaxText checkpoint to a Hugging Face model checkpoint.
7+
# 2. Run a forward pass check to compare the logits and KL divergence between the converted ckpt and orginal golden HF model.
8+
9+
# Pre-requisites:
10+
# 1. Set HF_TOKEN environment variable to your Hugging Face access token with read permissions
11+
# export HF_TOKEN=<Hugging Face access token>
812

913

1014
set -ex
1115
idx=$(date +%Y-%m-%d-%H-%M)
1216
MODEL_NAME='gemma2-2b'
1317
export MODEL_VARIATION='2b'
14-
HF_TOKEN='' # Important!!! Save your hf access token here
1518
TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_REPO_ROOT:-$PWD}/assets}"'/tokenizer.gemma'
1619

1720
# Installing torch for deps in forward_pass_logit_checker.py
@@ -33,7 +36,7 @@ python3 -m MaxText.utils.ckpt_conversion.to_huggingface "${MAXTEXT_PKG_DIR:-${MA
3336
hf_access_token=${HF_TOKEN} \
3437
load_parameters_path=${CKPT_PATH} \
3538
base_output_directory=${LOCAL_PATH} \
36-
scan_layers=false
39+
scan_layers=false
3740

3841
# Alternatively, if uploaded the converted ckpt, HF requires local storage of model
3942
# mkdir -p "${LOCAL_PATH}"
@@ -48,4 +51,4 @@ python3 -m tests.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_R
4851
scan_layers=false \
4952
--hf_model_path=${LOCAL_PATH} \
5053
--max_kl_div=0.015 \
51-
--run_hf_model=true
54+
--run_hf_model=true

end_to_end/tpu/gemma2/2b/test_gemma2_to_mt.sh

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
11
#!/bin/bash
22

3-
# This file contains an end-to-end Airflow nightly test, designed to run once a day on a v4-8, along with documentation to guide users in getting started with Gemma2-2B.
3+
# This script is both an end-to-end test that runs once a day on a v4-8 and documentation for how to get started with Gemma2-2B.
44

5-
# The flow of this file is as follows:
6-
# 1. Convert the checkpoint downloaded from Hugging Face to make it compatible with MaxText
7-
# 2. Run a forward pass logits check to compare with the original HF golden model
8-
# 2. Run decoding, finetuning of Gemma2-2B. with the converted checkpoint.
9-
# 3. Run decoding from the finetuned checkpoint from step 2
5+
# The flow of this script is as follows:
6+
# 1. Convert the checkpoint downloaded from Hugging Face to make it compatible with MaxText.
7+
# 2. Run a forward pass logits check to compare with the original HF golden model.
8+
# 3. Run decoding, finetuning of Gemma2-2B. with the converted checkpoint.
9+
# 4. Run decoding from the finetuned checkpoint from step 3.
10+
11+
# Pre-requisites:
12+
# 1. Set HF_TOKEN environment variable to your Hugging Face access token with read permissions
13+
# export HF_TOKEN=<Hugging Face access token>
1014

1115

1216
set -ex
1317
idx=$(date +%Y-%m-%d-%H-%M)
1418
MODEL_NAME='gemma2-2b'
1519
export MODEL_VARIATION='2b'
16-
HF_TOKEN='' # Important!!! Save your hf access token here
1720
HF_GOLDEN_MODEL='google/gemma-2-2b'
1821
TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_REPO_ROOT:-$PWD}/assets}"'/tokenizer.gemma'
1922

end_to_end/tpu/gemma3/4b/test_gemma3_to_hf.sh

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
11
#!/bin/bash
22

3-
# This file is both an integration test that runs once a day on a v4-8 and documentation for how to get started with Qwen3-4B.
3+
# This script is both an end-to-end test that runs once a day on a v4-8 and documentation for how to get started with Gemma3-4B.
44

5-
# The flow of this file is as follows:
6-
# 1. Convert the checkpoint downloaded from Hugging Face to make it compatible with MaxText
7-
# 2. Run a forward pass check to compare the logits and KL divergence between the converted ckpt and orginal golden HF model
5+
# The flow of this script is as follows:
6+
# 1. Convert a MaxText checkpoint to a Hugging Face model checkpoint.
7+
# 2. Run a forward pass check to compare the logits and KL divergence between the converted ckpt and orginal golden HF model.
8+
9+
# Pre-requisites:
10+
# 1. Set HF_TOKEN environment variable to your Hugging Face access token with read permissions
11+
# export HF_TOKEN=<Hugging Face access token>
812

913
set -ex
1014
idx=$(date +%Y-%m-%d-%H-%M)
1115
MODEL_NAME='gemma3-4b'
1216
export MODEL_VARIATION='4b'
13-
HF_TOKEN='' # Important!!! Save your hf access token here
1417
TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_REPO_ROOT:-$PWD}/assets}"'/tokenizer.gemma3'
1518
# To convert the multimodal model, make sure the use_multimodal is set to be true
16-
USE_MULTIMODAL=true
19+
USE_MULTIMODAL=false
1720

1821
# Installing torch for deps in forward_pass_logit_checker.py
1922
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
@@ -35,7 +38,7 @@ python3 -m MaxText.utils.ckpt_conversion.to_huggingface "${MAXTEXT_PKG_DIR:-${MA
3538
load_parameters_path=${CKPT_PATH} \
3639
base_output_directory=${LOCAL_PATH} \
3740
use_multimodal=${USE_MULTIMODAL} \
38-
scan_layers=false
41+
scan_layers=false
3942

4043
# Alternatively, if uploaded the converted ckpt, HF requires local storage of model
4144
# mkdir -p "${LOCAL_PATH}"
@@ -51,4 +54,4 @@ python3 -m tests.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_R
5154
scan_layers=false \
5255
--hf_model_path=${LOCAL_PATH} \
5356
--max_kl_div=0.015 \
54-
--run_hf_model=true
57+
--run_hf_model=true

end_to_end/tpu/gemma3/4b/test_gemma3_to_mt.sh

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,26 @@
11
#!/bin/bash
22

3-
# This file contains an end-to-end Airflow nightly test, designed to run once a day on a v4-8, along with documentation to guide users in getting started with Gemma3-4B.
3+
# This script is both an end-to-end test that runs once a day on a v4-8 and documentation for how to get started with Gemma3-4B.
44

5-
# The flow of this file is as follows:
6-
# 1. Convert the checkpoint downloaded from Hugging Face to make it compatible with MaxText
7-
# 2. Run a forward pass logits check to compare with the original HF golden model
8-
# 2. Run decoding, finetuning of Gemma3-4B. with the converted checkpoint.
9-
# 3. Run decoding from the finetuned checkpoint from step 2
5+
# The flow of this script is as follows:
6+
# 1. Convert the checkpoint downloaded from Hugging Face to make it compatible with MaxText.
7+
# 2. Run a forward pass logits check to compare with the original HF golden model.
8+
# 3. Run decoding, finetuning of Gemma3-4B. with the converted checkpoint.
9+
# 4. Run decoding from the finetuned checkpoint from step 3.
10+
11+
# Pre-requisites:
12+
# 1. Set HF_TOKEN environment variable to your Hugging Face access token with read permissions
13+
# export HF_TOKEN=<Hugging Face access token>
1014

1115

1216
set -ex
1317
idx=$(date +%Y-%m-%d-%H-%M)
1418
MODEL_NAME='gemma3-4b'
1519
export MODEL_VARIATION='4b'
16-
HF_TOKEN='' # Important!!! Save your hf access token here
1720
HF_GOLDEN_MODEL='google/gemma-3-4b-it'
1821
TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_REPO_ROOT:-$PWD}/assets}"'/tokenizer.gemma3'
1922
# To convert the multimodal model, make sure the use_multimodal is set to be true
20-
USE_MULTIMODAL=true
23+
USE_MULTIMODAL=false
2124

2225
# Installing torch for deps in forward_pass_logit_checker.py
2326
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
@@ -33,18 +36,18 @@ python3 -m MaxText.utils.ckpt_conversion.to_maxtext "${MAXTEXT_PKG_DIR:-${MAXTEX
3336
hf_access_token=${HF_TOKEN} \
3437
base_output_directory=${MODEL_BUCKET}/${MODEL_VARIATION}/unscanned/${idx} \
3538
use_multimodal=${USE_MULTIMODAL} \
36-
scan_layers=false
39+
scan_layers=false
3740

3841
export UNSCANNED_CKPT_PATH=${MODEL_BUCKET}/${MODEL_VARIATION}/unscanned/${idx}/0/items
3942

40-
# # To get scanned ckpt, flip the scan_layers.
43+
# # To get scanned ckpt, flip the scan_layers.
4144
# ToDo: gemma3 multimodal scanned ckpt
4245
# python3 -m MaxText.utils.ckpt_conversion.to_maxtext src/MaxText/configs/base.yml \
4346
# model_name=${MODEL_NAME} \
4447
# hf_access_token=${HF_TOKEN} \
4548
# base_output_directory=${MODEL_BUCKET}/${MODEL_VARIATION}/scanned/${idx} \
4649
# use_multimodal=${USE_MULTIMODAL} \
47-
# scan_layers=true
50+
# scan_layers=true
4851

4952
# export SCANNED_CKPT_PATH=${MODEL_BUCKET}/${MODEL_VARIATION}/scanned/${idx}/0/items
5053

@@ -53,14 +56,14 @@ export UNSCANNED_CKPT_PATH=${MODEL_BUCKET}/${MODEL_VARIATION}/unscanned/${idx}/0
5356

5457
# ToDo: improve forward_pass_logit_checker to test multi-modal prompt
5558
python3 -m tests.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml \
56-
tokenizer_path==${TOKENIZER_PATH} \
59+
tokenizer_path=${TOKENIZER_PATH} \
5760
load_parameters_path=${UNSCANNED_CKPT_PATH} \
5861
model_name=${MODEL_NAME} \
5962
use_multimodal=${USE_MULTIMODAL} \
6063
scan_layers=false \
6164
--hf_model_path=${HF_GOLDEN_MODEL} \
6265
--max_kl_div=0.015 \
63-
--run_hf_model=true
66+
--run_hf_model=true
6467

6568
# We can run decoding for unscanned checkpoints.
6669
if [ ${USE_MULTIMODAL} == true ]; then
@@ -84,4 +87,4 @@ if [ ${USE_MULTIMODAL} == true ]; then
8487
python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml model_name=${MODEL_NAME} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${FINETUNE_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=ht_test max_prefill_predict_length=272 max_target_length=300 steps=1 async_checkpointing=false scan_layers=false use_multimodal=${USE_MULTIMODAL} prompt=\'Describe\ image\ \<start_of_image\>\' image_path=\'src/MaxText/test_assets/test_image.jpg\' attention=\'dot_product\'
8588
else
8689
python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml model_name=${MODEL_NAME} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${FINETUNE_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=ht_test max_prefill_predict_length=8 max_target_length=16 steps=1 async_checkpointing=false scan_layers=false prompt='I love to' attention=\'dot_product\'
87-
fi
90+
fi

end_to_end/tpu/qwen3/4b/test_qwen3_to_hf.sh

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
#!/bin/bash
22

3-
# This file is both an integration test that runs once a day on a v4-8 and documentation for how to get started with Qwen3-4B.
3+
# This script is both an end-to-end test that runs once a day on a v4-8 and documentation for how to get started with Qwen3-4B.
44

5-
# The flow of this file is as follows:
6-
# 1. Convert the checkpoint downloaded from Hugging Face to make it compatible with MaxText
7-
# 2. Run a forward pass check to compare the logits and KL divergence between the converted ckpt and orginal golden HF model
5+
# The flow of this script is as follows:
6+
# 1. Convert a MaxText checkpoint to a Hugging Face model checkpoint.
7+
# 2. Run a forward pass check to compare the logits and KL divergence between the converted ckpt and orginal golden HF model.
8+
9+
# Pre-requisites:
10+
# 1. Set HF_TOKEN environment variable to your Hugging Face access token with read permissions
11+
# export HF_TOKEN=<Hugging Face access token>
812

913

1014
set -ex
1115
idx=$(date +%Y-%m-%d-%H-%M)
1216
MODEL_NAME='qwen3-4b'
1317
export MODEL_VARIATION='4b'
14-
HF_TOKEN='' # Important!!! Save your hf access token here
1518

1619
# Installing torch for deps in forward_pass_logit_checker.py
1720
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cpu
@@ -32,7 +35,7 @@ python3 -m MaxText.utils.ckpt_conversion.to_huggingface "${MAXTEXT_PKG_DIR:-${MA
3235
hf_access_token=${HF_TOKEN} \
3336
load_parameters_path=${CKPT_PATH} \
3437
base_output_directory=${LOCAL_PATH} \
35-
scan_layers=false
38+
scan_layers=false
3639

3740
# Alternatively, if uploaded the converted ckpt, HF requires local storage of model
3841
# mkdir -p "${LOCAL_PATH}"
@@ -47,4 +50,4 @@ python3 -m tests.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_R
4750
scan_layers=false \
4851
--hf_model_path=${LOCAL_PATH} \
4952
--max_kl_div=0.015 \
50-
--run_hf_model=true
53+
--run_hf_model=true
Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
11
#!/bin/bash
22

3-
# This file contains an end-to-end Airflow nightly test, designed to run once a day on a v4-8, along with documentation to guide users in getting started with Gemma2-2B.
3+
# This script is both an end-to-end test that runs once a day on a v4-8 and documentation for how to get started with Qwen3-4B.
44

5-
# The flow of this file is as follows:
6-
# 1. Convert the checkpoint downloaded from Hugging Face to make it compatible with MaxText
7-
# 2. Run a forward pass logits check to compare with the original HF golden model
8-
# 2. Run decoding, finetuning of Qwen3-4B. with the converted checkpoint.
9-
# 3. Run decoding from the finetuned checkpoint from step 2
5+
# The flow of this script is as follows:
6+
# 1. Convert the checkpoint downloaded from Hugging Face to make it compatible with MaxText.
7+
# 2. Run a forward pass logits check to compare with the original HF golden model.
8+
# 3. Run decoding, finetuning of Qwen3-4B. with the converted checkpoint.
9+
# 4. Run decoding from the finetuned checkpoint from step 3.
10+
11+
# Pre-requisites:
12+
# 1. Set HF_TOKEN environment variable to your Hugging Face access token with read permissions
13+
# export HF_TOKEN=<Hugging Face access token>
1014

1115
set -ex
1216
idx=$(date +%Y-%m-%d-%H-%M)
1317
MODEL_NAME='qwen3-4b'
1418
export MODEL_VARIATION='4b'
15-
HF_TOKEN='' # Important!!! Save your hf access token here
1619
HF_GOLDEN_MODEL='Qwen/Qwen3-4B'
1720
TOKENIZER_PATH="${MAXTEXT_ASSETS_ROOT:-${MAXTEXT_REPO_ROOT:-$PWD}/assets}"'/qwen3-tokenizer'
1821

@@ -29,23 +32,23 @@ python -m MaxText.utils.ckpt_conversion.to_maxtext "${MAXTEXT_PKG_DIR:-${MAXTEXT
2932
model_name=${MODEL_NAME} \
3033
hf_access_token=${HF_TOKEN} \
3134
base_output_directory=${MODEL_BUCKET}/${MODEL_VARIATION}/unscanned/${idx} \
32-
scan_layers=false
35+
scan_layers=false
3336

3437
export UNSCANNED_CKPT_PATH=${MODEL_BUCKET}/${MODEL_VARIATION}/unscanned/${idx}/0/items
3538

3639
# We also test whether the forward pass logits match the original HF model
3740
# to get higher precision (eg. float32) run on CPU with `JAX_PLATFORMS=cpu`
3841
python3 -m tests.forward_pass_logit_checker "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml \
39-
tokenizer_path==${TOKENIZER_PATH}\
42+
tokenizer_path=${TOKENIZER_PATH}\
4043
load_parameters_path=${UNSCANNED_CKPT_PATH} \
4144
model_name=${MODEL_NAME} \
4245
scan_layers=false \
4346
--hf_model_path=${HF_GOLDEN_MODEL} \
4447
--max_kl_div=0.015 \
45-
--run_hf_model=True
48+
--run_hf_model=True
4649

4750
# We can run decoding for unscanned checkpoints.
48-
python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path==${TOKENIZER_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_NAME} prompt="I love to"
51+
python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_NAME} prompt="I love to"
4952

5053
# # Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data
5154
export DATASET_PATH=gs://maxtext-dataset
@@ -55,7 +58,7 @@ export BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs
5558
# We can also run finetuning by using the scanned converted checkpoint.
5659
# Note that scanned checkpoint helps with efficient finetuning
5760
export FINETUNE_RUN_NAME=runner_finetune_${idx}
58-
python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path==${TOKENIZER_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} max_target_length=8192 steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_NAME} checkpoint_period=5
61+
python3 -m MaxText.train "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} max_target_length=8192 steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_NAME} checkpoint_period=5
5962

6063
# Now, run decoding on the checkpoint generated from our finetune run.
61-
python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path==${TOKENIZER_PATH} load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${FINETUNE_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_NAME} prompt="I love to"
64+
python3 -m MaxText.decode "${MAXTEXT_PKG_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/MaxText}/"configs/base.yml tokenizer_path=${TOKENIZER_PATH} load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${FINETUNE_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=${MODEL_NAME} prompt="I love to"

0 commit comments

Comments
 (0)