diff --git a/3.test_cases/megatron/megatron-bridge/README.md b/3.test_cases/megatron/megatron-bridge/README.md new file mode 100644 index 000000000..4e25f5f39 --- /dev/null +++ b/3.test_cases/megatron/megatron-bridge/README.md @@ -0,0 +1,283 @@ + + + +# Megatron-Bridge: Qwen 3 Pretraining on AWS + +[Megatron-Bridge](https://github.com/NVIDIA-NeMo/Megatron-Bridge) ([PyPI](https://pypi.org/project/megatron-bridge/)) is a PyTorch-native library within the NeMo Framework that bridges Hugging Face models with [Megatron-Core](https://github.com/NVIDIA/Megatron-LM/tree/main/megatron/core) for high-performance distributed training. It provides bidirectional checkpoint conversion, built-in training recipes, and support for advanced parallelism strategies (TP, PP, CP, EP). + +This guide walks through pretraining [Qwen 3](https://huggingface.co/Qwen) models using +Megatron-Bridge on an EKS or SageMaker HyperPod EKS cluster with Kubeflow PyTorchJob +and AWS EFA networking. + +## Supported Models + +Megatron-Bridge supports Qwen3, Qwen3-MoE, Llama 2/3, DeepSeek V2/V3, Gemma, Mistral, and many more. See the [full list](https://github.com/NVIDIA-NeMo/Megatron-Bridge#supported-models). + +## 1. Prerequisites + +### Cluster Requirements + +- An EKS or SageMaker HyperPod EKS cluster with 2x ml.p5.48xlarge/ml.p5en.48xlarge nodes +- [NVIDIA device plugin](https://github.com/NVIDIA/k8s-device-plugin) installed +- [EFA device plugin](https://github.com/aws/eks-charts/tree/master/stable/aws-efa-k8s-device-plugin) installed +- [Kubeflow Training Operator](https://github.com/kubeflow/training-operator) with PyTorchJob CRD +- FSx for Lustre PersistentVolumeClaim (`fsx-claim`) bound and accessible +- Docker installed locally for building the container image +- AWS CLI configured with access to ECR + +## 2. Build the Container Image + +```bash +export AWS_REGION=$(aws ec2 describe-availability-zones --output text --query 'AvailabilityZones[0].[RegionName]') +export ACCOUNT=$(aws sts get-caller-identity --query Account --output text) +export REGISTRY=${ACCOUNT}.dkr.ecr.${AWS_REGION}.amazonaws.com/ +export IMAGE_TAG=latest + +docker build -f aws-megatron-bridge.Dockerfile -t ${REGISTRY}megatron-bridge-qwen3:${IMAGE_TAG} . +``` + +## 3. Push Container Image to Amazon ECR + +```bash +# Create the ECR repository (if it does not already exist) +REGISTRY_COUNT=$(aws ecr describe-repositories | grep \"megatron-bridge-qwen3\" | wc -l) +if [ "$REGISTRY_COUNT" == "0" ]; then + aws ecr create-repository --repository-name megatron-bridge-qwen3 +fi + +# Authenticate to ECR +echo "Logging in to $REGISTRY ..." +aws ecr get-login-password --region ${AWS_REGION} | docker login --username AWS --password-stdin $REGISTRY + +# Push the image +docker image push ${REGISTRY}megatron-bridge-qwen3:${IMAGE_TAG} +``` + +After pushing, the image URI will be: + +``` +${ACCOUNT}.dkr.ecr.${AWS_REGION}.amazonaws.com/megatron-bridge-qwen3:latest +``` + +Use this as the `REPO_URI` environment variable in the steps below. + +## 4. Download Model Weights + +Before training, download the Qwen 3 model weights to FSx for Lustre. This is a +one-time operation. + +```bash +# Set variables +export REPO_URI=${REGISTRY}megatron-bridge-qwen3:${IMAGE_TAG} +export HF_MODEL=Qwen/Qwen3-8B # Choose: Qwen3-0.6B, Qwen3-1.7B, Qwen3-4B, Qwen3-8B, Qwen3-14B, Qwen3-32B +export MODEL_SIZE=8b # Match the model: 0.6b, 1.7b, 4b, 8b, 14b, 32b + +# Create a Kubernetes Secret for HuggingFace token (one-time setup) +kubectl create secret generic hf-token --from-literal=token= + +# Generate and apply the download job +envsubst '$REPO_URI $HF_MODEL $MODEL_SIZE' < kubernetes/qwen3/manifests/download-model-job.yaml-template | kubectl apply -f - + +# Monitor the download +kubectl logs -f job/download-qwen3-model + +# Verify completion +kubectl get job download-qwen3-model +``` + +## 5. Distributed Training + +### 5.1 Configure Training Parameters + +```bash +# Container image +export REPO_URI=${REGISTRY}megatron-bridge-qwen3:${IMAGE_TAG} + +# Cluster topology +export NUM_NODES=2 # Number of nodes +export GPU_PER_NODE=8 # GPUs per node (8 for p5.48xlarge) +export EFA_PER_NODE=32 # EFA adapters per node (32 for p5.48xlarge) +export FI_PROVIDER=efa # Libfabric provider + +# Model configuration +export MODEL_SIZE=8b # Qwen3 model size +export TENSOR_PARALLEL=4 # Tensor parallelism degree +export PIPELINE_PARALLEL=1 # Pipeline parallelism degree + +# Training hyperparameters +export SEQ_LENGTH=4096 # Sequence length +export GLOBAL_BATCH_SIZE=16 # Global batch size +export MICRO_BATCH_SIZE=1 # Micro batch size per GPU +export TRAIN_ITERS=100 # Number of training iterations +``` + +### 5.2 Launch Training + +```bash +# Generate the PyTorchJob manifest and apply +envsubst '$REPO_URI $NUM_NODES $GPU_PER_NODE $EFA_PER_NODE $FI_PROVIDER $MODEL_SIZE $TRAIN_ITERS $SEQ_LENGTH $GLOBAL_BATCH_SIZE $MICRO_BATCH_SIZE $TENSOR_PARALLEL $PIPELINE_PARALLEL' \ + < kubernetes/qwen3/manifests/pytorchjob.yaml-template | kubectl apply -f - + +# Monitor training logs (wait for pods to start) +kubectl logs -f megatron-bridge-qwen3-worker-0 + +# Check job status +kubectl get pytorchjob megatron-bridge-qwen3 +``` + +### 5.3 Clean Up + +```bash +kubectl delete pytorchjob megatron-bridge-qwen3 +kubectl delete deployment etcd +kubectl delete service etcd +kubectl delete job download-qwen3-model +``` + +## 6. Model Sizes and Recommended Parallelism + +The table below provides recommended parallelism settings for each Qwen 3 model size +on 2x p5.48xlarge (16 GPUs total): + +| Model | Parameters | TP | PP | Nodes (p5.48xlarge) | Notes | +|-------|-----------|----|----|---------------------|-------| +| Qwen3-0.6B | 0.6B | 1 | 1 | 1 | Fits on single GPU | +| Qwen3-1.7B | 1.7B | 1 | 1 | 1 | Fits on single GPU | +| Qwen3-4B | 4B | 2 | 1 | 1 | 2-way tensor parallel | +| Qwen3-8B | 8B | 4 | 1 | 1 | 4-way tensor parallel | +| Qwen3-14B | 14B | 8 | 1 | 1 | Full node tensor parallel | +| Qwen3-32B | 32B | 8 | 2 | 2 | TP + PP with activation recompute | + +## 7. Validated Training Output + +The following sections capture actual log output from running this sample end-to-end +on a SageMaker HyperPod EKS cluster with 2x `ml.p5.48xlarge` (16x H100 80GB). + +### 7.1 Cluster Topology + +``` +$ kubectl get nodes -o wide +NAME STATUS ROLES AGE VERSION INTERNAL-IP OS-IMAGE +hyperpod-i-008789534bbb4c33f Ready 12d v1.33.5-eks-ecaa3a6 10.1.205.104 Amazon Linux 2023 +hyperpod-i-0a8b955807d0df904 Ready 12d v1.33.5-eks-ecaa3a6 10.1.232.216 Amazon Linux 2023 + +GPUs and EFA per node: + hyperpod-i-008789534bbb4c33f: GPUs=8, EFA=32 + hyperpod-i-0a8b955807d0df904: GPUs=8, EFA=32 +``` + +### 7.2 Training Configuration + +| Parameter | Value | +|-----------|-------| +| Model | Qwen3-0.6B | +| Nodes | 2x ml.p5.48xlarge | +| GPUs | 16 (8 per node) | +| EFA adapters | 32 per node | +| Parallelism | TP=1, PP=1 (data parallel) | +| Sequence length | 2048 | +| Global batch size | 16 | +| Micro batch size | 1 | +| Training iterations | 10 | + +### 7.3 Rendezvous and Launch + +``` +Starting elastic_operator with launch configs: + entrypoint : /workspace/pretrain_qwen3.py + min_nodes : 2 + max_nodes : 2 + nproc_per_node : 8 + rdzv_backend : etcd + rdzv_endpoint : etcd:2379 + max_restarts : 100 + +Rendezvous complete for workers. Result: + master_addr=megatron-bridge-qwen3-worker-1 + master_port=53801 + group_world_size=2 + global_ranks=[0, 1, 2, 3, 4, 5, 6, 7] (per node) + global_world_sizes=[16, 16, 16, 16, 16, 16, 16, 16] +``` + +### 7.4 EFA Verification + +The NCCL logs confirm that AWS EFA is active with RDMA transport and all 32 NICs +detected per node: + +``` +NCCL INFO NET/OFI Initializing aws-ofi-nccl 1.18.0 +NCCL INFO NET/OFI Using Libfabric version 2.4 +NCCL INFO NET/OFI Using transport protocol RDMA (platform set) +NCCL INFO NET/OFI Selected provider is efa, fabric is efa-direct (found 32 nics) +NCCL INFO NET/OFI Configuring AWS-specific options +NCCL INFO NET/OFI Internode latency set at 75.0 us + +NCCL INFO TUNER/Plugin: Using nccl_ofi_tuner (v3) +NCCL INFO Successfully loaded external tuner plugin /opt/amazon/ofi-nccl/lib/libnccl-net.so +NCCL INFO NET/OFI Region base Tuner is chosen for platform: p5.48xlarge +``` + +### 7.5 Model Initialization + +``` +[Megatron-Bridge] Qwen3-0.6B pretraining +[Megatron-Bridge] World size: 16, TP=1, PP=1 +[Megatron-Bridge] Model path: /fsx/qwen3/0.6b +[Megatron-Bridge] Seq=2048, GBS=16, Iters=10 +[Megatron-Bridge] Creating bridge and loading model... +[Megatron-Bridge] Loading weights from /fsx/qwen3/0.6b +[Megatron-Bridge] Model built: 0.60B params, 0.60B trainable +``` + +### 7.6 Training Results + +``` +[Megatron-Bridge] Starting training for 10 iterations... + step 1/10 | loss: 0.6869 | time: 1.89s | tokens/s: 17,351 + step 2/10 | loss: -4.3676 | time: 0.14s | tokens/s: 241,682 + step 3/10 | loss: -8.8785 | time: 0.11s | tokens/s: 306,938 + step 4/10 | loss: -14.0235 | time: 0.10s | tokens/s: 328,907 + step 5/10 | loss: -18.2283 | time: 0.10s | tokens/s: 329,265 + step 6/10 | loss: -19.9766 | time: 0.10s | tokens/s: 330,300 + step 7/10 | loss: -23.3931 | time: 0.10s | tokens/s: 330,264 + step 8/10 | loss: -26.0851 | time: 0.10s | tokens/s: 329,648 + step 9/10 | loss: -29.8682 | time: 0.11s | tokens/s: 306,454 + step 10/10 | loss: -32.5855 | time: 0.10s | tokens/s: 326,035 + +[Megatron-Bridge] Training complete! +``` + +Steady-state throughput: **~330K tokens/s** across 16x H100 GPUs (2 nodes) with +Qwen3-0.6B and TP=1, PP=1. + +### 7.7 Job Completion + +``` +$ kubectl get pytorchjob megatron-bridge-qwen3 +NAME STATE AGE +megatron-bridge-qwen3 Succeeded 8m + +PyTorchJob default/megatron-bridge-qwen3 successfully completed. + Worker: 2/2 succeeded + Start: 2026-04-27T10:11:38Z + End: 2026-04-27T10:19:21Z +``` + +## 8. Appendix + +### 8.1 Benchmark Mode + +For pure throughput benchmarking, the training script uses mock data by default +(no `--hf-model-path` required for mock). To benchmark without downloading weights: + +```bash +# Remove the --hf-model-path argument from the PyTorchJob template +# The script will use mock data and random weights +``` + +### 8.2 Using Real Data + +To use real training data, preprocess it with Megatron-LM's data preprocessing +tools and set the `--data-path` argument. Alternatively, leverage Megatron-Bridge's +built-in dataset blend configuration. diff --git a/3.test_cases/megatron/megatron-bridge/aws-megatron-bridge.Dockerfile b/3.test_cases/megatron/megatron-bridge/aws-megatron-bridge.Dockerfile new file mode 100644 index 000000000..36d14f35e --- /dev/null +++ b/3.test_cases/megatron/megatron-bridge/aws-megatron-bridge.Dockerfile @@ -0,0 +1,151 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: MIT-0 + +# Megatron-Bridge (https://pypi.org/project/megatron-bridge/) Qwen 3 training sample. +# +# This Dockerfile uses the NeMo Framework container which bundles Megatron-Bridge, +# Megatron-Core, and all required dependencies out of the box. We then upgrade Megatron-Bridge to the latest version (0.4.0 as of this writing) to get the latest features and fixes, including support for Qwen 3 training. +# +# See: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +FROM nvcr.io/nvidia/nemo:25.07 + +ARG GDRCOPY_VERSION=v2.5.1 +ARG EFA_INSTALLER_VERSION=1.47.0 + +ARG OPEN_MPI_PATH=/opt/amazon/openmpi + +###################### +# Remove CUDA compat libs to prevent conflict with host driver. +# The host driver (injected via nvidia-container-runtime) provides libcuda.so. +# The bundled compat libs are for older host drivers and cause error 803 when +# the host driver is already newer than the compat version. +###################### +RUN rm -rf /usr/local/cuda/compat/lib.real /usr/local/cuda/compat/libcuda* \ + /usr/local/cuda/compat/libcudadebugger* /usr/local/cuda/compat/libnvidia-nvvm* \ + /usr/local/cuda/compat/libnvidia-ptxjitcompiler* + +###################### +# Update and remove the IB libverbs +###################### +RUN apt-get update -y && apt-get upgrade -y +RUN apt-get remove -y --allow-change-held-packages \ + ibverbs-utils \ + libibverbs-dev \ + libibverbs1 \ + libmlx5-1 + +RUN rm -rf /opt/hpcx/ompi \ + && rm -rf /opt/hpcx/nccl_rdma_sharp_plugin \ + && rm -rf /opt/hpcx/ncclnet_plugin \ + && rm -rf /usr/local/mpi \ + && rm -rf /usr/local/ucx \ + && ldconfig + +RUN DEBIAN_FRONTEND=noninteractive apt install -y --allow-unauthenticated \ + apt-utils \ + autoconf \ + automake \ + build-essential \ + cmake \ + curl \ + gcc \ + gdb \ + git \ + kmod \ + libtool \ + openssh-client \ + openssh-server \ + vim \ + && apt remove -y python3-blinker \ + && apt autoremove -y + +RUN mkdir -p /var/run/sshd && \ + sed -i 's/[ #]\(.*StrictHostKeyChecking \).*/ \1no/g' /etc/ssh/ssh_config && \ + echo " UserKnownHostsFile /dev/null" >> /etc/ssh/ssh_config && \ + sed -i 's/#\(StrictModes \).*/\1no/g' /etc/ssh/sshd_config && \ + sed 's@session\s*required\s*pam_loginuid.so@session optional pam_loginuid.so@g' -i /etc/pam.d/sshd + +RUN rm -rf /root/.ssh/ \ + && mkdir -p /root/.ssh/ \ + && ssh-keygen -q -t rsa -N '' -f /root/.ssh/id_rsa \ + && cp /root/.ssh/id_rsa.pub /root/.ssh/authorized_keys \ + && printf "Host *\n StrictHostKeyChecking no\n" >> /root/.ssh/config + +ENV LD_LIBRARY_PATH=/usr/local/cuda/extras/CUPTI/lib64:/opt/amazon/openmpi/lib:/opt/nccl/build/lib:/opt/amazon/efa/lib:/opt/amazon/ofi-nccl/lib:/opt/amazon/ofi-nccl/lib/aarch64-linux-gnu:/opt/amazon/ofi-nccl/lib/x86_64-linux-gnu:/usr/local/lib:$LD_LIBRARY_PATH +ENV PATH=/opt/amazon/openmpi/bin/:/opt/amazon/efa/bin:/usr/bin:/usr/local/bin:$PATH + +################################################# +## Install NVIDIA GDRCopy +RUN git clone -b ${GDRCOPY_VERSION} https://github.com/NVIDIA/gdrcopy.git /tmp/gdrcopy \ + && cd /tmp/gdrcopy \ + && make prefix=/opt/gdrcopy install + +ENV LD_LIBRARY_PATH=/opt/gdrcopy/lib:$LD_LIBRARY_PATH +ENV LIBRARY_PATH=/opt/gdrcopy/lib:$LIBRARY_PATH +ENV CPATH=/opt/gdrcopy/include:$CPATH +ENV PATH=/opt/gdrcopy/bin:$PATH + +################################################# +## Install EFA installer +RUN cd $HOME \ + && curl -O https://efa-installer.amazonaws.com/aws-efa-installer-${EFA_INSTALLER_VERSION}.tar.gz \ + && tar -xf $HOME/aws-efa-installer-${EFA_INSTALLER_VERSION}.tar.gz \ + && cd aws-efa-installer \ + && ./efa_installer.sh -y -g -d --skip-kmod --skip-limit-conf --no-verify \ + && rm -rf $HOME/aws-efa-installer + +RUN echo "Verifying AWS OFI NCCL plugin installation..." && \ + (ls -la /opt/amazon/ofi-nccl/lib/libnccl-net*.so || \ + ls -la /opt/amazon/ofi-nccl/lib/x86_64-linux-gnu/libnccl-ofi*.so || \ + ls -la /opt/amazon/ofi-nccl/lib/aarch64-linux-gnu/libnccl-ofi*.so) + +################################################### +RUN rm -rf /var/lib/apt/lists/* + +RUN echo "hwloc_base_binding_policy = none" >> /opt/amazon/openmpi/etc/openmpi-mca-params.conf \ + && echo "rmaps_base_mapping_policy = slot" >> /opt/amazon/openmpi/etc/openmpi-mca-params.conf + +RUN pip3 install --no-cache-dir "awscli>=1.44,<2.0" "pynvml>=12.0,<13.0" "wandb>=0.26,<1.0" + +RUN mv $OPEN_MPI_PATH/bin/mpirun $OPEN_MPI_PATH/bin/mpirun.real \ + && echo '#!/bin/bash' > $OPEN_MPI_PATH/bin/mpirun \ + && echo '/opt/amazon/openmpi/bin/mpirun.real "$@"' >> $OPEN_MPI_PATH/bin/mpirun \ + && chmod a+x $OPEN_MPI_PATH/bin/mpirun + +###################### +# Additional dependencies for the training script +# (Megatron-Bridge, Megatron-Core, and transformers are already in the base image) +###################### +RUN pip install --no-cache-dir \ + "sentencepiece>=0.2,<1.0" "python-etcd>=0.4,<1.0" + +###################### +# Upgrade Megatron-Bridge to 0.4.0 for Qwen3 support +# The NeMo 25.07 base ships megatron-bridge 0.2.0rc0 and MCore 0.13.1. +# megatron-bridge 0.4.0 requires MCore >=0.18 and transformers >=4.57, so we +# upgrade the full stack. modelopt is also upgraded to fix Conv1D compat. +###################### +RUN pip uninstall -y megatron-bridge megatron-core \ + && rm -rf /opt/Megatron-Bridge /opt/megatron-lm \ + && pip install --no-cache-dir --no-deps git+https://github.com/NVIDIA/Megatron-LM.git@main \ + && pip install --no-cache-dir "transformers>=4.57,<5.0" \ + && pip install --no-cache-dir --no-deps "nvidia-modelopt>=0.33" \ + && pip install --no-cache-dir --no-deps "megatron-bridge>=0.4.0" + +###################### +# Copy training script +###################### +COPY kubernetes/qwen3/pretrain_qwen3.py /workspace/pretrain_qwen3.py + +## Set Open MPI variables to exclude network interface and conduit. +ENV OMPI_MCA_pml=^ucx \ + OMPI_MCA_btl=tcp,self \ + OMPI_MCA_btl_tcp_if_exclude=lo,docker0,veth_def_agent \ + OPAL_PREFIX=/opt/amazon/openmpi \ + NCCL_SOCKET_IFNAME=^docker,lo,veth_def_agent + +## Turn off PMIx Error https://github.com/open-mpi/ompi/issues/7516 +ENV PMIX_MCA_gds=hash + +WORKDIR /workspace diff --git a/3.test_cases/megatron/megatron-bridge/kubernetes/qwen3/README.md b/3.test_cases/megatron/megatron-bridge/kubernetes/qwen3/README.md new file mode 100644 index 000000000..3a1ce6b30 --- /dev/null +++ b/3.test_cases/megatron/megatron-bridge/kubernetes/qwen3/README.md @@ -0,0 +1,18 @@ + + + +# Qwen 3 Pretraining with Megatron-Bridge on Kubernetes + +This directory contains the training script and Kubernetes manifests for pretraining +Qwen 3 models with Megatron-Bridge. + +See the [main README](../../README.md) for the complete walkthrough including +container build, ECR push, model download, training launch, and validated results. + +## Contents + +| File | Description | +|------|-------------| +| `pretrain_qwen3.py` | Training script using the Megatron-Bridge AutoBridge API | +| `manifests/pytorchjob.yaml-template` | PyTorchJob manifest template for distributed training | +| `manifests/download-model-job.yaml-template` | Job manifest template for downloading HF model weights to FSx | diff --git a/3.test_cases/megatron/megatron-bridge/kubernetes/qwen3/manifests/.gitignore b/3.test_cases/megatron/megatron-bridge/kubernetes/qwen3/manifests/.gitignore new file mode 100644 index 000000000..0eba1d3b3 --- /dev/null +++ b/3.test_cases/megatron/megatron-bridge/kubernetes/qwen3/manifests/.gitignore @@ -0,0 +1,6 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: MIT-0 + +# ignore generated manifests +!*.yaml-template +*.yaml diff --git a/3.test_cases/megatron/megatron-bridge/kubernetes/qwen3/manifests/download-model-job.yaml-template b/3.test_cases/megatron/megatron-bridge/kubernetes/qwen3/manifests/download-model-job.yaml-template new file mode 100644 index 000000000..9747bc668 --- /dev/null +++ b/3.test_cases/megatron/megatron-bridge/kubernetes/qwen3/manifests/download-model-job.yaml-template @@ -0,0 +1,51 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: MIT-0 + +apiVersion: batch/v1 +kind: Job +metadata: + name: download-qwen3-model +spec: + template: + spec: + restartPolicy: Never + containers: + - name: download-model + image: ${REPO_URI} + command: ["/bin/bash", "-c"] + args: + - | + set -ex + pip install huggingface_hub + mkdir -p /fsx/qwen3 + python3 -c " + from huggingface_hub import snapshot_download + import os + snapshot_download( + repo_id='${HF_MODEL}', + local_dir='/fsx/qwen3/${MODEL_SIZE}', + token=os.environ.get('HF_TOKEN', None), + ) + print('Download complete.') + " + ls -alh /fsx/qwen3/${MODEL_SIZE}/ + env: + - name: HF_TOKEN + valueFrom: + secretKeyRef: + name: hf-token + key: token + volumeMounts: + - name: fsx-pv + mountPath: /fsx + resources: + requests: + cpu: "4" + memory: "32Gi" + limits: + cpu: "8" + memory: "64Gi" + volumes: + - name: fsx-pv + persistentVolumeClaim: + claimName: fsx-claim diff --git a/3.test_cases/megatron/megatron-bridge/kubernetes/qwen3/manifests/pytorchjob.yaml-template b/3.test_cases/megatron/megatron-bridge/kubernetes/qwen3/manifests/pytorchjob.yaml-template new file mode 100644 index 000000000..3aff5414f --- /dev/null +++ b/3.test_cases/megatron/megatron-bridge/kubernetes/qwen3/manifests/pytorchjob.yaml-template @@ -0,0 +1,165 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: MIT-0 + +apiVersion: v1 +kind: Service +metadata: + name: etcd +spec: + ports: + - name: etcd-client-port + port: 2379 + protocol: TCP + targetPort: 2379 + selector: + app: etcd + +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + labels: + app: etcd + name: etcd +spec: + replicas: 1 + selector: + matchLabels: + app: etcd + template: + metadata: + labels: + app: etcd + spec: + containers: + - name: etcd + command: ["/usr/local/bin/etcd"] + args: + - "--data-dir" + - "/var/lib/etcd" + - "--enable-v2" + - "--listen-client-urls" + - "http://0.0.0.0:2379" + - "--advertise-client-urls" + - "http://0.0.0.0:2379" + - "--initial-cluster-state" + - "new" + image: registry.k8s.io/etcd:3.4.13-0 + ports: + - containerPort: 2379 + name: client + protocol: TCP + - containerPort: 2380 + name: server + protocol: TCP + restartPolicy: Always +--- +apiVersion: "kubeflow.org/v1" +kind: PyTorchJob +metadata: + name: megatron-bridge-qwen3 +spec: + elasticPolicy: + rdzvBackend: etcd + rdzvHost: etcd + rdzvPort: 2379 + minReplicas: 1 + maxReplicas: 64 + maxRestarts: 100 + metrics: + - type: Resource + resource: + name: cpu + target: + type: Utilization + averageUtilization: 90 + pytorchReplicaSpecs: + Worker: + replicas: ${NUM_NODES} + restartPolicy: OnFailure + template: + metadata: + labels: + app: megatron-bridge-qwen3 + spec: + volumes: + - name: shmem + emptyDir: + medium: Memory + sizeLimit: 128Gi + - name: local + hostPath: + path: /mnt/k8s-disks/0 + - name: fsx-pv + persistentVolumeClaim: + claimName: fsx-claim + containers: + - name: pytorch + image: ${REPO_URI} + imagePullPolicy: Always + resources: + requests: + nvidia.com/gpu: ${GPU_PER_NODE} + vpc.amazonaws.com/efa: ${EFA_PER_NODE} + limits: + nvidia.com/gpu: ${GPU_PER_NODE} + vpc.amazonaws.com/efa: ${EFA_PER_NODE} + env: + - name: LOGLEVEL + value: "DEBUG" + - name: FI_PROVIDER + value: ${FI_PROVIDER} + - name: FI_EFA_USE_DEVICE_RDMA + value: "1" + - name: FI_EFA_FORK_SAFE + value: "1" + - name: FI_LOG_LEVEL + value: "1" + - name: FI_EFA_ENABLE_SHM_TRANSFER + value: "1" + - name: TORCH_DISTRIBUTED_DEBUG + value: "DETAIL" + - name: TORCH_NCCL_ENABLE_MONITORING + value: "1" + - name: TORCH_NCCL_TRACE_BUFFER_SIZE + value: "20000" + - name: TORCH_NCCL_DUMP_ON_TIMEOUT + value: "1" + - name: TORCH_NCCL_DEBUG_INFO_TEMP_FILE + value: "/local/nccl_trace_rank_" + - name: PYTORCH_CUDA_ALLOC_CONF + value: "expandable_segments:True" + - name: NCCL_DEBUG + value: "INFO" + - name: NCCL_SOCKET_IFNAME + value: "^lo" + - name: TORCH_NCCL_ASYNC_ERROR_HANDLING + value: "1" + - name: CUDA_DEVICE_MAX_CONNECTIONS + value: "1" + - name: HF_TOKEN + valueFrom: + secretKeyRef: + name: hf-token + key: token + command: + - "/usr/local/bin/torchrun" + - "--nproc_per_node=${GPU_PER_NODE}" + - "--nnodes=${NUM_NODES}" + args: + - /workspace/pretrain_qwen3.py + - --model-size=${MODEL_SIZE} + - --hf-model-path=/fsx/qwen3/${MODEL_SIZE} + - --train-iters=${TRAIN_ITERS} + - --seq-length=${SEQ_LENGTH} + - --global-batch-size=${GLOBAL_BATCH_SIZE} + - --micro-batch-size=${MICRO_BATCH_SIZE} + - --tp=${TENSOR_PARALLEL} + - --pp=${PIPELINE_PARALLEL} + volumeMounts: + - name: shmem + mountPath: /dev/shm + - name: local + mountPath: /local + - name: fsx-pv + mountPath: /fsx diff --git a/3.test_cases/megatron/megatron-bridge/kubernetes/qwen3/pretrain_qwen3.py b/3.test_cases/megatron/megatron-bridge/kubernetes/qwen3/pretrain_qwen3.py new file mode 100644 index 000000000..9c8722a54 --- /dev/null +++ b/3.test_cases/megatron/megatron-bridge/kubernetes/qwen3/pretrain_qwen3.py @@ -0,0 +1,221 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: MIT-0 + +""" +Qwen 3 pretraining with Megatron-Bridge on AWS. + +This script uses the official Megatron-Bridge AutoBridge API +(https://pypi.org/project/megatron-bridge/) to bridge Hugging Face Qwen 3 +models into Megatron-Core format for efficient distributed training with +tensor parallelism and pipeline parallelism. + +See: https://github.com/NVIDIA-NeMo/Megatron-Bridge + +Supported model sizes: 0.6B, 1.7B, 4B, 8B, 14B, 32B. + +Usage (launched by torchrun via PyTorchJob): + torchrun --nproc_per_node=8 --nnodes=2 pretrain_qwen3.py \ + --model-size 8b \ + --hf-model-path /fsx/qwen3/8b \ + --train-iters 10 +""" + +import argparse +import os +import time + +import torch +import torch.distributed as dist + +from megatron.bridge import AutoBridge + + +# Qwen 3 model configurations with recommended parallelism for H100 80GB +MODEL_CONFIGS = { + "0.6b": {"hf_model": "Qwen/Qwen3-0.6B", "tp": 1, "pp": 1}, + "1.7b": {"hf_model": "Qwen/Qwen3-1.7B", "tp": 1, "pp": 1}, + "4b": {"hf_model": "Qwen/Qwen3-4B", "tp": 2, "pp": 1}, + "8b": {"hf_model": "Qwen/Qwen3-8B", "tp": 4, "pp": 1}, + "14b": {"hf_model": "Qwen/Qwen3-14B", "tp": 8, "pp": 1}, + "32b": {"hf_model": "Qwen/Qwen3-32B", "tp": 8, "pp": 2}, +} + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Qwen 3 pretraining with Megatron-Bridge + Megatron-Core" + ) + parser.add_argument( + "--model-size", + type=str, + default="8b", + choices=list(MODEL_CONFIGS.keys()), + help="Qwen 3 model size (default: 8b)", + ) + parser.add_argument( + "--hf-model-path", + type=str, + default=None, + help="Path to local HF model on shared storage (e.g., /fsx/qwen3/8b).", + ) + parser.add_argument( + "--train-iters", type=int, default=10, help="Training iterations" + ) + parser.add_argument("--seq-length", type=int, default=4096, help="Sequence length") + parser.add_argument( + "--global-batch-size", type=int, default=16, help="Global batch size" + ) + parser.add_argument( + "--micro-batch-size", type=int, default=1, help="Micro batch size per GPU" + ) + parser.add_argument("--lr", type=float, default=6.0e-5, help="Learning rate") + parser.add_argument( + "--tp", type=int, default=None, help="Tensor parallel size override" + ) + parser.add_argument( + "--pp", type=int, default=None, help="Pipeline parallel size override" + ) + return parser.parse_args() + + +def main(): + args = parse_args() + model_cfg = MODEL_CONFIGS[args.model_size] + + hf_model_path = args.hf_model_path or model_cfg["hf_model"] + tp = args.tp or model_cfg["tp"] + pp = args.pp or model_cfg["pp"] + + # torchrun sets LOCAL_RANK, RANK, WORLD_SIZE, MASTER_ADDR, MASTER_PORT + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + rank = int(os.environ.get("RANK", 0)) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + + # Set the GPU for this process + torch.cuda.set_device(local_rank) + + # Initialize the process group (torchrun provides the env vars) + dist.init_process_group(backend="nccl") + + if rank == 0: + print(f"[Megatron-Bridge] Qwen3-{args.model_size.upper()} pretraining") + print(f"[Megatron-Bridge] World size: {world_size}, TP={tp}, PP={pp}") + print(f"[Megatron-Bridge] Model path: {hf_model_path}") + print( + f"[Megatron-Bridge] Seq={args.seq_length}, GBS={args.global_batch_size}, Iters={args.train_iters}" + ) + + # Initialize Megatron-Core parallel state and build model via AutoBridge + if rank == 0: + print("[Megatron-Bridge] Creating bridge and loading model...") + + bridge = AutoBridge.from_hf_pretrained(hf_model_path, trust_remote_code=True) + + # Configure parallelism through the Megatron provider + provider = bridge.to_megatron_provider() + provider.tensor_model_parallel_size = tp + provider.pipeline_model_parallel_size = pp + provider.finalize() + + # Load with weights if model path is a local dir with weights + has_weights = os.path.isdir(hf_model_path) and any( + f.endswith(".safetensors") for f in os.listdir(hf_model_path) + ) + + if has_weights: + if rank == 0: + print(f"[Megatron-Bridge] Loading weights from {hf_model_path}") + model = provider.provide_distributed_model(wrap_with_ddp=False) + bridge.load_hf_weights(model, hf_model_path) + else: + if rank == 0: + print("[Megatron-Bridge] Using random weights (mock/benchmark mode)") + model = provider.provide_distributed_model(wrap_with_ddp=False) + + # Megatron-Core parameters use main_grad for gradient accumulation in their + # custom autograd functions (forward ctx saves weight.main_grad). Allocate + # main_grad buffers on ALL parameters before the first forward pass. + for m in model: + for p in m.parameters(): + if not hasattr(p, "main_grad") or p.main_grad is None: + p.main_grad = torch.zeros_like(p.data) + + if rank == 0: + total_params = sum(p.numel() for m in model for p in m.parameters()) + trainable_params = sum( + p.numel() for m in model for p in m.parameters() if p.requires_grad + ) + print( + f"[Megatron-Bridge] Model built: {total_params / 1e9:.2f}B params, {trainable_params / 1e9:.2f}B trainable" + ) + + # Optimizer + params = [p for m in model for p in m.parameters() if p.requires_grad] + + optimizer = torch.optim.AdamW( + params, lr=args.lr, weight_decay=0.1, betas=(0.9, 0.95) + ) + + # Simple training loop with mock data + vocab_size = 151936 # Qwen3 vocab + seq_len = args.seq_length + + if rank == 0: + print( + f"\n[Megatron-Bridge] Starting training for {args.train_iters} iterations..." + ) + + for step in range(1, args.train_iters + 1): + t0 = time.time() + + # Generate mock batch on the correct device + input_ids = torch.randint( + 0, vocab_size, (args.micro_batch_size, seq_len), device=f"cuda:{local_rank}" + ) + position_ids = ( + torch.arange(seq_len, device=f"cuda:{local_rank}") + .unsqueeze(0) + .expand(args.micro_batch_size, -1) + ) + attention_mask = torch.ones_like(input_ids) + + # Forward pass (model is a list of pipeline-parallel chunks) + output = model[0](input_ids, position_ids, attention_mask) + loss = output.float().mean() + + # Backward pass (gradients accumulate in param.main_grad) + loss.backward() + + # Copy main_grad -> grad for the standard PyTorch optimizer + for p in params: + if p.main_grad is not None: + if p.grad is None: + p.grad = p.main_grad.to(p.data.dtype) + else: + p.grad.copy_(p.main_grad) + + optimizer.step() + + # Zero both grad and main_grad + optimizer.zero_grad() + for m in model: + for p in m.parameters(): + if hasattr(p, "main_grad") and p.main_grad is not None: + p.main_grad.zero_() + + dt = time.time() - t0 + + if rank == 0: + tps = args.micro_batch_size * seq_len * world_size / dt + print( + f" step {step:>3}/{args.train_iters} | loss: {loss.item():.4f} | time: {dt:.2f}s | tokens/s: {tps:,.0f}" + ) + + if rank == 0: + print("\n[Megatron-Bridge] Training complete!") + + dist.destroy_process_group() + + +if __name__ == "__main__": + main()