From 3a86ffa41fb0b0d86e60ab195db9d5c1b8a673bb Mon Sep 17 00:00:00 2001 From: marlinfiggins Date: Tue, 9 Sep 2025 11:39:59 -0700 Subject: [PATCH 1/3] End-to-end test for Kimi-K2 --- .../tpu/deepseek/kimi-k2-1t/test_kimi-k2.sh | 60 +++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 end_to_end/tpu/deepseek/kimi-k2-1t/test_kimi-k2.sh diff --git a/end_to_end/tpu/deepseek/kimi-k2-1t/test_kimi-k2.sh b/end_to_end/tpu/deepseek/kimi-k2-1t/test_kimi-k2.sh new file mode 100644 index 0000000000..6ca578e7f4 --- /dev/null +++ b/end_to_end/tpu/deepseek/kimi-k2-1t/test_kimi-k2.sh @@ -0,0 +1,60 @@ +#!/bin/bash +set -euo pipefail + +# This file tests the implementation of Kimi-K2. + +# The flow of this file is as follows: +# 1. Convert the checkpoint downloaded from HuggingFace to make it compatible with MaxText. +# 2. Convert the scanned checkpoint from step 1 into unscanned checkpoint format and run more efficient decoding. +# 3. Run logits check test between Huggingface and MaxText. + +export MODEL_NAME='kimi-k2-1t.yml' +export TOKENIZER_PATH='moonshotai/Kimi-K2-Instruct' + +export CHKPT_BUCKET='gs://maxtext-deepseek/kimi-k2-1t/hf' +export MODEL_BUCKET='gs://maxtext-deepseek/kimi-k2-1t' +export idx=0 + +export BASE_CFG='src/MaxText/configs/base.yml' + +# Environment / deps +echo "[setup] Installing minimal torch wheel for forward_pass_logit_checker deps..." +python3 -m pip install -q --disable-pip-version-check torch --index-url https://download.pytorch.org/whl/cpu + +# Step 1: +echo "[convert] Converting HF checkpoint to MaxText scanned Orbax..." +JAX_PLATFORMS=cpu python3 -m MaxText.convert_deepseek_family_ckpt \ + --base_model_path "${CHKPT_BUCKET}" \ + --maxtext_model_path "${MODEL_BUCKET}/${idx}" \ + --model_size "${MODEL_NAME}" + +# Step 2: +echo "[convert] Creating unscanned Orbax (optional)..." +JAX_PLATFORMS=cpu python3 -m MaxText.convert_deepseek_family_unscanned_ckpt \ + --base_model_path "${CHKPT_BUCKET}" \ + --maxtext_model_path "${MODEL_BUCKET}/${idx}/unscanned" \ + --model_size "${MODEL_NAME}" + +# Step 3: +export SCANNED_CKPT_PATH="${MODEL_BUCKET}/${idx}/0/items" + +echo "[check] Running forward_pass_logit_checker" +python3 -m tests.forward_pass_logit_checker \ + "${BASE_CFG}" \ + tokenizer_type=huggingface \ + tokenizer_path="${TOKENIZER_PATH}" \ + load_parameters_path="${SCANNED_CKPT_PATH}" \ + run_name="forward_pass_test_${MODEL_NAME}_hf_live" \ + per_device_batch_size=1 \ + model_name="${MODEL_NAME}" \ + max_prefill_predict_length=16 \ + max_target_length=16 \ + dataset_type=synthetic \ + scan_layers=false \ + sparse_matmul=False \ + dtype=float32 \ + activations_in_float32=true \ + matmul_precision=high \ + --run_hf_model=true \ + --hf_model_path="${TOKENIZER_PATH}" \ + --max_kl_div=2e-4 From 2c97354ac8b08437aebe666a4bf50f1958b95e6e Mon Sep 17 00:00:00 2001 From: marlinfiggins Date: Tue, 9 Sep 2025 11:47:31 -0700 Subject: [PATCH 2/3] Model name corrections; passing unscanned checkpoint for efficient decoding. --- end_to_end/tpu/deepseek/kimi-k2-1t/test_kimi-k2.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/end_to_end/tpu/deepseek/kimi-k2-1t/test_kimi-k2.sh b/end_to_end/tpu/deepseek/kimi-k2-1t/test_kimi-k2.sh index 6ca578e7f4..00dcbab3b4 100644 --- a/end_to_end/tpu/deepseek/kimi-k2-1t/test_kimi-k2.sh +++ b/end_to_end/tpu/deepseek/kimi-k2-1t/test_kimi-k2.sh @@ -8,7 +8,7 @@ set -euo pipefail # 2. Convert the scanned checkpoint from step 1 into unscanned checkpoint format and run more efficient decoding. # 3. Run logits check test between Huggingface and MaxText. -export MODEL_NAME='kimi-k2-1t.yml' +export MODEL_NAME='kimi-k2-1t' export TOKENIZER_PATH='moonshotai/Kimi-K2-Instruct' export CHKPT_BUCKET='gs://maxtext-deepseek/kimi-k2-1t/hf' @@ -29,7 +29,7 @@ JAX_PLATFORMS=cpu python3 -m MaxText.convert_deepseek_family_ckpt \ --model_size "${MODEL_NAME}" # Step 2: -echo "[convert] Creating unscanned Orbax (optional)..." +echo "[convert] Creating unscanned Orbax..." JAX_PLATFORMS=cpu python3 -m MaxText.convert_deepseek_family_unscanned_ckpt \ --base_model_path "${CHKPT_BUCKET}" \ --maxtext_model_path "${MODEL_BUCKET}/${idx}/unscanned" \ @@ -43,7 +43,7 @@ python3 -m tests.forward_pass_logit_checker \ "${BASE_CFG}" \ tokenizer_type=huggingface \ tokenizer_path="${TOKENIZER_PATH}" \ - load_parameters_path="${SCANNED_CKPT_PATH}" \ + load_parameters_path="${UNSCANNED_CKPT_PATH}" \ run_name="forward_pass_test_${MODEL_NAME}_hf_live" \ per_device_batch_size=1 \ model_name="${MODEL_NAME}" \ From 1cad4253365050f6f4c2a0a5dedd7b9fea07e9e6 Mon Sep 17 00:00:00 2001 From: marlinfiggins Date: Tue, 23 Sep 2025 10:22:23 -0700 Subject: [PATCH 3/3] Adding steps for downloading and converting checkpoint to bf16 --- .../tpu/deepseek/kimi-k2-1t/test_kimi-k2.sh | 76 ++++++++++++++----- 1 file changed, 58 insertions(+), 18 deletions(-) diff --git a/end_to_end/tpu/deepseek/kimi-k2-1t/test_kimi-k2.sh b/end_to_end/tpu/deepseek/kimi-k2-1t/test_kimi-k2.sh index 00dcbab3b4..5c779bdfa6 100644 --- a/end_to_end/tpu/deepseek/kimi-k2-1t/test_kimi-k2.sh +++ b/end_to_end/tpu/deepseek/kimi-k2-1t/test_kimi-k2.sh @@ -4,47 +4,85 @@ set -euo pipefail # This file tests the implementation of Kimi-K2. # The flow of this file is as follows: -# 1. Convert the checkpoint downloaded from HuggingFace to make it compatible with MaxText. -# 2. Convert the scanned checkpoint from step 1 into unscanned checkpoint format and run more efficient decoding. -# 3. Run logits check test between Huggingface and MaxText. +# 1. Download the checkpoint from HuggingFace (fp8 weights). +# 2. Convert the checkpoint from FP8 to BF16 in HuggingFace format. +# 3. Upload the BF16 HuggingFace checkpoint to your GCS bucket. +# 4. Convert the BF16 HuggingFace checkpoint to a MaxText scanned Orbax checkpoint. +# 5. Convert the scanned checkpoint to an unscanned checkpoint for efficient decoding. +# 6. Run logits check test between HuggingFace and MaxText using the unscanned checkpoint. export MODEL_NAME='kimi-k2-1t' export TOKENIZER_PATH='moonshotai/Kimi-K2-Instruct' export CHKPT_BUCKET='gs://maxtext-deepseek/kimi-k2-1t/hf' export MODEL_BUCKET='gs://maxtext-deepseek/kimi-k2-1t' -export idx=0 + +# Local working dirs for HF weights +export HF_LOCAL_FP8_DIR="${PWD}/kimi-k2-fp8" +export HF_LOCAL_BF16_DIR="${PWD}/kimi-k2-bf16" export BASE_CFG='src/MaxText/configs/base.yml' # Environment / deps -echo "[setup] Installing minimal torch wheel for forward_pass_logit_checker deps..." -python3 -m pip install -q --disable-pip-version-check torch --index-url https://download.pytorch.org/whl/cpu +echo "[setup] Installing dependencies..." +python3 -m pip install -q --disable-pip-version-check \ + torch==2.4.1 --index-url https://download.pytorch.org/whl/cpu \ + safetensors==0.4.5 \ + transformers \ + huggingface_hub \ + jsonlines \ + google-cloud-storage + +# Step 1: Download FP8 weights from Hugging Face +if [[ ! -d "${HF_LOCAL_FP8_DIR}" ]]; then + echo "[step 1] Downloading ${TOKENIZER_PATH} into ${HF_LOCAL_FP8_DIR}" + huggingface-cli download "${TOKENIZER_PATH}" \ + --local-dir "${HF_LOCAL_FP8_DIR}" \ + --local-dir-use-symlinks False +else + echo "[step 1] Skipping download; ${HF_LOCAL_FP8_DIR} already exists" +fi + +# Step 2: Convert FP8 -> BF16 in HuggingFace format +if [[ ! -d "${HF_LOCAL_BF16_DIR}" ]]; then + echo "[step 2] Converting FP8 -> BF16 HF checkpoint" + python3 -m MaxText.deepseek_fp8_to_bf16 \ + --input-fp8-hf-path "${HF_LOCAL_FP8_DIR}" \ + --output-bf16-hf-path "${HF_LOCAL_BF16_DIR}" +else + echo "[step 2] Skipping FP8->BF16; ${HF_LOCAL_BF16_DIR} already exists" +fi -# Step 1: -echo "[convert] Converting HF checkpoint to MaxText scanned Orbax..." +# Step 3: Upload BF16 HF weights to GCS +# After downloading and converting checkpoints, copy them to GCS bucket at $CHKPT_BUCKET \ +# Non-Googlers please remember to use separate GCS paths for uploading model weights from HuggingFace ($CHKPT_BUCKET) and MaxText compatible weights ($MODEL_BUCKET). +# Non-Googlers please remember to point these variables to GCS buckets that you own, this script uses internal buckets for testing. +echo "[step 3] Syncing BF16 HF checkpoint to ${CHKPT_BUCKET}" +gsutil -m rsync -r "${HF_LOCAL_BF16_DIR}" "${CHKPT_BUCKET}" + +# Step 4: Convert HF (BF16) -> MaxText scanned Orbax +echo "[step 4] HF BF16 -> MaxText scanned Orbax" JAX_PLATFORMS=cpu python3 -m MaxText.convert_deepseek_family_ckpt \ - --base_model_path "${CHKPT_BUCKET}" \ + --base_model_path "${CHKPT_BUCKET}" \ --maxtext_model_path "${MODEL_BUCKET}/${idx}" \ - --model_size "${MODEL_NAME}" + --model_size "${MODEL_NAME}" -# Step 2: -echo "[convert] Creating unscanned Orbax..." +# Step 5: Convert scanned -> unscanned Orbax JAX_PLATFORMS=cpu python3 -m MaxText.convert_deepseek_family_unscanned_ckpt \ - --base_model_path "${CHKPT_BUCKET}" \ + --base_model_path "${CHKPT_BUCKET}" \ --maxtext_model_path "${MODEL_BUCKET}/${idx}/unscanned" \ - --model_size "${MODEL_NAME}" + --model_size "${MODEL_NAME}" -# Step 3: -export SCANNED_CKPT_PATH="${MODEL_BUCKET}/${idx}/0/items" +export UNSCANNED_CKPT_PATH="${MODEL_BUCKET}/${idx}/unscanned/0/items" -echo "[check] Running forward_pass_logit_checker" +# Step 6: Logit check (MaxText vs HF) using unscanned checkpoint +echo "[step 6] Running forward_pass_logit_checker (unscanned ckpt, bf16 dtype)" python3 -m tests.forward_pass_logit_checker \ "${BASE_CFG}" \ tokenizer_type=huggingface \ tokenizer_path="${TOKENIZER_PATH}" \ load_parameters_path="${UNSCANNED_CKPT_PATH}" \ - run_name="forward_pass_test_${MODEL_NAME}_hf_live" \ + run_name="forward_pass_test_${MODEL_NAME}_hf_live_unscanned" \ per_device_batch_size=1 \ model_name="${MODEL_NAME}" \ max_prefill_predict_length=16 \ @@ -58,3 +96,5 @@ python3 -m tests.forward_pass_logit_checker \ --run_hf_model=true \ --hf_model_path="${TOKENIZER_PATH}" \ --max_kl_div=2e-4 + +echo "[done] Cross-check completed."