diff --git a/.github/workflows/rocm-ci.yml b/.github/workflows/rocm-ci.yml
new file mode 100644
index 000000000..59e177b74
--- /dev/null
+++ b/.github/workflows/rocm-ci.yml
@@ -0,0 +1,293 @@
+# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
+#
+# See LICENSE for license information.
+
+name: TransformerEngine CI
+
+on:
+ push:
+ branches:
+ - 'dev'
+ - 'release_v1.*_rocm'
+ - 'release_v2.*_rocm'
+ pull_request:
+ branches:
+ - 'dev'
+ - 'release_v1.**_rocm'
+ - 'release_v2.**_rocm'
+ workflow_dispatch:
+ inputs:
+ test_level:
+ description: 'Test Level (1-3)'
+ required: true
+ default: '1'
+ skip_dev_merge:
+ description: 'Skip merging dev branch'
+ type: boolean
+ default: false
+
+concurrency:
+ group: ${{ github.workflow }}-${{ github.ref }}
+ cancel-in-progress: true
+
+jobs:
+ build_and_test:
+ name: Build and Test on GPU
+ timeout-minutes: 720
+ runs-on: linux-mi325-4
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v4
+ with:
+ submodules: 'recursive'
+
+ - name: Merge origin/dev
+ # Only run on PRs targeting dev, or manual runs where we didn't skip it
+ if: |
+ (github.event_name == 'pull_request' && github.base_ref == 'dev') ||
+ (github.event_name == 'workflow_dispatch' && inputs.skip_dev_merge != 'true' && github.ref == 'refs/heads/dev')
+ run: |
+ echo "Attempting to merge origin/dev..."
+ git config --global user.email "amd@amd.com"
+ git config --global user.name "AMD CI"
+
+ # Fetch dev specifically
+ git fetch origin dev
+
+ # Attempt merge; this will exit with error code 1 if there is a conflict, failing the job
+ git merge origin/dev
+ echo "Merge successful."
+
+ - name: Select Docker Image Tag
+ id: select-image
+ env:
+ DEV_IMAGE: ${{ vars.DEV_DOCKER_IMAGE }}
+ REL_IMAGE: ${{ vars.REL613_DOCKER_IMAGE }}
+ run: |
+ BRANCH_NAME="${{ github.base_ref || github.ref_name }}"
+ echo "Determining image for branch: $BRANCH_NAME"
+ DEV_DOCKER_IMAGE="$DEV_IMAGE"
+ REL613_DOCKER_IMAGE="$REL_IMAGE"
+ IMAGE_TO_USE="$DEV_DOCKER_IMAGE"
+ if [[ $BRANCH_NAME =~ ^release_v([0-9]+)\.([0-9]+)_rocm$ ]]; then
+ MAJOR_VERSION=${BASH_REMATCH[1]}
+ MINOR_VERSION=${BASH_REMATCH[2]}
+ if (( MAJOR_VERSION == 1 )); then
+ if (( MINOR_VERSION == 13 || MINOR_VERSION == 14 )); then IMAGE_TO_USE="$REL613_DOCKER_IMAGE"; fi
+ fi
+ fi
+ echo "Selected image: $IMAGE_TO_USE"
+ echo "image-tag=$IMAGE_TO_USE" >> $GITHUB_OUTPUT
+
+ - name: Pull Docker Image
+ run: |
+ docker pull ${{ steps.select-image.outputs.image-tag }}
+
+ - name: Run Container
+ run: |
+ docker run -dt \
+ --name te-runner \
+ --network=host \
+ --device=/dev/dri --device=/dev/kfd \
+ --shm-size=16G \
+ --pid=host \
+ --group-add $(getent group render | cut -d: -f3) \
+ --group-add $(getent group video | cut -d: -f3) \
+ -v "${{ github.workspace }}:/workspace" \
+ -w /workspace \
+ ${{ steps.select-image.outputs.image-tag}}
+
+ - name: Determine GPU Architecture via rocminfo
+ id: gpu-arch
+ run: |
+ # Run rocminfo inside the container and capture the output
+ ARCH=$(docker exec te-runner bash -c "rocminfo | grep -m 1 -oP 'gfx[0-9a-fA-F]+'")
+ if [ -z "$ARCH" ]; then
+ echo "::error::Could not determine GPU architecture using rocminfo inside the container."
+ # Optional: Print full rocminfo output for debugging
+ docker exec te-runner rocminfo
+ exit 1
+ fi
+ echo "Detected GPU Arch: $ARCH"
+ echo "arch=$ARCH" >> $GITHUB_OUTPUT
+
+ - name: Build Project
+ run: |
+ docker exec \
+ -e GPU_ARCH=${{ steps.gpu-arch.outputs.arch }} \
+ te-runner bash -c "$(cat <<'EOF'
+ set -ex
+
+ export HIP_PATH=""
+ export PYTORCH_ROCM_ARCH=$GPU_ARCH
+ export NVTE_ROCM_ARCH=$GPU_ARCH
+ export NVTE_AITER_PREBUILT_BASE_URL=https://compute-artifactory.amd.com:5000/artifactory/rocm-generic-local/te-ci/aiter-prebuilts
+ pip install ninja
+ pip install -v . 2>&1
+ EOF
+ )"
+
+ - name: Run sGPU tests
+ id: sgpu-tests
+ continue-on-error: true
+ run: |
+ docker exec te-runner bash -c "$(cat <<'EOF'
+ #!/usr/bin/bash
+ set -x -o pipefail
+ ulimit -c 0 # Disable core dumps
+
+ # debug output
+ ls -d /opt/rocm*
+ python --version
+ pip list | egrep "transformer_e|torch|jax|numpy|ml_dtypes|typing_ext"
+
+ HIP_VISIBLE_DEVICES=1 ci/pytorch.sh > /workspace/torch_sgpu.log 2>&1 &
+ torch_pid=$!; echo Pytorch test pid $!
+
+ HIP_VISIBLE_DEVICES=2 ci/jax.sh > /workspace/jax_sgpu.log 2>&1 &
+ jax_pid=$!; echo JAX test pid $!
+
+ HIP_VISIBLE_DEVICES=3 ci/core.sh > /workspace/core_sgpu.log 2>&1 &
+ core_pid=$!; echo Core test pid $!
+
+ wait $core_pid; core_rc=$?
+ wait $jax_pid; jax_rc=$?
+ wait $torch_pid; torch_rc=$?
+
+ # Check PyTorch
+ if [ $torch_rc -ne 0 ]; then
+ echo "::group::[FAILED] PyTorch sGPU Log"
+ cat /workspace/torch_sgpu.log
+ echo "::endgroup::"
+ echo "::error::Pytorch sGPU test FAILED."
+ fi
+
+ # Check JAX
+ if [ $jax_rc -ne 0 ]; then
+ echo "::group::[FAILED] JAX sGPU Log"
+ cat /workspace/jax_sgpu.log
+ echo "::endgroup::"
+ echo "::error::JAX sGPU test FAILED."
+ fi
+
+ # Check Core
+ if [ $core_rc -ne 0 ]; then
+ echo "::group::[FAILED] Core sGPU Log"
+ cat /workspace/core_sgpu.log
+ echo "::endgroup::"
+ echo "::error::Core sGPU test FAILED."
+ fi
+
+ test $torch_rc -eq 0 -a $jax_rc -eq 0 -a $core_rc -eq 0
+ EOF
+ )"
+
+ - name: Run mGPU tests
+ id: mgpu-tests
+ continue-on-error: true
+ run: |
+ docker exec te-runner bash -c "$(cat <<'EOF'
+ #!/usr/bin/bash
+ set -x -o pipefail
+ ulimit -c 0 # Disable core dumps
+
+ # Run PyTorch
+ ci/pytorch.sh > /workspace/torch_mgpu.log 2>&1
+ torch_rc=$?
+
+ # Run JAX
+ ci/jax.sh > /workspace/jax_mgpu.log 2>&1
+ jax_rc=$?
+
+ if [ $torch_rc -ne 0 ]; then
+ echo "::group::[FAILED] PyTorch mGPU Log"
+ cat /workspace/torch_mgpu.log
+ echo "::endgroup::"
+ echo "::error::Pytorch mGPU test FAILED."
+ fi
+
+ if [ $jax_rc -ne 0 ]; then
+ echo "::group::[FAILED] JAX mGPU Log"
+ cat /workspace/jax_mgpu.log
+ echo "::endgroup::"
+ echo "::error::JAX mGPU test FAILED."
+ fi
+
+ test $torch_rc -eq 0 -a $jax_rc -eq 0
+ EOF
+ )"
+
+ - name: Run Examples
+ id: examples-tests
+ continue-on-error: true
+ run: |
+ docker exec te-runner bash -c "$(cat <<'EOF'
+ #!/usr/bin/bash
+ set -ex -o pipefail
+ ulimit -c 0 # Disable core dumps
+
+ cd /workspace/examples/pytorch/mnist
+ python main.py 2>&1 | tee /workspace/examples.log
+ python main.py --use-te 2>&1 | tee -a /workspace/examples.log
+ python main.py --use-fp8 2>&1 | tee -a /workspace/examples.log
+
+ cd /workspace/examples/jax/mnist
+ pip3 install -r requirements.txt
+ python test_single_gpu_mnist.py 2>&1 | tee -a /workspace/examples.log
+ python test_single_gpu_mnist.py --use-te 2>&1 | tee -a /workspace/examples.log
+ python test_single_gpu_mnist.py --use-fp8 2>&1 | tee -a /workspace/examples.log
+
+ cd /workspace/examples/jax/encoder
+ pip3 install -r requirements.txt
+ python test_single_gpu_encoder.py 2>&1 | tee -a /workspace/examples.log
+ python test_single_gpu_encoder.py --use-fp8 2>&1 | tee -a /workspace/examples.log
+ EOF
+ )"
+
+ - name: Check Test Failure Status
+ if: always()
+ run: |
+ # Check outcomes of the specific test steps
+ # "outcome" will be 'failure' even if continue-on-error was true
+ if [[ "${{ steps.sgpu-tests.outcome }}" == "failure" ]]; then
+ echo "::error::sGPU Tests Failed."
+ EXIT_STATUS=1
+ fi
+
+ if [[ "${{ steps.mgpu-tests.outcome }}" == "failure" ]]; then
+ echo "::error::mGPU Tests Failed."
+ EXIT_STATUS=1
+ fi
+
+ if [[ "${{ steps.examples-tests.outcome }}" == "failure" ]]; then
+ echo "::error::Example Tests Failed."
+ EXIT_STATUS=1
+ fi
+
+ # Fail the job if any errors were detected
+ if [[ "$EXIT_STATUS" == "1" ]]; then
+ exit 1
+ fi
+
+ - name: Copy logs and reports from container
+ if: always()
+ run: |
+ docker cp te-runner:/workspace/torch_sgpu.log ./torch_sgpu.log || true
+ docker cp te-runner:/workspace/jax_sgpu.log ./jax_sgpu.log || true
+ docker cp te-runner:/workspace/core_sgpu.log ./core_sgpu.log || true
+ docker cp te-runner:/workspace/torch_mgpu.log ./torch_mgpu.log || true
+ docker cp te-runner:/workspace/jax_mgpu.log ./jax_mgpu.log || true
+
+ - name: Upload logs and test reports
+ if: always()
+ uses: actions/upload-artifact@v4
+ with:
+ name: logs-and-reports
+ path: |
+ *.log
+ if-no-files-found: ignore
+ retention-days: 5
+
+ - name: Cleanup container
+ if: always()
+ run: docker rm -f te-runner || true
diff --git a/3rdparty/aiter b/3rdparty/aiter
index 1b00a0e8a..7a41cca67 160000
--- a/3rdparty/aiter
+++ b/3rdparty/aiter
@@ -1 +1 @@
-Subproject commit 1b00a0e8a54be0411490a69a5d7042abd33a56d9
+Subproject commit 7a41cca67187bd5f77c337765a1a289337901cef
diff --git a/3rdparty/hipify_torch b/3rdparty/hipify_torch
index 12ac3f401..3456cd19d 160000
--- a/3rdparty/hipify_torch
+++ b/3rdparty/hipify_torch
@@ -1 +1 @@
-Subproject commit 12ac3f401261ffa331a4000626a333727f06a0d8
+Subproject commit 3456cd19d4eb5e469317bfcfae1a89b7ab70f6c2
diff --git a/README.rst b/README.rst
index 74e72efde..b7afcbd2a 100644
--- a/README.rst
+++ b/README.rst
@@ -264,15 +264,15 @@ Note that when using `THD` format tensors with CK Fused Attention, one should pa
to indicate that there is no padding between sequences. Otherwise, passing proper tensors will indicate padding between sequences. This is the case
for both the `FusedAttention` and `DotProductAttention` modules.
-FA v3 Kernels in CK Backend
+AITER FA v3 Kernels
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-ROCm TE provides experimental support for flash-attention v3 fwd/bwd kernels using the ck backend for limited fused attention configs.
-To enable FA v3 kernels, the following environment variables can be used:
+ROCm TE supports flash-attention v3 fwd/bwd kernels on gfx942 and gfx950 using AITER backend.
+This functionality can be controlled by the following environment variables:
-* NVTE_CK_USES_FWD_V3 - by default 0, if set to 1, some cases will call the fwd v3 kernel, only applicable to the gfx942 architecture;
-* NVTE_CK_USES_BWD_V3 - by default 0, if set to 1, some cases will call the bwd v3 dqdkdv kernel;
-* NVTE_CK_IS_V3_ATOMIC_FP32 - by default 1, if set to 0 will use atomic fp16/bf16(w/o convert_dq kernel) in bwd pass when NVTE_CK_USES_BWD_V3 is set to 1;
-* NVTE_CK_HOW_V3_BF16_CVT - by default 1, float to bf16 convert type when bwd_v3 is set to 1, 0:RTNE; 1:RTNA; 2:RTZ, only applicable to the gfx942 architecture.
+* NVTE_CK_USES_FWD_V3 - by default 1, if set to 0, v3 kernels will not be used for fwd pass;
+* NVTE_CK_USES_BWD_V3 - by default 1, if set to 0, v3 kernels will not be used for bwd pass;
+* NVTE_CK_IS_V3_ATOMIC_FP32 - by default 1, if set to 0 will use atomic fp16/bf16(w/o convert_dq kernel) in bwd pass when v3 is enabled;
+* NVTE_CK_HOW_V3_BF16_CVT - by default 1, float to bf16 convert type when v3 is enabled, 0:RTNE; 1:RTNA; 2:RTZ, only applicable to the gfx942 architecture.
Float to BFloat16 Conversion in CK Backend (gfx942 only)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
diff --git a/benchmarks/attention/benchmark_attention_rocm.py b/benchmarks/attention/benchmark_attention_rocm.py
index fdae8555b..b126fb022 100644
--- a/benchmarks/attention/benchmark_attention_rocm.py
+++ b/benchmarks/attention/benchmark_attention_rocm.py
@@ -27,19 +27,22 @@
pd.set_option("display.precision", 4)
-# data type
+# -------------------- Benchmark Settings --------------------
+# Data type
dtype = torch.bfloat16
-# number of iterations after 3 warmup iterations
-num_iters = 3
-# checkpointing
+# Number of warmup iterations before profiling
+warmup_iters = 20
+# Number of iterations after warmup iterations
+num_iters = 10
+# Checkpointing attention
ckpt_attn = False
-# workspace optimization path for cuDNN attention
+# Workspace optimization for attention
workspace_opt = True
# QKV memory layout
qkv_layout = "bshd_bshd_bshd"
-# padding between sequences for qkv_format=thd
+# Padding between sequences for qkv_format=thd
pad_between_seqs = False
-# training mode
+# Training mode
is_training = True
model_configs = {
@@ -48,37 +51,96 @@
"test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), # longer seq, mask
"test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), # bias
"test_3": ModelConfig(2, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), # GQA
+ "test_4": ModelConfig(2, 128, 8, 128, 8192, 8192, 0.0, "causal_bottom_right", "no_bias")
}
-# Define DataFrame indices and columns
+# DataFrame indices and columns for results
indices = [model for model in model_configs.keys()]
columns = [
- "FusedAttention Module",
- "FusedAttention Kernels (fwd)",
- "FusedAttention Kernels (bwd)",
- "FusedAttention Kernels (fwd+bwd)",
- "FlashAttention Module",
- "FlashAttention Kernels (fwd)",
- "FlashAttention Kernels (bwd)",
- "FlashAttention Kernels (fwd+bwd)",
- "Fused vs Flash Kernels Speedup (fwd+bwd)",
- "FusedAttention CK Module",
- "FusedAttention CK Kernels (fwd)",
- "FusedAttention CK Kernels (bwd)",
- "FusedAttention CK Kernels (fwd+bwd)",
- "FusedAttention AOTriton Module",
- "FusedAttention AOTriton Kernels (fwd)",
- "FusedAttention AOTriton Kernels (bwd)",
- "FusedAttention AOTriton Kernels (fwd+bwd)",
- ]
-
-output_csv="times.csv"
+ "FusedAttention CK Module",
+ "FusedAttention CK Kernels (fwd)",
+ "FusedAttention CK Kernels (bwd)",
+ "FusedAttention CK Kernels (fwd+bwd)",
+ "FusedAttention CK TFLOPs (fwd)",
+ "FusedAttention CK TFLOPs (bwd)",
+
+ "FlashAttention Module",
+ "FlashAttention Kernels (fwd)",
+ "FlashAttention Kernels (bwd)",
+ "FlashAttention Kernels (fwd+bwd)",
+ "FlashAttention TFLOPs (fwd)",
+ "FlashAttention TFLOPs (bwd)",
+ "Fused vs Flash Kernels Speedup (fwd+bwd)",
+
+ "FusedAttention AOTriton Module",
+ "FusedAttention AOTriton Kernels (fwd)",
+ "FusedAttention AOTriton Kernels (bwd)",
+ "FusedAttention AOTriton Kernels (fwd+bwd)",
+ "FusedAttention AOTriton TFLOPs (fwd)",
+ "FusedAttention AOTriton TFLOPs (bwd)",
+]
+
+# Output CSV filename
+output_csv = "times.csv"
+# Output directory name
+output_dir_name = "profiler_outputs"
+# Current working directory
cwd = os.getcwd()
+
+# All attention backend environment variables
+ATTENTION_ENV_VARS = [
+ "NVTE_FUSED_ATTN",
+ "NVTE_FLASH_ATTN",
+ "NVTE_FUSED_ATTN_AOTRITON",
+ "NVTE_FUSED_ATTN_CK",
+ "NVTE_UNFUSED_ATTN",
+ "NVTE_CK_USES_BWD_V3",
+ "NVTE_CK_USES_FWD_V3",
+ "NVTE_CK_IS_V3_ATOMIC_FP32",
+]
+
+def cleanup_env():
+ """Set all attention-related environment variables to 0."""
+ for var in ATTENTION_ENV_VARS:
+ os.environ[var] = "0"
+
+def setup_backend_env(backend_name, use_ck_bwd_v3=True, use_ck_fwd_v3=True, use_ck_v3_a16=False):
+ cleanup_env()
+
+ if backend_name == "flash":
+ os.environ["NVTE_FLASH_ATTN"] = "1"
+ elif backend_name == "fused_ck":
+ os.environ["NVTE_FUSED_ATTN"] = "1"
+ os.environ["NVTE_FUSED_ATTN_CK"] = "1"
+ os.environ["NVTE_CK_USES_BWD_V3"] = "1" if use_ck_bwd_v3 else "0"
+ if use_ck_bwd_v3:
+ os.environ["NVTE_CK_IS_V3_ATOMIC_FP32"] = "0" if use_ck_v3_a16 else "1"
+ os.environ["NVTE_CK_USES_FWD_V3"] = "1" if use_ck_fwd_v3 else "0"
+ elif backend_name == "fused_aotriton":
+ os.environ["NVTE_FUSED_ATTN"] = "1"
+ os.environ["NVTE_FUSED_ATTN_AOTRITON"] = "1"
+
+# Kernel name patterns for identifying kernels in profiler output
+KERNEL_PATTERNS = {
+ # Flash Attention patterns
+ "flash_fwd": "FmhaFwd",
+ "flash_bwd": "FmhaBwd",
+
+ # CK patterns (v2 and v3)
+ "ck_fwd_v2": "ck_tile::FmhaFwdKernel",
+ "ck_bwd_v2": "ck_tile::FmhaBwd",
+ "ck_fwd_v3": "aiter::fmha_fwd",
+ "ck_bwd_v3": "aiter::fmha_bwd",
+
+ # AOTriton patterns
+ "aotriton_fwd": "attn_fwd",
+ "aotriton_bwd": "bwd",
+}
+
# Runs benchmark with warmup iterations and profiles using rocprof
def benchmark_dot_product_attention(model, attention, column_name, dirname):
config = model_configs[model]
- warmup_iters = 3
for i in range(warmup_iters):
attn_fwd, attn_bwd = _run_dot_product_attention(
dtype,
@@ -90,27 +152,30 @@ def benchmark_dot_product_attention(model, attention, column_name, dirname):
pad_between_seqs,
is_training,
)
- os.makedirs(dirname)
- before_files = set(os.listdir("."))
+ os.makedirs(dirname, exist_ok=True)
+ before_files = set(os.listdir(cwd))
# Profiling command using rocprof
+ benchmark_dir = os.path.dirname(os.path.abspath(__file__))
prof_cmd = [
- "env | grep NVTE; "
"rocprof",
"--hip-trace",
"--basenames off",
"python",
"-c",
- f""" "import benchmark_attention_rocm;""",
+ f""" "import sys; sys.path.insert(0, '{benchmark_dir}'); import benchmark_attention_rocm;""",
f"""benchmark_attention_rocm.benchmark_dot_product_attention_profiler("""
f"""'{model}', '{attention}', '{column_name}')" """,
]
prof_cmd = " ".join(prof_cmd)
subprocess.call(prof_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True)
- after_files = set(os.listdir("."))
+ after_files = set(os.listdir(cwd))
new_files = after_files - before_files
for f in new_files:
- shutil.move(f, os.path.join(dirname, f))
+ src_path = os.path.join(cwd, f)
+ dst_path = os.path.join(dirname, f)
+ if os.path.isfile(src_path): # Only move files, not directories
+ shutil.move(src_path, dst_path)
torch.cuda.empty_cache()
# Runs profiler and records timing information
@@ -133,75 +198,96 @@ def benchmark_dot_product_attention_profiler(model, attention, column_name):
torch.cuda.synchronize()
attn_time = time.time() - attn_start
- df_times = pd.read_csv(output_csv, index_col=0)
+ output_csv_path = os.path.join(cwd, output_csv)
+ df_times = pd.read_csv(output_csv_path, index_col=0)
df_times.loc[model, column_name] = attn_time * 1e3 / num_iters
- df_times.to_csv(output_csv)
+ df_times.to_csv(output_csv_path)
torch.cuda.empty_cache()
+# Calculate TFLOPs for attention operations
+def calculate_attention_tflops(batch_size, seq_len, num_heads_q, head_dim_qk, fwd_time_ms, bwd_time_ms, is_causal):
+ # Calculate total fwdFLOPs
+ fwd_flops = (0.5 if is_causal else 1.0) * 4 * batch_size * seq_len * seq_len * num_heads_q * head_dim_qk / 1e12
+ # Calculate forward TFLOPs
+ fwd_tflops = fwd_flops / (fwd_time_ms / 1000.0)
+ # Calculate backward TFLOPs
+ bwd_tflops = (fwd_flops / (bwd_time_ms / 1000.0)) * 2.5
+ return fwd_tflops, bwd_tflops
+
# Helper function to extract timing results from profiler logs
def parse_helper(model, dirname, fwd_search_pattern, bwd_search_pattern, column_name, df_times):
- df = pd.read_csv(os.path.join(dirname,"results.stats.csv"))
+ df = pd.read_csv(os.path.join(dirname, "results.stats.csv"))
- # Extract kernel timing values
- fwd_values = df[df["Name"].str.contains(fwd_search_pattern)]["AverageNs"].to_numpy()
- bwd_values = df[df["Name"].str.contains(bwd_search_pattern)]["AverageNs"].to_numpy()
+ # Extract kernel timing values
+ fwd_values = df[df["Name"].str.contains(fwd_search_pattern, regex=False)]["AverageNs"].to_numpy()
+ bwd_values = df[df["Name"].str.contains(bwd_search_pattern, regex=False)]["AverageNs"].to_numpy()
- if len(bwd_values) == 0:
- return False # CK V3 not supported or kernel_func not found
+ if len(fwd_values) == 0 or len(bwd_values) == 0:
+ return False # Kernels not found
t_attn_avg = np.empty(len(fwd_values) + len(bwd_values))
t_attn_avg[:len(fwd_values)] = fwd_values
t_attn_avg[len(fwd_values):] = bwd_values
-
- # Store results in DataFrame
- df_times.loc[model, f"{column_name} Kernels (fwd)"] = t_attn_avg[:len(fwd_values)].sum() / 1e6
- df_times.loc[model, f"{column_name} Kernels (bwd)"] = t_attn_avg[len(fwd_values):].sum() / 1e6
- df_times.loc[model, f"{column_name} Kernels (fwd+bwd)"] = t_attn_avg.sum() / 1e6
-
+
+ # Store results in DataFrame (convert from ns to ms)
+ fwd_time_ms = t_attn_avg[:len(fwd_values)].sum() / 1e6
+ bwd_time_ms = t_attn_avg[len(fwd_values):].sum() / 1e6
+
+ df_times.loc[model, f"{column_name} Kernels (fwd)"] = fwd_time_ms
+ df_times.loc[model, f"{column_name} Kernels (bwd)"] = bwd_time_ms
+ df_times.loc[model, f"{column_name} Kernels (fwd+bwd)"] = fwd_time_ms + bwd_time_ms
+
+ # Calculate TFLOPs for both forward and backward
+ config = model_configs[model]
+ is_causal = "causal" in config.attn_mask_type.lower()
+ fwd_tflops, bwd_tflops = calculate_attention_tflops(
+ config.batch_size, config.max_seqlen_q, config.num_heads,
+ config.head_dim_qk, fwd_time_ms, bwd_time_ms, is_causal
+ )
+
+ df_times.loc[model, f"{column_name} TFLOPs (fwd)"] = fwd_tflops
+ df_times.loc[model, f"{column_name} TFLOPs (bwd)"] = bwd_tflops
+
return True
# Parses profiler logs for different attention backends
-def parse_results(model, df_times, perf_dir_flash_attn, perf_dir_fused_attn, perf_dir_fused_ck, perf_dir_fused_aotriton, use_ck_bwd_v3):
+def parse_results(model, df_times, perf_dir_flash_attn, perf_dir_fused_ck, perf_dir_fused_aotriton, use_ck_bwd_v3, use_ck_fwd_v3):
+ # Parse Flash Attention
if perf_dir_flash_attn:
- parse_helper(model, perf_dir_flash_attn, "FmhaFwd", "FmhaBwd", "FlashAttention", df_times)
-
- if perf_dir_fused_attn:
- ck_v3_success = False
- if use_ck_bwd_v3:
- ck_v3_success = parse_helper(model, perf_dir_fused_ck, "FmhaFwd", "kernel_func", "FusedAttention", df_times)
- if not ck_v3_success:
- parse_helper(model, perf_dir_fused_ck, "FmhaFwd", "FmhaBwd", "FusedAttention", df_times)
-
- if perf_dir_fused_attn:
- ck_v3_success = False
- if use_ck_bwd_v3:
- ck_v3_success = parse_helper(model, perf_dir_fused_ck, "FmhaFwd", "kernel_func", "FusedAttention CK", df_times)
- if not ck_v3_success:
- parse_helper(model, perf_dir_fused_ck, "FmhaFwd", "FmhaBwd", "FusedAttention CK", df_times)
+ parse_helper(model, perf_dir_flash_attn, KERNEL_PATTERNS["flash_fwd"], KERNEL_PATTERNS["flash_bwd"], "FlashAttention", df_times)
+
+ # Parse FusedAttention CK (use v3 or v2 patterns based on flags)
+ if perf_dir_fused_ck:
+ fwd_pattern = KERNEL_PATTERNS["ck_fwd_v3"] if use_ck_fwd_v3 else KERNEL_PATTERNS["ck_fwd_v2"]
+ bwd_pattern = KERNEL_PATTERNS["ck_bwd_v3"] if use_ck_bwd_v3 else KERNEL_PATTERNS["ck_bwd_v2"]
+ parse_helper(model, perf_dir_fused_ck, fwd_pattern, bwd_pattern, "FusedAttention CK", df_times)
+ # Parse AOTriton
if perf_dir_fused_aotriton:
- parse_helper(model, perf_dir_fused_aotriton, "attn_fwd", "bwd", "FusedAttention AOTriton", df_times)
-
- if perf_dir_flash_attn and perf_dir_fused_attn:
- df_times.loc[model, "Fused vs Flash Kernels Speedup (fwd+bwd)"] = (
- df_times.loc[model, "FlashAttention Kernels (fwd+bwd)"]
- / df_times.loc[model, "FusedAttention Kernels (fwd+bwd)"]
- )
+ parse_helper(model, perf_dir_fused_aotriton, KERNEL_PATTERNS["aotriton_fwd"], KERNEL_PATTERNS["aotriton_bwd"], "FusedAttention AOTriton", df_times)
+
+ # Calculate speedup if both Flash and Fused CK results exist
+ if perf_dir_flash_attn and perf_dir_fused_ck:
+ flash_time = df_times.loc[model, "FlashAttention Kernels (fwd+bwd)"]
+ fused_time = df_times.loc[model, "FusedAttention CK Kernels (fwd+bwd)"]
+ if flash_time > 0 and fused_time > 0:
+ df_times.loc[model, "Fused vs Flash Kernels Speedup (fwd+bwd)"] = flash_time / fused_time
-###############################################################################
# Post-benchmark sanity checks
-###############################################################################
def sanity_checks(
- profiler_root: str = "profiler_outputs",
- csv_path: str = "times.csv",
+ profiler_root: str = None,
+ csv_path: str = None,
tolerance_pct: float = 5.0,
):
"""
• Verifies that every model/backend that *should* have run produced
profiler_root/
/results.stats.csv
- • Checks FusedAttention vs FusedAttention-CK timing within ±tolerance_pct
• Non-zero exit code on any failure (CI friendly)
"""
+ if profiler_root is None:
+ profiler_root = output_dir_name
+ if csv_path is None:
+ csv_path = output_csv
print("\n============= Sanity-check results =============")
ok_overall = True
times_csv_path = os.path.join(cwd, csv_path)
@@ -212,7 +298,6 @@ def sanity_checks(
dir_pattern = {
"FlashAttention": "prof_flash_{model}",
- "FusedAttention": "prof_fused_{model}",
"FusedAttention CK": "prof_fused_ck_{model}",
"FusedAttention AOTriton": "prof_fused_aotriton_{model}",
}
@@ -231,7 +316,6 @@ def sanity_checks(
if flash_ok:
expected["FlashAttention"] = dir_pattern["FlashAttention"]
if fused_ok:
- expected["FusedAttention"] = dir_pattern["FusedAttention"]
if NVTE_Fused_Attn_Backend.NVTE_CK in fused_bes:
expected["FusedAttention CK"] = dir_pattern["FusedAttention CK"]
if NVTE_Fused_Attn_Backend.NVTE_AOTriton in fused_bes:
@@ -247,63 +331,40 @@ def sanity_checks(
ok_overall = False
raise FileNotFoundError(f"Error while profiling {model} [{be}], results.stats.csv not found")
- # Fused Vs Fused CK trace
- if "FusedAttention" in expected and "FusedAttention CK" in expected:
- f_fwd, f_bwd = df.loc[model, ["FusedAttention Kernels (fwd)",
- "FusedAttention Kernels (bwd)"]]
- c_fwd, c_bwd = df.loc[model, ["FusedAttention CK Kernels (fwd)",
- "FusedAttention CK Kernels (bwd)"]]
- if min(f_fwd, f_bwd, c_fwd, c_bwd) > 0:
- rel_fwd = abs(f_fwd - c_fwd) / max(f_fwd, c_fwd)
- rel_bwd = abs(f_bwd - c_bwd) / max(f_bwd, c_bwd)
- if rel_fwd < tol and rel_bwd < tol:
- print(f" [OK ] Fused vs CK diff <= {tolerance_pct}% "
- f"(fwd {rel_fwd*100:.2f} %, bwd {rel_bwd*100:.2f} %)")
- else:
- ok_overall = False
- raise AssertionError(f" Fused vs CK kernel time diff > {tolerance_pct}% "
- f"(fwd {rel_fwd*100:.2f} %, bwd {rel_bwd*100:.2f} %)")
print("-" * 60)
return ok_overall
def main(args):
- output_dir = "profiler_outputs/"
- run_dir = os.path.dirname(__file__)
-
- # Remove from current working directory
- if os.path.exists(os.path.join(cwd, output_dir)):
- shutil.rmtree(os.path.join(cwd, output_dir))
- if os.path.exists(os.path.join(cwd, output_csv)):
- os.remove(os.path.join(cwd, output_csv))
-
- # Remove from run directory
- if os.path.exists(os.path.join(run_dir, output_dir)):
- shutil.rmtree(os.path.join(run_dir, output_dir))
- if os.path.exists(os.path.join(run_dir, output_csv)):
- os.remove(os.path.join(run_dir, output_csv))
-
- os.chdir(run_dir)
+ output_dir = os.path.join(cwd, output_dir_name + "/")
+ output_csv_path = os.path.join(cwd, output_csv)
+
+ # Clean up old outputs in cwd
+ if os.path.exists(output_dir):
+ shutil.rmtree(output_dir)
+ if os.path.exists(output_csv_path):
+ os.remove(output_csv_path)
+
os.makedirs(output_dir)
df_times = pd.DataFrame(index=indices, columns=columns)
df_times = df_times.infer_objects(copy=False)
df_times.fillna(0.0, inplace=True)
df_times.index.name = "Model"
- df_times.to_csv(output_csv)
+ df_times.to_csv(output_csv_path)
device_id = torch.cuda.current_device()
device_properties = torch.cuda.get_device_properties(device_id)
print(
f"Device {device_id}: "
f"{device_properties.name} GPU, "
- f"sm{device_properties.major}{device_properties.minor} compute capability, "
+ f"{device_properties.gcnArchName.split(':')[0]} architecture, "
f"{device_properties.total_memory/1024**3:.1f}GB memory"
)
# Benchmarking starts..
for model in model_configs.keys():
config = model_configs[model]
- available_backends,_, fused_attn_backends = _get_attention_backends(
+ available_backends, _, fused_attn_backends = _get_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
@@ -313,86 +374,73 @@ def main(args):
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
if not(fused_attn_supported or flash_attn_supported):
- print("No attention backend's detected for ", model)
+ print(f"No attention backend detected for {model}")
continue
print(
- f'Running {model} with {"cuDNN attention" if fused_attn_supported else ""}'
+ f'Running {model} with {"Fused Attention" if fused_attn_supported else ""}'
f'{" and flash-attention" if flash_attn_supported else ""}...'
)
- perf_dir_flash_attn, perf_dir_fused_attn, perf_dir_fused_ck, perf_dir_fused_aotriton = None, None, None, None
+ perf_dir_flash_attn, perf_dir_fused_ck, perf_dir_fused_aotriton = None, None, None
- # Benchmark for each attention backend
+ # Benchmark Flash Attention
if flash_attn_supported:
- os.environ.update({
- "NVTE_FUSED_ATTN": "0", "NVTE_FLASH_ATTN": "1",
- "NVTE_FUSED_ATTN_AOTRITON": "0", "NVTE_FUSED_ATTN_CK": "0" , "NVTE_UNFUSED_ATTN": "0"
- })
- perf_dir_flash_attn = os.path.join("profiler_outputs/", f"prof_flash_{model}")
+ setup_backend_env("flash")
+ perf_dir_flash_attn = os.path.join(output_dir, f"prof_flash_{model}")
benchmark_dot_product_attention(model, "FlashAttention", "FlashAttention Module", perf_dir_flash_attn)
- if fused_attn_supported:
-
- os.environ.update({
- "NVTE_FUSED_ATTN": "1", "NVTE_FLASH_ATTN": "0",
- "NVTE_FUSED_ATTN_AOTRITON": "0", "NVTE_FUSED_ATTN_CK": "1", "NVTE_UNFUSED_ATTN": "0"
- })
- if args.use_ck_bwd_v3:
- os.environ["NVTE_CK_USES_BWD_V3"] = "1"
-
- # FusedAttention run
- perf_dir_fused_attn = os.path.join("profiler_outputs/", f"prof_fused_{model}")
- benchmark_dot_product_attention(model, "FusedAttention", "FusedAttention Module", perf_dir_fused_attn)
-
- #FusedAttention CK run
- if NVTE_Fused_Attn_Backend.NVTE_CK in fused_attn_backends:
- perf_dir_fused_ck = os.path.join("profiler_outputs/", f"prof_fused_ck_{model}")
- benchmark_dot_product_attention(model, "FusedAttention", "FusedAttention CK Module", perf_dir_fused_ck)
-
- if NVTE_Fused_Attn_Backend.NVTE_AOTriton in fused_attn_backends:
- #AOTRITON Backend
- os.environ.update({
- "NVTE_FUSED_ATTN_AOTRITON": "1", "NVTE_FUSED_ATTN_CK": "0",
- "NVTE_CK_USES_BWD_V3": "0", "NVTE_UNFUSED_ATTN": "0"
- })
- perf_dir_fused_aotriton = os.path.join("profiler_outputs/", f"prof_fused_aotriton_{model}")
- benchmark_dot_product_attention(model, "FusedAttention", "FusedAttention AOTriton Module", perf_dir_fused_aotriton)
-
- for var in ["NVTE_CK_USES_BWD_V3", "NVTE_FUSED_ATTN_AOTRITON", "NVTE_FUSED_ATTN_CK", "NVTE_FUSED_ATTN", "NVTE_FLASH_ATTN", "NVTE_UNFUSED_ATTN"]:
- os.environ.pop(var, None)
-
- df_times = pd.read_csv("times.csv", index_col=0)
- parse_results(model, df_times, perf_dir_flash_attn, perf_dir_fused_attn, perf_dir_fused_ck, perf_dir_fused_aotriton, args.use_ck_bwd_v3)
- df_times.to_csv("times.csv")
-
- df_times = pd.read_csv("times.csv")
+ # Benchmark Fused Attention CK (with v2/v3 based on flags)
+ if fused_attn_supported and NVTE_Fused_Attn_Backend.NVTE_CK in fused_attn_backends:
+ setup_backend_env("fused_ck", use_ck_bwd_v3=args.use_ck_bwd_v3, use_ck_fwd_v3=args.use_ck_fwd_v3, use_ck_v3_a16=args.use_ck_v3_a16)
+ perf_dir_fused_ck = os.path.join(output_dir, f"prof_fused_ck_{model}")
+ benchmark_dot_product_attention(model, "FusedAttention", "FusedAttention CK Module", perf_dir_fused_ck)
+
+ # AOTriton Backend
+ if fused_attn_supported and NVTE_Fused_Attn_Backend.NVTE_AOTriton in fused_attn_backends:
+ setup_backend_env("fused_aotriton")
+ perf_dir_fused_aotriton = os.path.join(output_dir, f"prof_fused_aotriton_{model}")
+ benchmark_dot_product_attention(model, "FusedAttention", "FusedAttention AOTriton Module", perf_dir_fused_aotriton)
+
+ df_times = pd.read_csv(output_csv_path, index_col=0)
+ parse_results(model, df_times, perf_dir_flash_attn, perf_dir_fused_ck, perf_dir_fused_aotriton, args.use_ck_bwd_v3, args.use_ck_fwd_v3)
+ df_times.to_csv(output_csv_path)
+
+ df_times = pd.read_csv(output_csv_path)
df_times.index = list(model_configs.keys())
- a = df_times[
+ timing_df = df_times[
[
- "FusedAttention Kernels (fwd+bwd)",
+ "FusedAttention CK Kernels (fwd)",
+ "FusedAttention CK Kernels (bwd)",
+ "FusedAttention CK Kernels (fwd+bwd)",
"FlashAttention Kernels (fwd+bwd)",
"Fused vs Flash Kernels Speedup (fwd+bwd)",
]
+ ].copy()
+ timing_df.columns = [
+ "CK fwd (ms)",
+ "CK bwd (ms)",
+ "CK fwd+bwd (ms)",
+ "Flash fwd+bwd (ms)",
+ "CK/Flash Speedup",
]
- a.columns = ["cuDNN fwd+bwd (ms)", "flash-attn fwd+bwd (ms)", "cuDNN vs flash speedup"]
+ print(timing_df)
print()
- print(a)
- if cwd != run_dir:
- final_profiler_dir = os.path.join(cwd, "profiler_outputs")
- if os.path.exists(final_profiler_dir):
- shutil.rmtree(final_profiler_dir)
- shutil.move("profiler_outputs", final_profiler_dir)
-
- final_csv_path = os.path.join(cwd, output_csv)
- if os.path.exists(final_csv_path):
- os.remove(final_csv_path)
- shutil.move(output_csv, final_csv_path)
+ tflops_df = df_times[
+ [
+ "FusedAttention CK TFLOPs (fwd)",
+ "FusedAttention CK TFLOPs (bwd)",
+ ]
+ ].copy()
+ tflops_df.columns = ["CK FWD TFLOPs", "CK BWD TFLOPs"]
+ print(tflops_df)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument("--use_ck_bwd_v3", action="store_true", help="Use NVTE_CK_USES_BWD_V3=1 for CK bwd kernels")
- parser.add_argument("--run_sanity_checks", action="store_true", help="After benchmarking, verify profiler outputs and Fused vs CK timing parity")
+ parser.add_argument("--no_ck_bwd_v3", action="store_false", dest="use_ck_bwd_v3", help="Set NVTE_CK_USES_BWD_V3=0 for CK bwd kernels")
+ parser.add_argument("--no_ck_fwd_v3", action="store_false", dest="use_ck_fwd_v3", help="Set NVTE_CK_USES_FWD_V3=0 for CK fwd kernels")
+ parser.add_argument("--use_ck_v3_a16", action="store_true", help="Use NVTE_CK_IS_V3_ATOMIC_FP32=0 for atomic16. Default is 1")
+ parser.add_argument("--run_sanity_checks", action="store_true", help="After benchmarking, verify profiler outputs.")
args = parser.parse_args()
main(args)
if args.run_sanity_checks:
diff --git a/build_tools/jax.py b/build_tools/jax.py
index ae8e696c8..4e587b965 100644
--- a/build_tools/jax.py
+++ b/build_tools/jax.py
@@ -21,7 +21,12 @@ def xla_path() -> str:
Throws FileNotFoundError if XLA source is not found."""
try:
- from jax.extend import ffi
+ import jax
+ from packaging import version
+ if version.parse(jax.__version__) >= version.parse("0.5.0"):
+ from jax import ffi
+ else:
+ from jax.extend import ffi
except ImportError:
if os.getenv("XLA_HOME"):
xla_home = Path(os.getenv("XLA_HOME"))
diff --git a/build_tools/utils.py b/build_tools/utils.py
index 9414d778c..f848bad74 100644
--- a/build_tools/utils.py
+++ b/build_tools/utils.py
@@ -535,7 +535,8 @@ def hipify(base_dir, src_dir, sources, include_dirs):
extra_files=[],
is_pytorch_extension=True,
hipify_extra_files_only=False,
- show_detailed=False)
+ show_detailed=False,
+ no_math_replace=True)
# Because hipify output_directory == project_directory
# Original sources list may contain previous hipifying results that ends up with duplicated entries
diff --git a/build_tools/wheel_utils/Dockerfile.rocm.manylinux.x86 b/build_tools/wheel_utils/Dockerfile.rocm.manylinux.x86
index 2b78544df..cf5dbb3bc 100644
--- a/build_tools/wheel_utils/Dockerfile.rocm.manylinux.x86
+++ b/build_tools/wheel_utils/Dockerfile.rocm.manylinux.x86
@@ -2,22 +2,32 @@
#
# See LICENSE for license information.
-# This Dockerfile is used to build TransformerEngine wheels for ROCm on x86_64 architecture.
-# It is based on the manylinux_2_28_x86_64 based image with ROCm installed.
-ARG BASE_IMAGE=quay.io/pypa/manylinux_2_28_x86_64:non_existent_rocm_tag
+# This Dockerfile is used to build TransformerEngine wheels for ROCm on x86_64 architecture
+# on top of the manylinux_2_28_x86_64 base image.
+
+# Build args:
+# BASE_IMAGE - Base manylinux image to use. Default: quay.io/pypa/manylinux_2_28_x86_64
+# ROCM_REPO_URL - ROCm repository URL. Default: https://repo.radeon.com/rocm/rhel8/latest/main/
+# GPU_TARGETS - Semicolon separated list of target GPU architectures. Default: "gfx942;gfx950"
+# TARGET_BRANCH - Target branch for TransformerEngine. Default: none (use git default)
+# GPU_TARGETS and TARGET_BRANCH can be overriden when start a container with NVTE_ROCM_ARCH and TARGET_BRANCH environment variables.
+
+# Set base image
+ARG BASE_IMAGE=quay.io/pypa/manylinux_2_28_x86_64
FROM $BASE_IMAGE
-# Setup the build_system repo
-RUN echo -e "[build_system]\nname=ROCm\nbaseurl=https://repo.almalinux.org/build_system/8/x86_64/\nenabled=1\ngpgcheck=0" >/etc/yum.repos.d/build_system.repo
+ARG ROCM_REPO_URL=https://repo.radeon.com/rocm/rhel8/latest/main/
-# Add and enable repos
-RUN dnf update -y || true
-RUN dnf install -y epel-release elrepo-release
-RUN dnf config-manager --set-enabled build_system powertools extras epel elrepo
+# Set up ROCm repo
+RUN echo -e "[rocm]\nname=ROCm\nbaseurl=${ROCM_REPO_URL}\nenabled=1\ngpgcheck=0" > /etc/yum.repos.d/rocm.repo
+
+# Setup packages
+RUN dnf install -y --disablerepo=epel rocm-dev hipblaslt hipblaslt-devel hipcub hipcub-devel
+RUN dnf group install -y "Development Tools" && dnf install -y git cmake llvm-toolset gcc-toolset-12
+
+#Uncomment the next line for ROCm 6.4 cmake workaround: remove newer incomnpatible cmake preinstalled on base image
+#RUN rm /usr/local/bin/cmake || true
-# Setup dev packages
-RUN dnf group install -y "Development Tools" && \
- dnf install -y git cmake llvm-toolset hipblaslt hipblaslt-devel gcc-toolset-12
RUN dnf clean all
RUN rm -rf /var/cache/dnf/*
diff --git a/build_tools/wheel_utils/build_wheels.sh b/build_tools/wheel_utils/build_wheels.sh
index 5320b8a39..5d37ae1d9 100644
--- a/build_tools/wheel_utils/build_wheels.sh
+++ b/build_tools/wheel_utils/build_wheels.sh
@@ -30,11 +30,17 @@ fi
ROCM_BUILD=`${PYBINDIR}python -c "import build_tools.utils as u; print(int(u.rocm_build()))"`
+if [ "$LOCAL_TREE_BUILD" != "1" ]; then
+ if [ "$ROCM_BUILD" = "1" ]; then
+ git pull
+ fi
+ git checkout $TARGET_BRANCH
+ git submodule update --init --recursive
+fi
+
if [ "$ROCM_BUILD" = "1" ]; then
- git pull
+ ${PYBINDIR}pip install setuptools wheel
fi
-git checkout $TARGET_BRANCH
-git submodule update --init --recursive
if $BUILD_METAPACKAGE ; then
cd /TransformerEngine
@@ -50,10 +56,10 @@ if $BUILD_COMMON ; then
WHL_BASE="transformer_engine-${VERSION}"
if [ "$ROCM_BUILD" = "1" ]; then
TE_CUDA_VERS="rocm"
- ${PYBINDIR}pip install ninja dataclasses
- if [ -n "$PYBINDIR" ]; then
- PATH="$PYBINDIR:$PATH" #hipify expects python in PATH
- fi
+ #dataclasses, psutil are needed for AITER
+ ${PYBINDIR}pip install ninja dataclasses psutil
+ #hipify expects python in PATH, also ninja may be installed to python bindir
+ test -n "$PYBINDIR" && PATH="$PYBINDIR:$PATH" || true
else
TE_CUDA_VERS="cu12"
PYBINDIR=/opt/python/cp38-cp38/bin/
diff --git a/ci/jax.sh b/ci/jax.sh
index 80c61ce9b..356588f5b 100755
--- a/ci/jax.sh
+++ b/ci/jax.sh
@@ -50,12 +50,18 @@ run_lbl() {
_test_label=""
}
+run_default_fa_lbl() {
+ if [ $_fus_attn = "$_DEFAULT_FUSED_ATTN" ]; then
+ run_lbl "$@"
+ fi
+}
+
run_test_config() {
echo ==== Run with Fused attention backend: $_fus_attn ====
run_default_fa 1 test_custom_call_compute.py
run_default_fa 1 test_functions.py
run 1 test_fused_attn.py
- NVTE_CK_USES_FWD_V3=1 NVTE_CK_USES_BWD_V3=1 run_lbl "v3" 1 test_fused_attn.py # Using FAv3 for forward and backward pass
+ NVTE_CK_USES_FWD_V3=0 NVTE_CK_USES_BWD_V3=0 run_default_fa_lbl "v2" 3 test_fused_attn.py # Using FAv2 for forward and backward pass
run_default_fa 1 test_helper.py
run_default_fa 1 test_layer.py #it effectevly always uses unfused attention
run_default_fa 1 test_sanity_import.py
@@ -66,33 +72,23 @@ run_test_config() {
run_test_config_mgpu() {
echo ==== Run mGPU with Fused attention backend: $_fus_attn ====
- _JAX_DISABLE_JIT_FLAG=${JAX_DISABLE_JIT:-0}
_ver=$(pip show jaxlib | grep Version)
case "$_ver" in
*0.4.35*)
- # Workaround for distributed tests hang with JIT enabled
- JAX_DISABLE_JIT=1 run 3 test_distributed_fused_attn.py -k 'not (test_context_parallel_allgather_attn[BALANCED or test_context_parallel_ring_attn)'
- _JAX_DISABLE_JIT_FLAG=1
-
- # Run tests that fail with JIT disabled
- #run_lbl "allgather_balanced" 3 test_distributed_fused_attn.py -k 'test_context_parallel_allgather_attn[BALANCED'
-
+ # Workaround for distributed tests hang with xla_flag
+ XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" run 3 test_distributed_fused_attn.py -k 'not test_context_parallel_ring_attn'
+
# Test ring attention with xla_flag --xla_experimental_ignore_channel_id only
- # TODO: remove this flag after jax/xla update
- XLA_FLAGS="--xla_experimental_ignore_channel_id" run_lbl "parallel_ring" 3 test_distributed_fused_attn.py -k test_context_parallel_ring_attn
- ;;
- *0.6.*)
- # Workaround for distributed tests hang with JIT enabled
- JAX_DISABLE_JIT=1 run 3 test_distributed_fused_attn.py -k 'not test_context_parallel_allgather_attn[BALANCED'
- _JAX_DISABLE_JIT_FLAG=1
+ XLA_FLAGS="--xla_experimental_ignore_channel_id" run_lbl "parallel_ring" 3 test_distributed_fused_attn.py -k test_context_parallel_ring_attn
;;
*)
- run 3 test_distributed_fused_attn.py
+ # Workaround for distributed tests hang with xla_flag
+ XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" run 3 test_distributed_fused_attn.py
;;
esac
run_default_fa 3 test_distributed_layernorm.py
- JAX_DISABLE_JIT=$_JAX_DISABLE_JIT_FLAG run_default_fa 3 test_distributed_layernorm_mlp.py
+ XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" run_default_fa 3 test_distributed_layernorm_mlp.py
run_default_fa 3 test_distributed_softmax.py
run_default_fa 3 test_sanity_import.py
diff --git a/ci/pytorch.sh b/ci/pytorch.sh
index 207949ee5..519238521 100755
--- a/ci/pytorch.sh
+++ b/ci/pytorch.sh
@@ -12,7 +12,7 @@ TEST_DIR=${TE_PATH}tests/pytorch
#: ${TEST_WORKERS:=4}
install_prerequisites() {
- pip install 'numpy>=1.22.4,<2.0' pandas
+ pip install 'numpy>=1.22.4' pandas
rc=$?
if [ $rc -ne 0 ]; then
script_error "Failed to install test prerequisites"
@@ -72,7 +72,6 @@ run_test_config(){
run 1 test_sanity.py
run_default_fa 1 test_sanity_import.py
run_default_fa 1 fused_attn/test_fused_attn.py # Backend selection is controlled by the test
- NVTE_CK_USES_FWD_V3=1 NVTE_CK_USES_BWD_V3=1 run_default_fa_lbl "v3" 1 fused_attn/test_fused_attn.py # Using FAv3 for forward and backward pass
run_default_fa 1 triton_kernels/test_cast.py
run_default_fa 1 triton_kernels/test_cast_mxfp8.py
run_default_fa 1 triton_kernels/test_norm_common.py
@@ -81,6 +80,7 @@ run_test_config(){
run_default_fa 1 test_parallel_cross_entropy.py
NVTE_USE_DEQUANTIZE_TRITON=1 NVTE_USE_CAST_TRANSPOSE_TRITON=1 NVTE_USE_RMSNORM_TRITON=1 NVTE_USE_LAYERNORM_TRITON=1 run_default_fa_lbl "triton" 1 test_numerics.py
NVTE_USE_RMSNORM_TRITON=1 run_default_fa_lbl "triton" 1 test_fusible_ops.py
+ NVTE_USE_CAST_TRANSPOSE_TRITON=1 run_default_fa_lbl "triton" 1 test_float8_current_scaling_exact.py
}
run_test_config_mgpu(){
@@ -93,6 +93,7 @@ run_test_config_mgpu(){
run 3 distributed/test_fusible_ops.py
run 3 distributed/test_numerics.py
run 3 distributed/test_torch_fsdp2.py
+ run 3 distributed/test_torch_fsdp2_fp8.py
run 3 fused_attn/test_fused_attn_with_cp.py
fi
}
@@ -111,7 +112,7 @@ run_benchmark() {
return
fi
- python "$BENCH_SCRIPT" --use_ck_bwd_v3 --run_sanity_checks || test_run_error $BENCH_SCRIPT
+ python "$BENCH_SCRIPT" --run_sanity_checks || test_run_error $BENCH_SCRIPT
}
# Single config mode, run it and return result
diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt
index e2a1938ce..9d3ff719b 100644
--- a/tests/cpp/operator/CMakeLists.txt
+++ b/tests/cpp/operator/CMakeLists.txt
@@ -60,6 +60,7 @@ else()
hipify(CUDA_SOURCE_DIR ${cuda_source_dir}
HEADER_INCLUDE_DIR ${header_include_dir}
CUSTOM_MAP_FILE "${TE}/hipify_custom_map.json"
+ NO_MATH_REPLACE
)
get_hipified_list("${test_cuda_sources}" test_hip_sources)
message("${message_line}")
diff --git a/tests/cpp/operator/test_cast_mxfp8.cu b/tests/cpp/operator/test_cast_mxfp8.cu
index 33a9b8629..b0a847d7d 100644
--- a/tests/cpp/operator/test_cast_mxfp8.cu
+++ b/tests/cpp/operator/test_cast_mxfp8.cu
@@ -51,6 +51,9 @@ void scale_block(const ProcessingMethod processing_method,
const size_t j_min,
const size_t j_max,
const size_t cols) {
+#ifdef __HIP_PLATFORM_AMD__
+ using std::isnan, std::isinf;
+#endif
float amax = 0.0f;
// Find the absolute maximum value in the block
@@ -71,17 +74,10 @@ void scale_block(const ProcessingMethod processing_method,
elt *= static_cast(grad[idx]);
}
dbias[j] += elt;
-#ifndef __HIP_PLATFORM_AMD__
if (isinf(elt) || isnan(elt)) {
continue;
}
amax = std::max(amax, std::abs(elt));
-#else
- if (std::isinf(elt) || std::isnan(elt)) {
- continue;
- }
- amax = fmaxf(amax, fabsf(elt));
-#endif
}
}
@@ -312,6 +308,23 @@ void performTest_x1(const ProcessingMethod processing_method,
block_size_cols,
scales_stride);
+
+#ifdef __HIP_PLATFORM_AMD__
+ if (processing_method != ProcessingMethod::CAST_ONLY) {
+ std::vector> mismatch_idx;
+ compare_e8m0_scaling_factors("scales", output_c, ref_output_scales.get(),
+ unpadded_blocks_Y, unpadded_blocks_X, scales_stride, 0.01, rowwise, mismatch_idx);
+
+ if (mismatch_idx.size()) {
+ adjust_ref(mismatch_idx, ref_output_c.get(), unpadded_blocks_Y, unpadded_blocks_X, rows, cols, otype);
+ }
+
+ auto [atol, rtol] = getTolerances(otype);
+ compareResults("output_c", output_c, ref_output_c.get(), rowwise, atol, rtol);
+ }
+ else
+#endif // #ifdef __HIP_PLATFORM_AMD__
+ {
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c", output_c, ref_output_c.get(), rowwise, atol, rtol);
@@ -321,6 +334,7 @@ void performTest_x1(const ProcessingMethod processing_method,
compare_e8m0_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(),
unpadded_blocks_Y, unpadded_blocks_X, scales_stride);
+ }
if (processing_method == ProcessingMethod::CAST_DBIAS || processing_method == ProcessingMethod::CAST_DBIAS_DACT) {
auto [atol_dbias, rtol_dbias] = getTolerances(itype);
@@ -454,7 +468,29 @@ void performTest_x2(const ProcessingMethod processing_method,
block_size_cols,
scales_stride_rowwise,
scales_stride_colwise);
+#ifdef __HIP_PLATFORM_AMD__
+ if (processing_method != ProcessingMethod::CAST_ONLY) {
+ std::vector> mismatch_idx_r;
+ compare_e8m0_scaling_factors("scales_rowwise", output, ref_scales_rowwise.get(),
+ unpadded_blocks_Y_rowwise, unpadded_blocks_X_rowwise, scales_stride_rowwise, 0.01, true, mismatch_idx_r);
+
+ if (mismatch_idx_r.size()) {
+ adjust_ref(mismatch_idx_r, ref_output_c_rowwise.get(), unpadded_blocks_Y_rowwise, unpadded_blocks_X_rowwise, rows, cols, otype);
+ }
+ std::vector> mismatch_idx_c;
+ compare_e8m0_scaling_factors("scales_colwise", output, ref_scales_colwise.get(),
+ unpadded_blocks_Y_colwise, unpadded_blocks_X_colwise, scales_stride_colwise, 0.01, false, mismatch_idx_c);
+
+ if (mismatch_idx_c.size()) {
+ adjust_ref(mismatch_idx_c, ref_output_c_colwise.get(), unpadded_blocks_Y_colwise, unpadded_blocks_X_colwise, rows, cols, otype);
+ }
+ auto [atol, rtol] = getTolerances(otype);
+ compareResults("output_c_rowwise", output, ref_output_c_rowwise.get(), true, atol, rtol);
+ compareResults("output_c_colwise", output, ref_output_c_colwise.get(), false, atol, rtol);
+ } else
+#endif // #ifdef __HIP_PLATFORM_AMD__
+ {
auto [atol, rtol] = getTolerances(otype);
compareResults("output_c_rowwise", output, ref_output_c_rowwise.get(), true, atol, rtol);
compareResults("output_c_colwise", output, ref_output_c_colwise.get(), false, atol, rtol);
@@ -464,6 +500,7 @@ void performTest_x2(const ProcessingMethod processing_method,
compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(),
ref_scales_colwise.get(), unpadded_blocks_Y_colwise,
unpadded_blocks_X_colwise, scales_stride_colwise);
+ }
if (processing_method == ProcessingMethod::CAST_DBIAS || processing_method == ProcessingMethod::CAST_DBIAS_DACT) {
auto [atol_dbias, rtol_dbias] = getTolerances(itype);
@@ -563,7 +600,7 @@ TEST_P(FusedCastMXFP8TestSuite, TestFusedCastMXFP8) {
if (getDeviceComputeCapability() < blackwellComputeCapability) {
GTEST_SKIP();
}
-#endif
+#endif // #ifdef __HIP_PLATFORM_AMD__
using namespace transformer_engine;
using namespace test;
diff --git a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu
index f93c8c9e0..4acbac4fb 100644
--- a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu
+++ b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu
@@ -262,9 +262,24 @@ void performTest_x1(const size_t rows,
block_size_rows,
block_size_cols,
scales_stride);
+#ifdef __HIP_PLATFORM_AMD__
+ std::vector> mismatch_idx;
+ if (rowwise) {
+ compare_e8m0_scaling_factors("rowwise scales", output, ref_output_scales.get(),
+ unpadded_blocks_Y, unpadded_blocks_X, scales_stride, 0.01, true, mismatch_idx);
+ } else {
+ compare_e8m0_scaling_factors("colwise scales", output, ref_output_scales.get(),
+ unpadded_blocks_Y, unpadded_blocks_X, scales_stride, 0.01, false, mismatch_idx);
+ }
+ if (mismatch_idx.size()) {
+ adjust_ref(mismatch_idx, ref_output.get(), unpadded_blocks_Y, unpadded_blocks_X, rows, cols, otype);
+ }
auto [atol, rtol] = getTolerances(otype);
compareResults("output", output, ref_output.get(), rowwise, atol, rtol);
+#else // #ifdef __HIP_PLATFORM_AMD__
+ auto [atol, rtol] = getTolerances(otype);
+ compareResults("output", output, ref_output.get(), rowwise, atol, rtol);
const uint8_t * const gpu_scales_ptr = rowwise
? output.rowwise_cpu_scale_inv_ptr()
@@ -276,6 +291,7 @@ void performTest_x1(const size_t rows,
compare_e8m0_scaling_factors("colwise scales", gpu_scales_ptr, ref_output_scales.get(),
unpadded_blocks_Y, unpadded_blocks_X, scales_stride);
}
+#endif // #ifdef __HIP_PLATFORM_AMD__
}
/**
@@ -361,17 +377,41 @@ void performTest_x2(const size_t rows,
block_size_cols,
scales_stride_rowwise,
scales_stride_colwise);
+#ifdef __HIP_PLATFORM_AMD__
+ std::vector> mismatch_idx_r;
+ compare_e8m0_scaling_factors("scales_rowwise", output,
+ ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise,
+ unpadded_blocks_X_rowwise, scales_stride_rowwise, 0.01, true, mismatch_idx_r);
+
+ if (mismatch_idx_r.size()) {
+ adjust_ref(mismatch_idx_r, ref_output_colwise.get(), unpadded_blocks_Y_rowwise, unpadded_blocks_X_rowwise, rows, cols, otype);
+ }
+
+ std::vector> mismatch_idx_c;
+ compare_e8m0_scaling_factors("scales_colwise", output,
+ ref_scales_colwise.get(), unpadded_blocks_Y_colwise,
+ unpadded_blocks_X_colwise, scales_stride_colwise, 0.01, false, mismatch_idx_c);
+
+ if (mismatch_idx_c.size()) {
+ adjust_ref(mismatch_idx_c, ref_output_rowwise.get(), unpadded_blocks_Y_colwise, unpadded_blocks_X_colwise, rows, cols, otype);
+ }
auto [atol, rtol] = getTolerances(otype);
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("output_c_rowwise", output, ref_output_rowwise.get(), true, atol, rtol);
compareResults("output_c_colwise", output, ref_output_colwise.get(), false, atol, rtol);
+#else // #ifdef __HIP_PLATFORM_AMD__
+ auto [atol, rtol] = getTolerances(otype);
+ auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
+ compareResults("output_c_rowwise", output, ref_output_rowwise.get(), true, atol, rtol);
+ compareResults("output_c_colwise", output, ref_output_colwise.get(), false, atol, rtol);
compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(),
ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise,
unpadded_blocks_X_rowwise, scales_stride_rowwise);
compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(),
ref_scales_colwise.get(), unpadded_blocks_Y_colwise,
unpadded_blocks_X_colwise, scales_stride_colwise);
+#endif // #ifdef __HIP_PLATFORM_AMD__
}
std::vector> matrix_sizes = {
@@ -418,12 +458,12 @@ class CastMXFP8_GatedActTestSuite : public ::testing::TestWithParam
TEST_P(CastMXFP8_GatedActTestSuite, TestCastMXFP8Swiglu) {
#ifdef __HIP_PLATFORM_AMD__
omp_set_num_threads(std::min(128, omp_get_max_threads())); // Using threads = # of vcpus causes occasional errors.
-#else
+#else // #ifdef __HIP_PLATFORM_AMD__
// Skip tests for pre-Blackwell architectures
if (getDeviceComputeCapability() < blackwellComputeCapability) {
GTEST_SKIP();
}
-#endif
+#endif // #ifdef __HIP_PLATFORM_AMD__
using namespace transformer_engine;
diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu
index b731cc701..071470bdf 100644
--- a/tests/cpp/operator/test_cublaslt_gemm.cu
+++ b/tests/cpp/operator/test_cublaslt_gemm.cu
@@ -93,7 +93,7 @@ void compute_ref(
// update ref_d_amax if in fp8
DType dtype = TypeInfo::dtype;
if(isFp8Type(dtype)){
- ref_d_amax = std::max(ref_d_amax, std::fabs(val));
+ ref_d_amax = std::max(ref_d_amax, std::fabs(val));
}
}
}
@@ -127,10 +127,8 @@ void compute_mxfp8_ref(
for(size_t kk = 0; kk < k; kk++){
size_t a_idx = transa ? (ii*k + kk) : (kk*m + ii);
size_t b_idx = transb ? (kk*n + jj) : (jj*k + kk);
- float a_scale_inv_val = (float)std::pow(2,
- a_scale_inv_data[transa ? a_idx/32 : (kk/32 * m + ii)] - 127);
- float b_scale_inv_val = (float)std::pow(2,
- b_scale_inv_data[transb ? (kk/32 * n + jj) : b_idx/32] - 127);
+ float a_scale_inv_val = std::exp2f(a_scale_inv_data[transa ? a_idx/32 : (kk/32 * m + ii)] - 127);
+ float b_scale_inv_val = std::exp2f(b_scale_inv_data[transb ? (kk/32 * n + jj) : b_idx/32] - 127);
val += a_scale_inv_val * (float)a_data[a_idx] * b_scale_inv_val * (float)b_data[b_idx];
}
if(bias_data){
@@ -144,7 +142,7 @@ void compute_mxfp8_ref(
// update ref_d_amax if in fp8
DType dtype = TypeInfo::dtype;
if(isFp8Type(dtype)){
- ref_d_amax = std::max(ref_d_amax, std::fabs(val));
+ ref_d_amax = std::max(ref_d_amax, std::fabs(val));
}
}
}
@@ -177,16 +175,11 @@ std::pair getTestTolerances(const DType type, bool use_fp8, bool
// relax for certain FP8 gemm with hipblaslt
if (use_mxfp8) {
atol = 5e-4;
- /*During hipifying std::max is converted to ::max
- to w/a HIP bug with using std:: in device functions.
- W/o explicitlit , compiler uses non-templated int method variant from HIP headers
- TODO: remove when switch to new hipify version after fixing HIP bug */
- rtol = std::max(rtol, 1e-3);
+ rtol = std::max(rtol, 1e-3);
}
else if (use_fp8) {
atol = 1e-3;
- //TODO: remove (see comment above)
- rtol = std::max(rtol, 5e-3);
+ rtol = std::max(rtol, 1e-2);
}
else if (type == DType::kBFloat16) {
//relax for certain prime number TN gemm
diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu
index 28a122b13..5ad6e084c 100644
--- a/tests/cpp/test_common.cu
+++ b/tests/cpp/test_common.cu
@@ -547,22 +547,9 @@ void compareResults_sequential(const std::string &name, const Tensor &test,
const double mean = (t + r) / 2;
const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6);
const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6);
-#ifndef __HIP_PLATFORM_AMD__
const double cast_mean_p = static_cast(static_cast(mean_p));
const double cast_mean_m = static_cast(static_cast(mean_m));
assertion = !(cast_mean_m == std::min(t, r) && cast_mean_p == std::max(t, r));
-#else
- const double cast_mean_p =
- static_cast(static_cast(static_cast(static_cast(mean_p))));
- const double cast_mean_m =
- static_cast(static_cast(static_cast(static_cast(mean_m))));
- /*During hipifying std::max and std::min are converted to ::max and ::min
- to w/a HIP bug with using std:: in device functions.
- W/o explicitlit , compiler uses non-templated int method variant from HIP headers
- TODO: remove when switch to new hipify version after fixing HIP bug */
- assertion =
- !(cast_mean_m == std::min(t, r) && cast_mean_p == std::max(t, r));
-#endif
}
std::string direction = rowwise ? "rowwise" : "columnwise";
ASSERT_FALSE(assertion) << "Error in tensor " << name << " in "
@@ -603,21 +590,9 @@ static size_t getFirstMismatchIdx(const DType data_type, const T* test_data, con
const double mean = (t + r) / 2;
const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6);
const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6);
-#ifndef __HIP_PLATFORM_AMD__
const double cast_mean_p = static_cast(static_cast(mean_p));
const double cast_mean_m = static_cast(static_cast(mean_m));
assertion = !(cast_mean_m == std::min(t, r) && cast_mean_p == std::max(t, r));
-#else
- const double cast_mean_p =
- static_cast(static_cast(static_cast(static_cast(mean_p))));
- const double cast_mean_m =
- static_cast(static_cast(static_cast(static_cast(mean_m))));
- /*During hipifying std::max and std::min are converted to ::max and ::min
- to w/a HIP bug with using std:: in device functions.
- W/o explicitlit , compiler uses non-templated int method variant from HIP headers
- TODO: remove when switch to new hipify version after fixing HIP bug */
- assertion = !(cast_mean_m == std::min(t, r) && cast_mean_p == std::max(t, r));
-#endif
}
if (assertion && i < first_mismatch_idx) {
first_mismatch_idx = i;
@@ -714,6 +689,74 @@ void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test,
}
}
+#ifdef __HIP_PLATFORM_AMD__
+void compare_e8m0_scaling_factors(const std::string &name, Tensor &output, const uint8_t *ref,
+ const size_t row_blocks, const size_t col_blocks, const size_t stride,
+ double tol, bool rowwise, std::vector> &mismatch_idx) {
+ const uint8_t *const test = rowwise ? output.rowwise_cpu_scale_inv_ptr()
+ : output.columnwise_cpu_scale_inv_ptr();
+
+ const double scale_tol = std::max(1., row_blocks * col_blocks * tol);
+
+ for (int i = 0; i < row_blocks; i++) {
+ for (int j = 0; j < col_blocks; j++) {
+ const int idx = i * stride + j;
+ if (test[idx] != ref[idx]) {
+ int t_scale = static_cast(test[idx]);
+ int r_scale = static_cast(ref[idx]);
+ if (std::abs(t_scale - r_scale) == 1) {
+ mismatch_idx.emplace_back(i, j, r_scale-t_scale);
+ } else {
+ GTEST_FAIL() << "Error in " << name << std::endl
+ << "Mismatch: " << t_scale << " vs "
+ << r_scale << " at index " << idx;
+ }
+ }
+ }
+ }
+ const size_t scale_mismatches = mismatch_idx.size();
+
+ ASSERT_FALSE(scale_mismatches > scale_tol)
+ << "Error in " << name << std::endl << std::setprecision(4)
+ << "Total scale mismatches: " << scale_mismatches << " (" << 100.*(double)scale_mismatches/(double)(row_blocks*col_blocks)
+ << "%) Exceeds tolerance of " << scale_tol << " (" << 100.*tol << "%) mismatches";
+
+ if (scale_mismatches) {
+ std::cout << "\x1b[33mWARNING:\x1b[0m " << scale_mismatches
+ << " scale mismatches were found. This does not imply an accuracy issue." << std::endl;
+ }
+}
+
+void adjust_ref(std::vector> mismatch_idx, void *ref, const size_t row_blocks,
+ const size_t col_blocks, const size_t rows, const size_t cols, DType otype) {
+ TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY( otype, T,
+ T *ref_data = reinterpret_cast(ref);
+ double scale_val;
+ const size_t col_blocks_size = cols / col_blocks;
+ const size_t row_blocks_size = rows / row_blocks;
+ for (const auto &[i, j, scale_diff] : mismatch_idx) {
+ if (scale_diff == 1) {
+ scale_val = 2.;
+ } else if (scale_diff == -1) {
+ scale_val = .5;
+ } else { // Shouldn't ever reach this
+ GTEST_FAIL() << "Error in adjust_ref, |scale_diff| > 1";
+ }
+ size_t ii_min = i * row_blocks_size;
+ const size_t ii_max = std::min(ii_min + row_blocks_size, rows);
+ for (; ii_min < ii_max; ii_min++) {
+ size_t jj_min = j * col_blocks_size;
+ const size_t jj_max = std::min(jj_min + col_blocks_size, cols);
+ for (; jj_min < jj_max; jj_min++) {
+ const size_t data_idx = ii_min * cols + jj_min;
+ ref_data[data_idx] = static_cast(static_cast(ref_data[data_idx]) * scale_val);
+ }
+ }
+ }
+ ); // NOLINT(*)
+}
+#endif // #ifdef __HIP_PLATFORM_AMD__
+
std::pair getTolerances(const DType type) {
switch(type) {
case DType::kFloat32:
diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h
index 7ac2b75a6..6b9514d38 100644
--- a/tests/cpp/test_common.h
+++ b/tests/cpp/test_common.h
@@ -19,6 +19,7 @@
#else
#include
#include "amd_detail/hip_float8.h"
+#include
#endif
#include
@@ -461,6 +462,14 @@ void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test,
const size_t row_blocks, const size_t col_blocks, const size_t stride);
void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref,
const size_t N);
+#ifdef USE_ROCM
+void compare_e8m0_scaling_factors(const std::string &name, Tensor &output, const uint8_t *ref,
+ const size_t row_blocks, const size_t col_blocks, const size_t stride,
+ double tol, bool rowwise, std::vector> &mismatch_idx);
+
+void adjust_ref(std::vector> mismatch_idx, void *ref, const size_t row_blocks,
+ const size_t col_blocks, const size_t rows, const size_t cols, DType otype);
+#endif
std::array get_scale_tensor_dims(const size_t rows, const size_t cols,
const size_t block_size_rows, const size_t block_size_cols);
diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py
index f5b8bbe68..ab32afe40 100644
--- a/tests/jax/test_fused_attn.py
+++ b/tests/jax/test_fused_attn.py
@@ -360,6 +360,18 @@ def _check_configs(self):
"is either BSHD_BSHD_BSHD or THD_THD_THD"
)
+ if self.head_dim_qk == 192 and self.head_dim_v == 128:
+ if self.attn_bias_type != AttnBiasType.NO_BIAS or self.bias_shape is not None:
+ pytest.skip("Aiter currently supports MLA hd192_hd128 only without bias.")
+ if self.attn_mask_type not in (AttnMaskType.CAUSAL_MASK, AttnMaskType.NO_MASK):
+ pytest.skip("Aiter currently supports MLA hd192_hd128 only for CAUSAL or NO_MASK.")
+ if self.dropout_prob != 0.0:
+ pytest.skip("Aiter currently supports MLA hd192_hd128 only without dropout.")
+ if self.qkv_layout != QKVLayout.BSHD_BSHD_BSHD:
+ pytest.skip("Aiter currently supports MLA hd192_hd128 only with BSHD_BSHD_BSHD layout.")
+ if self.seq_desc_format != SeqDescFormat.Mask:
+ pytest.skip("Aiter currently supports MLA hd192_hd128 only with mask-based SeqDescFormat.")
+
self.backend = FusedAttnHelper(
self.dtype,
self.dtype,
@@ -995,6 +1007,12 @@ def check_dqkv(primitive, reference, pad, idx):
pytest.param(
2, 2048, 2048, 12, 6, 128, 64, jnp.float16, id="2-2048-2048-12-6-128-64-FP16-GQA"
),
+ pytest.param(
+ 10, 4096, 4096, 16, 16, 192, 128, jnp.float16, id="10-4096-4096-16-16-192-128-FP16-MLA",
+ ),
+ pytest.param(
+ 10, 4096, 4096, 16, 16, 192, 128, jnp.bfloat16, id="10-4096-4096-16-16-192-128-BF16-MLA",
+ ),
],
)
@pytest.mark.parametrize(
diff --git a/tests/pytorch/distributed/run_fsdp2_fp8_model.py b/tests/pytorch/distributed/run_fsdp2_fp8_model.py
new file mode 100644
index 000000000..430ee5fed
--- /dev/null
+++ b/tests/pytorch/distributed/run_fsdp2_fp8_model.py
@@ -0,0 +1,308 @@
+#!/usr/bin/python3
+# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
+# See LICENSE for license information.
+
+
+import os
+import sys
+import argparse
+
+import transformer_engine.pytorch as te
+from transformer_engine.common.recipe import Float8CurrentScaling, Format, DelayedScaling, MXFP8BlockScaling
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+from torch import nn, optim
+from torch.distributed import DeviceMesh
+from torch.distributed._composable.fsdp import fully_shard
+from torch.distributed.device_mesh import init_device_mesh
+from transformer_engine.pytorch import torch_version
+from transformer_engine.pytorch.fp8 import fp8_model_init
+from torch.nn.parallel import DistributedDataParallel as DDP
+from pathlib import Path
+
+class SimpleNet(nn.Module):
+ def __init__(self, input_size, hidden_size, output_size, use_fsdp2=False):
+ super(SimpleNet, self).__init__()
+
+ # LayerNormLinear: fuses LayerNorm + Linear
+ self.ln_linear = te.LayerNormLinear(
+ in_features=input_size,
+ out_features=hidden_size,
+ eps=1e-5,
+ use_fsdp2=use_fsdp2,
+ keep_fp8_weight_transpose_cache=False
+ )
+
+ # LayerNormMLP: fuses LayerNorm + FC1 + Activation + FC2
+ self.ln_mlp = te.LayerNormMLP(
+ hidden_size=hidden_size,
+ ffn_hidden_size=hidden_size * 4, # Typical 4x expansion
+ use_fsdp2=use_fsdp2,
+ keep_fp8_weight_transpose_cache=False
+ )
+
+ # Regular Linear for final projection
+ self.fc_out = te.Linear(
+ hidden_size,
+ output_size,
+ use_fsdp2=use_fsdp2,
+ keep_fp8_weight_transpose_cache=False
+ )
+
+ def forward(self, x):
+ # LayerNormLinear: applies LayerNorm then Linear
+ x = self.ln_linear(x)
+
+ # LayerNormMLP: applies LayerNorm + FC1 + GELU + FC2
+ x = self.ln_mlp(x)
+
+ # Final Linear projection
+ x = self.fc_out(x)
+
+ return x
+
+def save_custom_attrs(module, _SKIP_KEYS = {"_data", "_module", "_transpose"}):
+ custom_attrs = {}
+ for name, param in module.named_parameters():
+ attrs = vars(param)
+ custom_attrs[name] = {k: v for k, v in attrs.items()}
+ for k in _SKIP_KEYS:
+ custom_attrs[name].pop(k, None)
+ return custom_attrs
+
+
+def restore_custom_attrs(module, custom_attrs):
+ for name, param in module.named_parameters():
+ if name in custom_attrs:
+ for attr_name, attr_value in custom_attrs[name].items():
+ setattr(param, attr_name, attr_value)
+
+
+def _parse_args(argv=None, namespace=None):
+ parser = argparse.ArgumentParser(description="Toy example for debugging fully_shard()")
+ parser.add_argument("--input-size", type=int, default=2048, help="Input size for the model")
+ parser.add_argument("--hidden-size", type=int, default=2048, help="Hidden layer size")
+ parser.add_argument("--output-size", type=int, default=2048, help="Output size for the model")
+ parser.add_argument("--batch-size", type=int, default=2048, help="Output size for the model")
+ parser.add_argument(
+ "--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8."
+ )
+ parser.add_argument(
+ "--iter", type=int, default=10, help="Number of iterations for forward pass"
+ )
+ parser.add_argument('--profile', action='store_true',
+ help='Enable pytorch profiling.')
+ parser.add_argument('--profile-step-start', type=int, default=6,
+ help='Global step to start profiling.')
+ parser.add_argument('--profile-step-end', type=int, default=7,
+ help='Global step to stop profiling.')
+ parser.add_argument('--profile-ranks', nargs='+', type=int, default=[0],
+ help='Global ranks to profile.')
+ parser.add_argument('--tensorboard-dir', type=str, default='./fsdp2_tensorboard',
+ help='Write TensorBoard logs to this directory.')
+ parser.add_argument('--gradients-save-file', type=str, default='all_iters.pt',
+ help='Write all the gradients across all the iterations to this file.')
+ parser.add_argument("--seed", type=int, default=42, help="RNG seed.")
+ parser.add_argument("--use-fsdp2", action='store_true',
+ help='Enable New FSDP2 training.')
+ parser.add_argument("--memory-profile", action='store_true',
+ help='profile memory traces')
+ parser.add_argument(
+ "--recipe",
+ type=str,
+ choices=["delayed", "mxfp8", "current"],
+ default="delayed",
+ help="Select the training recipe to use: 'delayed', 'mxfp8', or 'current'."
+ )
+
+ # Adding hsdp_dim as a list argument, comma-separated
+ parser.add_argument(
+ "--sharding-dims",
+ type=int,
+ nargs="+",
+ help='FSDP/HSDP sharding dimensions ("replicate", "shard")',
+ )
+ args = parser.parse_args(argv, namespace)
+ if args.sharding_dims:
+ assert len(args.sharding_dims) <= 2
+ return args
+
+
+sub_modules_to_wrap = [te.Linear, te.LayerNormLinear, te.LayerNormMLP]
+
+
+def _train(args):
+ assert "TORCHELASTIC_RUN_ID" in os.environ
+ WORLD_RANK = int(os.getenv("RANK", "0"))
+ WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
+ LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
+ LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
+ assert LOCAL_SIZE == WORLD_SIZE
+
+ # Set device and initialize RNG states
+ torch.cuda.set_device(WORLD_RANK)
+ torch.manual_seed(args.seed)
+ torch.cuda.manual_seed(args.seed)
+
+ # Initialize torch.distributed global process group and get DP/TP groups
+ dist_init_kwargs = {
+ "backend": "nccl",
+ "rank": WORLD_RANK,
+ "world_size": WORLD_SIZE,
+ }
+ assert dist.is_nccl_available()
+ dist.init_process_group(**dist_init_kwargs)
+ nccl_world = dist.new_group(backend="nccl")
+ device = torch.device(f"cuda:{LOCAL_RANK}")
+
+ # FP8 Configuration
+ if args.recipe == "current":
+ fp8_recipe = Float8CurrentScaling()
+ elif args.recipe == "mxfp8":
+ fp8_recipe = MXFP8BlockScaling()
+ elif args.recipe == "delayed":
+ fp8_recipe = DelayedScaling()
+ else:
+ raise ValueError(f"Unsupported recipe: {args.recipe}")
+
+ if args.memory_profile:
+ torch.cuda.memory._record_memory_history(enabled='all', context='all', stacks='all')
+ if args.fp8_init:
+ # Build the model with the specified context
+ with fp8_model_init(enabled = True):
+ model = SimpleNet(args.input_size, args.hidden_size, args.output_size, use_fsdp2=args.use_fsdp2)
+ else:
+ model = SimpleNet(args.input_size, args.hidden_size, args.output_size, use_fsdp2=args.use_fsdp2)
+ # Move the model to the correct device
+ if not args.memory_profile:
+ model.load_state_dict(torch.load('fsdp_model.pth'))
+ model.to(device)
+
+ # Creating a DeviceMesh for fully_shard
+ world_size = int(WORLD_SIZE)
+ device_ids = list(range(world_size))
+
+ # Apply FSDP/HSDP
+ if args.use_fsdp2:
+ custom_attrs = save_custom_attrs(model)
+ if LOCAL_RANK == 0:
+ print(f"Rank {LOCAL_RANK}: Applying FSDP fully_shard() to the model...")
+ print(f"sharding-dims:{args.sharding_dims}")
+ # Setup the sharding mesh for FSDP/HSDP
+ if args.sharding_dims == None: # FSDP
+ mesh = DeviceMesh("cuda", device_ids)
+ elif len(args.sharding_dims) == 1:
+ assert args.sharding_dims[0] == device_ids[-1] + 1
+ mesh = DeviceMesh("cuda", device_ids)
+ elif len(args.sharding_dims) == 2: # HSDP
+ assert args.sharding_dims[0] * args.sharding_dims[1] == device_ids[-1] + 1
+ mesh = init_device_mesh(
+ "cuda",
+ (args.sharding_dims[0], args.sharding_dims[1]),
+ mesh_dim_names=("replicate", "shard"),
+ )
+ else:
+ assert False
+ for sub_module in model.modules():
+ if any(
+ isinstance(sub_module, sub_module_to_wrap) for sub_module_to_wrap in sub_modules_to_wrap
+ ):
+ fully_shard(sub_module, mesh=mesh)
+ fully_shard(model, mesh=mesh, reshard_after_forward=True)
+ restore_custom_attrs(model, custom_attrs)
+ else:
+ model = DDP(model, device_ids=[LOCAL_RANK])
+
+ optimizer = te.optimizers.FusedAdam(model.parameters(), lr=1e-3)
+
+ input_path = Path("shared_input.pt")
+ if input_path.exists():
+ input_data = torch.load(input_path).to(device)
+ else:
+ input_data = torch.randn(args.batch_size, args.input_size, requires_grad=True).to(device)
+ torch.save(input_data.cpu(), input_path)
+ print("Generated and saved shared input tensor.")
+
+ out_tensors = []
+ prof = None
+ if (
+ args.profile
+ and torch.distributed.get_rank() in args.profile_ranks
+ ):
+ prof = torch.profiler.profile(
+ activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
+ schedule=torch.profiler.schedule(
+ wait=max(args.profile_step_start - 1, 0),
+ warmup=1 if args.profile_step_start > 0 else 0,
+ active=args.profile_step_end - args.profile_step_start,
+ repeat=1,
+ ),
+ on_trace_ready=torch.profiler.tensorboard_trace_handler(args.tensorboard_dir),
+ record_shapes=True,
+ profile_memory=True,
+ with_stack=True,
+ )
+ prof.start()
+ for iteration in range(args.iter):
+ if LOCAL_RANK == 0:
+ print(f"Starting iteration...{iteration}")
+ if args.profile and torch.distributed.get_rank() in args.profile_ranks:
+ prof.step()
+
+ # Zero the parameter gradients
+ optimizer.zero_grad()
+ with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
+ output = model(input_data)
+ target = torch.randn(args.batch_size, args.output_size).to(device)
+ loss = F.mse_loss(output, target)
+ loss.backward()
+ optimizer.step()
+ if LOCAL_RANK == 0:
+ print(f"Rank {LOCAL_RANK}: Iteration {iteration} completed.")
+
+ if not args.profile and not args.memory_profile:
+ with torch.no_grad():
+ for name, p in model.named_parameters():
+ full_grad = None
+ if p.grad is not None and hasattr(p.grad, 'full_tensor'):
+ # This call is required to be executed on ALL ranks
+ # to complete the collective communication.
+ full_grad = p.grad.full_tensor().detach().clone()
+ elif p.grad is not None:
+ full_grad = p.grad.detach().clone()
+ # 2. Only Rank 0 stores the result
+ if LOCAL_RANK == 0 and p.requires_grad:
+ out_tensors.append((name, full_grad))
+ if (
+ args.profile
+ and iteration == args.profile_step_end
+ and torch.distributed.get_rank() in args.profile_ranks
+ ):
+ prof.stop()
+
+ if (not args.profile and not args.memory_profile) and LOCAL_RANK == 0:
+ torch.save(out_tensors, args.gradients_save_file)
+
+ if args.memory_profile:
+ snapshot = torch.cuda.memory._snapshot()
+ import pickle
+ with open('memory_snapshot.pickle', 'wb') as f:
+ pickle.dump(snapshot, f)
+ # To disable memory history recording when no longer needed
+ torch.cuda.memory._record_memory_history(enabled=None)
+
+ # NOTE: In PyTorch < 2.6 there’s a teardown race where one rank may call
+ # destroy_process_group() while other ranks still have in-flight NCCL ops,
+ # which can trigger a NCCL/RCCL comm error. Newer releases (>= 2.6) fixed
+ # this, but we kept a version-guarded barrier on older Torch for stability.
+ if torch_version() < (2, 6, 0):
+ dist.barrier(device_ids=[torch.cuda.current_device()])
+ dist.destroy_process_group()
+
+ return 0
+
+
+if __name__ == "__main__":
+ sys.exit(_train(_parse_args()))
diff --git a/tests/pytorch/distributed/test_torch_fsdp2_fp8.py b/tests/pytorch/distributed/test_torch_fsdp2_fp8.py
new file mode 100644
index 000000000..f5d8b03cc
--- /dev/null
+++ b/tests/pytorch/distributed/test_torch_fsdp2_fp8.py
@@ -0,0 +1,118 @@
+#!/usr/bin/python3
+# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
+# See LICENSE for license information.
+
+import os
+from typing import List
+import pytest
+import subprocess
+from pathlib import Path
+from transformer_engine.pytorch import torch_version
+from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
+import torch
+from run_fsdp2_fp8_model import SimpleNet
+
+fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
+mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
+
+NUM_PROCS: int = torch.cuda.device_count()
+
+def assert_allclose(
+ l1: List[torch.Tensor], l2: List[torch.Tensor], atol: float, rtol: float = None
+) -> bool:
+ """Ensures two lists are equal."""
+ assert len(l1) == len(l2), "Unequal number of outputs."
+ for i, (t1, t2) in enumerate(zip(l1, l2)):
+ tols = dict(atol=atol)
+ if rtol is not None:
+ tols["rtol"] = rtol
+ result = torch.allclose(t1, t2, **tols)
+ if not result:
+ diff = torch.abs(t1 - t2)
+ tol = atol + (rtol * torch.abs(t2))
+ exceed_mask = diff > tol
+ if exceed_mask.any():
+ indices = torch.nonzero(exceed_mask, as_tuple=True)
+ max_diff = diff[exceed_mask].max()
+ max_idx = (diff[exceed_mask] == max_diff).nonzero(as_tuple=True)[0][0]
+ max_location = [idx[max_idx].item() for idx in indices]
+ msg = (
+ f"Outputs not close enough in tensor at idx={i}. "
+ f"Maximum difference at location {max_location} "
+ f"with {t1[exceed_mask][max_idx].item()} vs {t2[exceed_mask][max_idx].item()} "
+ f"(diff {max_diff.item()})."
+ )
+ raise AssertionError(msg)
+
+def _run_test(fp_init, recipe):
+ test_dir = Path(__file__).parent.resolve()
+ fsdp_script = test_dir / "run_fsdp2_fp8_model.py"
+
+ test_cmd = ["torchrun", f"--nproc_per_node={NUM_PROCS}", "--master-port=29501", str(fsdp_script)]
+
+ if fp_init:
+ test_cmd += ["--fp8-init"]
+ test_cmd += ["--recipe", recipe]
+
+ subprocess.run(test_cmd + ['--use-fsdp2','--gradients-save-file', 'all_iters_fsdp2.pt'], env=os.environ, check=True)
+ subprocess.run(test_cmd + ['--gradients-save-file', 'all_iters_dp.pt'], env=os.environ, check=True)
+
+ # Load outputs
+ output_fsdp = torch.load("all_iters_fsdp2.pt", map_location="cpu")
+ output_dp = torch.load("all_iters_dp.pt", map_location="cpu")
+
+ for idx, (te_output_no_cache, te_output_cache) in enumerate(zip(output_fsdp, output_dp)):
+
+ print(f"Comparing FSDP {te_output_no_cache[0]}, DDP {te_output_cache[0]} at index {idx}...")
+ assert_allclose(te_output_no_cache[1], te_output_cache[1], atol=0, rtol=0)
+ print(f"Tensor at index {idx} passed comparison.")
+
+
+@pytest.fixture
+def cleanup_artifacts():
+ yield # run the test first
+ for fname in ["all_iters_fsdp2.pt", "all_iters_dp.pt", "fsdp_model.pth", "shared_input.pt"]:
+ if os.path.exists(fname):
+ os.remove(fname)
+
+@pytest.mark.skipif(NUM_PROCS < 4, reason="Requires 4+ GPUs")
+@pytest.mark.skipif(NUM_PROCS % 2 != 0, reason="Requires even number of GPUs")
+@pytest.mark.skipif(not torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+")
+@pytest.mark.parametrize("fp8_init", ([False]))
+@pytest.mark.parametrize("recipe", (["delayed", "current", "mxfp8"]))
+@pytest.mark.usefixtures("cleanup_artifacts")
+def test_distributed(fp8_init, recipe):
+
+ batch_size = 2048
+ input_size = 2048
+ from pathlib import Path
+
+ input_path = Path("shared_input.pt")
+ if input_path.exists():
+ input_data = torch.load(input_path).to('cuda')
+ else:
+ input_data = torch.randn(batch_size, input_size, requires_grad=True).to('cuda')
+ torch.save(input_data.cpu(), input_path)
+ print("Generated and saved shared input tensor.")
+
+ model = SimpleNet(input_size, 2048, 2048)
+ torch.save(model.state_dict(), 'fsdp_model.pth')
+
+ if torch.cuda.device_count() < 4:
+ pytest.skip("FSDP2 test requires at least 4 GPUs")
+
+ if fp8_init and not fp8_available:
+ pytest.skip(reason_for_no_fp8)
+ if recipe == "mxfp8" and not mxfp8_available:
+ pytest.skip(reason_for_no_mxfp8)
+
+ _run_test(fp8_init, recipe)
+
+
+def test_dummy() -> None:
+ """Dummy test
+
+ pytest returns exit code 5 if all tests are skipped.
+
+ """
+ pass
diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py
index 7c2a09a7d..a9e980288 100644
--- a/tests/pytorch/fused_attn/test_fused_attn.py
+++ b/tests/pytorch/fused_attn/test_fused_attn.py
@@ -92,7 +92,8 @@ def __del__(self):
@pytest.fixture(autouse=True)
def reset_attn_backend():
env = EnvVarCleaner(["NVTE_FLASH_ATTN", "NVTE_FUSED_ATTN", "NVTE_UNFUSED_ATTN",
- "NVTE_FUSED_ATTN_CK", "NVTE_FUSED_ATTN_AOTRITON"])
+ "NVTE_FUSED_ATTN_CK", "NVTE_FUSED_ATTN_AOTRITON",
+ "NVTE_CK_USES_FWD_V3", "NVTE_CK_USES_BWD_V3"])
yield
@@ -375,7 +376,7 @@ def test_dot_product_attention(
if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2:
pytest.skip("Less than two backends to compare.")
- is_training = config.head_dim_qk <= 128 and config.head_dim_v <= 128
+ is_training = config.head_dim_qk <= 192 and config.head_dim_v <= 128
# UnfusedDotProductAttention backend
if unfused_attn_supported:
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(
@@ -421,6 +422,8 @@ def test_dot_product_attention(
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
os.environ["NVTE_FUSED_ATTN_CK"] = "1"
os.environ["NVTE_FUSED_ATTN_AOTRITON"] = "0"
+ os.environ["NVTE_CK_USES_FWD_V3"] = "1"
+ os.environ["NVTE_CK_USES_BWD_V3"] = "1"
fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention(
dtype,
config,
@@ -432,8 +435,21 @@ def test_dot_product_attention(
is_training,
share_cu_seqlens_ref,
)
- del os.environ["NVTE_FUSED_ATTN_CK"]
- del os.environ["NVTE_FUSED_ATTN_AOTRITON"]
+ if IS_HIP_EXTENSION:
+ os.environ["NVTE_CK_USES_FWD_V3"] = "0"
+ os.environ["NVTE_CK_USES_BWD_V3"] = "0"
+ fused_attn_fwd_2, fused_attn_bwd_2 = _run_dot_product_attention(
+ dtype,
+ config,
+ "FusedAttention",
+ ckpt_attn,
+ qkv_layout,
+ workspace_opt,
+ pad_between_seqs,
+ is_training,
+ share_cu_seqlens_ref,
+ )
+
# FlashAttention backend
if flash_attn_supported:
@@ -469,6 +485,11 @@ def test_dot_product_attention(
torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols)
for i, _ in enumerate(fused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols)
+ if IS_HIP_EXTENSION:
+ logging.info("[test_dot_product_attention]: fused attn backend 0 vs 2")
+ torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_2, **tols)
+ for i, _ in enumerate(fused_attn_bwd):
+ torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_2[i], **tols)
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@@ -500,6 +521,12 @@ def test_dpa_checkpoint(dtype, model_configs, model):
"mla_3_1": ModelConfig(
8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=128
), # inference
+ "mla_4_0": ModelConfig(
+ 10, 16, 16, 192, 4096, 4096, 0.0, "causal", "no_bias", head_dim_v=128
+ ),
+ "mla_4_1": ModelConfig(
+ 10, 16, 16, 192, 4096, 4096, 0.0, "no_mask", "no_bias", head_dim_v=128
+ ),
}
diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py
index 5367cc0ee..73c35eace 100644
--- a/tests/pytorch/test_numerics.py
+++ b/tests/pytorch/test_numerics.py
@@ -737,10 +737,6 @@ def test_gpt_full_activation_recompute(
pytest.skip(reason_for_no_fp8)
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
- if IS_HIP_EXTENSION:
- use_cast_transpose_triton = bool( int(os.environ.get('NVTE_USE_CAST_TRANSPOSE_TRITON', '0')) )
- if fp8 and recipe.float8_current_scaling() and use_cast_transpose_triton:
- pytest.skip("Float8 Current Scaling unsupported for full recompute.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
@@ -1959,9 +1955,6 @@ def test_grouped_linear_accuracy(
if IS_HIP_EXTENSION:
if dtype not in (torch.float32,) and fuse_wgrad_accumulation and not fp8:
pytest.skip(f"Rocm does not support fused wgrad accumulation for {dtype}.")
- use_cast_transpose_triton = bool( int(os.environ.get('NVTE_USE_CAST_TRANSPOSE_TRITON', '0')) )
- if fp8 and recipe.float8_current_scaling() and use_cast_transpose_triton:
- pytest.skip("Float8 Current Scaling unsupported for grouped linear accuracy.")
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if fp8 and recipe.mxfp8() and not mxfp8_available:
diff --git a/tests/pytorch/triton_kernels/test_cast.py b/tests/pytorch/triton_kernels/test_cast.py
index 81f01901e..1335e1dc2 100644
--- a/tests/pytorch/triton_kernels/test_cast.py
+++ b/tests/pytorch/triton_kernels/test_cast.py
@@ -1,16 +1,20 @@
# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# License for AMD contributions = MIT. See LICENSE for more information
-import os
import pytest
import torch
from transformer_engine.pytorch.triton_kernels.cast import te_quantize_triton
-from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
+from transformer_engine.pytorch.triton_kernels.cast_transpose import _compute_scale_from_amax_triton
+from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from transformer_engine.pytorch.triton_kernels.common import te_dtype_to_torch_dtype
import transformer_engine_torch as tex
from test_common import te_compare_results, fill_uniform, get_tolerances
+from transformer_engine.pytorch.fp8 import fp8_autocast
+from transformer_engine.common import recipe
+from transformer_engine.pytorch.utils import get_torch_float8_e4m3_type, get_torch_float8_e5m2_type
+@pytest.mark.parametrize("scaling", ("delayed", "current"))
@pytest.mark.parametrize("shape",
[
(16 ),
@@ -32,17 +36,30 @@
])
@pytest.mark.parametrize("in_dtype", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("out_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
-def test_quantize(shape, in_dtype, out_dtype):
+def test_quantize(scaling, shape, in_dtype, out_dtype):
input_tensor = fill_uniform(shape, dtype=in_dtype)
- scale_tensor = torch.rand(1, dtype=torch.float32, device='cuda') * 3.0 - 2.0
- amax_tensor = torch.zeros(1, dtype=torch.float32, device='cuda')
- triton_quantizer = Float8Quantizer(scale=scale_tensor, amax=amax_tensor, fp8_dtype=out_dtype)
+ if scaling == "current":
+ triton_quantizer = Float8CurrentScalingQuantizer(fp8_dtype=out_dtype, device="cuda")
+ tex_quantizer = Float8CurrentScalingQuantizer(fp8_dtype=out_dtype, device="cuda")
+
+ with fp8_autocast(enabled=True, fp8_recipe=recipe.Float8CurrentScaling()):
+ quantized_out_triton = te_quantize_triton(input_tensor, quantizer=triton_quantizer)
+ quantized_out_tex = tex.quantize(input_tensor, tex_quantizer)
+
+ elif scaling == "delayed":
+ scale_tensor = torch.rand(1, dtype=torch.float32, device='cuda') * 3.0 - 2.0
+ amax_tensor = torch.zeros(1, dtype=torch.float32, device='cuda')
+
+ triton_quantizer = Float8Quantizer(scale=scale_tensor, amax=amax_tensor, fp8_dtype=out_dtype)
+ tex_quantizer = Float8Quantizer(scale=scale_tensor, amax=amax_tensor, fp8_dtype=out_dtype)
+
+ quantized_out_triton = te_quantize_triton(input_tensor, quantizer=triton_quantizer)
+ quantized_out_tex = tex.quantize(input_tensor, tex_quantizer)
+
+ else:
+ raise ValueError(f"unknown scaling method {scaling}")
- quantized_out_triton = te_quantize_triton(input_tensor, quantizer=triton_quantizer)
-
- tex_quantizer = Float8Quantizer(scale=scale_tensor, amax=amax_tensor, fp8_dtype=out_dtype)
- quantized_out_tex = tex.quantize(input_tensor, tex_quantizer)
torch_out_dtype = te_dtype_to_torch_dtype(out_dtype)
atol_q, rtol_q = get_tolerances(torch_out_dtype)
@@ -112,3 +129,41 @@ def test_quantize_bad_transpose(t_shape, fp8_dtype):
quantized_output._transpose = torch.empty(t_shape, device='cuda')
te_quantize_triton(input_tensor, quantizer=quantizer, output=quantized_output)
+
+
+@pytest.mark.parametrize("amax_val", (0.0, float('nan'), float('inf'), -float('inf'), 1.0, 1e-8, 123.456))
+@pytest.mark.parametrize("force_pow_2_scales", (False, True))
+@pytest.mark.parametrize("epsilon", (0.0, 1e-3, 100.0))
+@pytest.mark.parametrize("fp8_dtype", (get_torch_float8_e4m3_type(), get_torch_float8_e5m2_type()))
+def test_compute_scale_from_amax(amax_val, force_pow_2_scales, epsilon, fp8_dtype):
+ max_fp8 = torch.finfo(fp8_dtype).max
+ value_for_inf = float(torch.finfo(torch.float32).max)
+
+ amax_list = [torch.tensor(amax_val, dtype=torch.float32, device="cuda")]
+
+ # TEX path - TEX expects lists for (amaxes, scales, inv_scales)
+ scale_ref = [torch.empty((), dtype=torch.float32, device="cuda")]
+ scale_inv_ref = [torch.empty((), dtype=torch.float32, device="cuda")]
+
+ chunk_size = 2048 * 32 # arbitrary
+ overflow_buf = torch.zeros(1, dtype=torch.int32, device="cuda")
+ tex.multi_tensor_compute_scale_and_scale_inv(
+ chunk_size,
+ overflow_buf,
+ [amax_list, scale_ref, scale_inv_ref],
+ max_fp8,
+ force_pow_2_scales,
+ epsilon,
+ )
+
+ # Triton path & comparison
+ scale_triton = torch.empty((), dtype=torch.float32, device="cuda")
+ scale_inv_triton = torch.empty((), dtype=torch.float32, device="cuda")
+ _compute_scale_from_amax_triton[(1,)](
+ amax_list[0], scale_triton, scale_inv_triton,
+ float(max_fp8), float(epsilon), float(value_for_inf),
+ FORCE_POW_2_SCALES=force_pow_2_scales,
+ )
+
+ torch.testing.assert_close(scale_triton, scale_ref[0], rtol=0.0, atol=0.0)
+ torch.testing.assert_close(scale_inv_triton, scale_inv_ref[0], rtol=0.0, atol=0.0)
diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt
index f70c9f8bb..5b0f1981d 100644
--- a/transformer_engine/common/CMakeLists.txt
+++ b/transformer_engine/common/CMakeLists.txt
@@ -239,6 +239,7 @@ else()
IGNORES "*/aotriton/*"
IGNORES "*/ck_fused_attn/*"
CUSTOM_MAP_FILE "${TE}/hipify_custom_map.json"
+ NO_MATH_REPLACE
)
get_hipified_list("${transformer_engine_SOURCES}" te_hip_sources)
message("${message_line}")
@@ -487,8 +488,16 @@ install(TARGETS transformer_engine DESTINATION .)
set_target_properties(transformer_engine PROPERTIES INSTALL_RPATH "$ORIGIN/lib;$ORIGIN/transformer_engine/lib")
if (USE_ROCM)
+ if("$ENV{ROCM_PATH}" STREQUAL "")
+ set(ROCM_PATH "/opt/rocm")
+ else()
+ set(ROCM_PATH "$ENV{ROCM_PATH}")
+ endif()
+ file(READ "${ROCM_PATH}/.info/version" ROCM_VER)
+ string(STRIP "${ROCM_VER}" ROCM_VER)
+ string(REGEX MATCH "^[0-9]+\\.[0-9]+" ROCM_VER "${ROCM_VER}")
file(WRITE "${CMAKE_CURRENT_BINARY_DIR}/build_info.txt"
- "ROCM_VERSION: ${hip_VERSION_MAJOR}.${hip_VERSION_MINOR}\n"
+ "ROCM_VERSION: ${ROCM_VER}\n"
"GPU_TARGETS: ${CMAKE_HIP_ARCHITECTURES}\n"
)
install(FILES "${CMAKE_CURRENT_BINARY_DIR}/build_info.txt" DESTINATION "transformer_engine/")
diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py
index 871723a0e..c72eca543 100644
--- a/transformer_engine/common/__init__.py
+++ b/transformer_engine/common/__init__.py
@@ -113,10 +113,9 @@ def _get_shared_object_file(library: str) -> Path:
# Case 1: Typical user workflow: Both locations are the same, return any result.
if te_install_dir == site_packages_dir:
- assert (
- so_path_in_install_dir is not None
- ), f"Could not find shared object file for Transformer Engine {library} lib."
- return so_path_in_install_dir
+ if so_path_in_install_dir is not None:
+ return so_path_in_install_dir
+ raise FileNotFoundError(f"Could not find shared object file for Transformer Engine {library} lib.")
# Case 2: ERR! Both locations are different but returned a valid result.
# NOTE: Unlike for source installations, pip does not wipe out artifacts from
@@ -139,7 +138,7 @@ def _get_shared_object_file(library: str) -> Path:
if so_path_in_default_dir is not None:
return so_path_in_default_dir
- raise RuntimeError(f"Could not find shared object file for Transformer Engine {library} lib.")
+ raise FileNotFoundError(f"Could not find shared object file for Transformer Engine {library} lib.")
@functools.lru_cache(maxsize=None)
diff --git a/transformer_engine/common/ck_fused_attn/CMakeLists.txt b/transformer_engine/common/ck_fused_attn/CMakeLists.txt
index c44a930e6..bc34444a0 100644
--- a/transformer_engine/common/ck_fused_attn/CMakeLists.txt
+++ b/transformer_engine/common/ck_fused_attn/CMakeLists.txt
@@ -32,21 +32,40 @@ message(STATUS "AITER V3_ASM_ARCHS: ${V3_ASM_ARCHS}")
list(JOIN V3_ASM_ARCHS ";" V3_ASM_ARCHS_STR)
set(ENV{GPU_ARCHS} "${V3_ASM_ARCHS_STR}")
-if(NOT DEFINED AITER_MHA_PATH)
- # delete the existing aiter/jit/build dir for a clean build
- file(REMOVE_RECURSE "${__AITER_SOURCE_DIR}/aiter/jit/build")
- # compile the libmha_fwd.so and libmha_bwd.so
- set(ENV{AITER_LOG_MORE} 1)
- # fp32 to bf16 cvt env still required for MI300X
- set(ENV{CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT} ${CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT})
- execute_process(
- COMMAND python3 ${__AITER_TEST_DIR}/compile.py
- )
- # libmha_fwd.so and libmha_bwd.so will be under 3rdparty/aiter/op_tests/cpp/mha
- set(__AITER_MHA_PATH ${__AITER_TEST_DIR})
-else()
+if(DEFINED AITER_MHA_PATH)
+ message(STATUS "[AITER-PREBUILT] Using AITER_MHA_PATH=${AITER_MHA_PATH}")
# use pre-built libmha_fwd.so libmha_bwd.so
set(__AITER_MHA_PATH ${AITER_MHA_PATH})
+else()
+ set(AITER_CACHE_VALID FALSE)
+ set(AITER_PREBUILT_DOWNLOAD_SUCCESS FALSE)
+ include("${CMAKE_CURRENT_LIST_DIR}/aiter_prebuilt.cmake")
+
+ # If cached path already exists and is valid
+ is_aiter_cache_valid(AITER_CACHE_VALID)
+
+ if(NOT AITER_CACHE_VALID)
+ # Try downloading prebuilt files if NVTE_AITER_PREBUILT_BASE_URL is set.
+ download_aiter_prebuilt(AITER_PREBUILT_DOWNLOAD_SUCCESS)
+
+ # If not downloaded, Fallback: Build from source
+ if(NOT AITER_PREBUILT_DOWNLOAD_SUCCESS)
+ message(STATUS " [AITER-PREBUILT] Building aiter from source.")
+ # delete the existing aiter/jit/build dir for a clean build
+ file(REMOVE_RECURSE "${__AITER_SOURCE_DIR}/aiter/jit/build")
+ # compile the libmha_fwd.so and libmha_bwd.so
+ set(ENV{AITER_LOG_MORE} 1)
+ # fp32 to bf16 cvt env still required for MI300X
+ set(ENV{CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT} ${CK_FUSED_ATTN_FLOAT_TO_BFLOAT16_DEFAULT})
+ execute_process(
+ COMMAND python3 ${__AITER_TEST_DIR}/compile.py
+ )
+ # libmha_fwd.so and libmha_bwd.so will be under 3rdparty/aiter/op_tests/cpp/mha
+ cache_local_aiter_build(${__AITER_TEST_DIR})
+ endif()
+ endif()
+ set(__AITER_MHA_PATH "${EXTRACT_DIR}")
+ message(STATUS "[AITER-PREBUILT] Using __AITER_MHA_PATH=${__AITER_MHA_PATH}")
endif()
set(ck_fused_attn_SOURCES)
@@ -112,4 +131,3 @@ foreach(ARCH IN LISTS V3_ASM_ARCHS)
DESTINATION ${CMAKE_INSTALL_PREFIX}/${AITER_MHA_INSTALL_PREFIX}/lib/aiter/${ARCH}/
PATTERN "codegen.py" EXCLUDE)
endforeach()
-
diff --git a/transformer_engine/common/ck_fused_attn/aiter_prebuilt.cmake b/transformer_engine/common/ck_fused_attn/aiter_prebuilt.cmake
new file mode 100644
index 000000000..894e3a350
--- /dev/null
+++ b/transformer_engine/common/ck_fused_attn/aiter_prebuilt.cmake
@@ -0,0 +1,125 @@
+# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
+#
+# See LICENSE for license information.
+
+cmake_minimum_required(VERSION 3.21)
+include(FetchContent)
+
+# Extract ROCm version
+set(ROCM_PATH "$ENV{ROCM_PATH}")
+if("${ROCM_PATH}" STREQUAL "")
+ set(ROCM_PATH "/opt/rocm")
+endif()
+file(READ "${ROCM_PATH}/.info/version" ROCM_VER_CONTENT)
+string(STRIP "${ROCM_VER_CONTENT}" ROCM_VER_CONTENT)
+string(REGEX MATCH "^[0-9]+\\.[0-9]+" ROCM_VER "${ROCM_VER_CONTENT}")
+
+# AITER commit
+file(REAL_PATH "${CMAKE_CURRENT_LIST_DIR}/../../../3rdparty/aiter" AITER_DIR)
+execute_process(
+ COMMAND sh -c "git config --global --add safe.directory ${AITER_DIR} 2>/dev/null || true && git -C ${AITER_DIR} rev-parse HEAD"
+ OUTPUT_VARIABLE AITER_SHA
+ OUTPUT_STRIP_TRAILING_WHITESPACE
+)
+
+# Cache key & local paths
+set(KEY "rocm-${ROCM_VER}_aiter-${AITER_SHA}")
+set(CACHE_ROOT "${CMAKE_CURRENT_LIST_DIR}/../../../build/aiter-prebuilts")
+set(EXTRACT_DIR "${CACHE_ROOT}/${KEY}")
+
+# Validate existing cache path
+function(is_aiter_cache_valid CACHE_VALID)
+ if(EXISTS "${EXTRACT_DIR}/libmha_fwd.so" AND EXISTS "${EXTRACT_DIR}/libmha_bwd.so")
+ set(${CACHE_VALID} TRUE PARENT_SCOPE)
+ message(STATUS "[AITER-PREBUILT] Found Cached build files at ${EXTRACT_DIR}")
+ return()
+ endif()
+
+ # Cache is invalid/outdated - clean it
+ file(REMOVE_RECURSE "${CACHE_ROOT}")
+ file(REMOVE_RECURSE "${CMAKE_BINARY_DIR}/_deps")
+endfunction()
+
+# Cache locally built libs
+function(cache_local_aiter_build SOURCE_DIR)
+ file(MAKE_DIRECTORY "${EXTRACT_DIR}")
+ message(STATUS "[AITER-PREBUILT] Caching locally built libs to ${EXTRACT_DIR}")
+ file(COPY "${SOURCE_DIR}/libmha_fwd.so" "${SOURCE_DIR}/libmha_bwd.so" DESTINATION "${EXTRACT_DIR}")
+endfunction()
+
+# Download prebuilt tgz file
+function(download_aiter_prebuilt DOWNLOAD_SUCCESS)
+ if(NOT DEFINED ENV{NVTE_AITER_PREBUILT_BASE_URL} OR "$ENV{NVTE_AITER_PREBUILT_BASE_URL}" STREQUAL "")
+ return()
+ endif()
+
+ set(FILE_URL "$ENV{NVTE_AITER_PREBUILT_BASE_URL}/${KEY}.tar.gz")
+ message(STATUS "[AITER-PREBUILT] NVTE_AITER_PREBUILT_BASE_URL is set - Attempting to download ${KEY}.tar.gz ...")
+
+ # Check if ${KEY}.tar.gz exists in the URL provided.
+ file(DOWNLOAD "${FILE_URL}.sha256" "/tmp/aiter_prebuilt_sha256.txt" STATUS sha_status LOG sha_log)
+ list(GET sha_status 0 sha_code)
+ if(NOT sha_code EQUAL 0)
+ message(WARNING " [AITER-PREBUILT] Prebuild file with Key=${KEY} not available in the NVTE_AITER_PREBUILT_BASE_URL provided.")
+ return()
+ endif()
+ file(READ "/tmp/aiter_prebuilt_sha256.txt" AITER_SHA_CONTENT)
+ string(STRIP "${AITER_SHA_CONTENT}" AITER_SHA_CONTENT)
+
+ file(MAKE_DIRECTORY "${CACHE_ROOT}")
+ FetchContent_Declare(
+ aiter_prebuilt
+ URL "${FILE_URL}"
+ URL_HASH SHA256=${AITER_SHA_CONTENT}
+ SOURCE_DIR "${EXTRACT_DIR}"
+ DOWNLOAD_EXTRACT_TIMESTAMP FALSE
+ )
+
+ # Download & extract prebuilt files
+ FetchContent_MakeAvailable(aiter_prebuilt)
+ message(STATUS "[AITER-PREBUILT] Successfully downloaded.")
+ set(${DOWNLOAD_SUCCESS} TRUE PARENT_SCOPE)
+endfunction()
+
+# Create prebuilt tgz file to upload
+function(create_upload_files)
+ # Locate .so files
+ if (NOT EXISTS "${EXTRACT_DIR}/libmha_fwd.so")
+ message(FATAL_ERROR "[AITER-PREBUILT] Missing libmha_fwd.so")
+ endif()
+ if (NOT EXISTS "${EXTRACT_DIR}/libmha_fwd.so")
+ message(FATAL_ERROR "[AITER-PREBUILT] Missing libmha_bwd.so")
+ endif()
+
+ # Output paths
+ set(OUTPUT_TGZ "/tmp/${KEY}.tar.gz")
+ set(OUTPUT_SHA "/tmp/${KEY}.tar.gz.sha256")
+
+ message(STATUS "[AITER-PREBUILT] Creating prebuilt files...")
+ # Create archive
+ file(ARCHIVE_CREATE
+ OUTPUT "${OUTPUT_TGZ}"
+ PATHS "${KEY}"
+ WORKING_DIRECTORY "${CACHE_ROOT}"
+ FORMAT "gnutar"
+ COMPRESSION "GZip")
+
+ # Compute SHA256
+ file(SHA256 "${OUTPUT_TGZ}" ARCHIVE_HASH)
+ file(WRITE "${OUTPUT_SHA}" "${ARCHIVE_HASH}")
+ message(STATUS "[AITER-PREBUILT] tgz and sha256 files generated successfully:")
+ message(STATUS " ${OUTPUT_TGZ}")
+ message(STATUS " ${OUTPUT_SHA}")
+endfunction()
+
+# ------------------------------------------------------
+# Script-mode entry point (to create upload files)
+# Usage: cmake -DACTION=upload -P /path/to/aiter_prebuilt.cmake
+# ------------------------------------------------------
+if (CMAKE_SCRIPT_MODE_FILE)
+ if (DEFINED ACTION AND ACTION STREQUAL "upload")
+ create_upload_files()
+ else()
+ message(FATAL_ERROR "[AITER-PREBUILT] Invalid ACTION=${ACTION}. Use upload.")
+ endif()
+endif()
\ No newline at end of file
diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp
index b38249f5b..dff1a7626 100644
--- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp
+++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp
@@ -556,7 +556,7 @@ void fused_attn_ck_fwd_impl(
if (env_p != nullptr && std::string(env_p) == "1")
nvte_log_ck_config = true;
}
- bool nvte_ck_uses_fwd_v3 = getenv("NVTE_CK_USES_FWD_V3", 0);
+ bool nvte_ck_uses_fwd_v3 = getenv("NVTE_CK_USES_FWD_V3", 1);
bool is_ragged = nvte_get_qkv_format(layout)==NVTE_QKV_Format::NVTE_THD;
@@ -1037,7 +1037,7 @@ void fused_attn_ck_bwd_impl(
// bwd v3 is optional by enabling the following envs
// default values follows the ck example setting
- bool nvte_ck_uses_bwd_v3 = getenv("NVTE_CK_USES_BWD_V3", 0);
+ bool nvte_ck_uses_bwd_v3 = getenv("NVTE_CK_USES_BWD_V3", 1);
bool nvte_ck_is_v3_atomic_fp32 = getenv("NVTE_CK_IS_V3_ATOMIC_FP32", 1);
int nvte_ck_how_v3_bf16_cvt = getenv("NVTE_CK_HOW_V3_BF16_CVT", 1);
if (nvte_log_ck_config) {
diff --git a/transformer_engine/common/gemm/rocm_gemm.cu b/transformer_engine/common/gemm/rocm_gemm.cu
index 574e8ab7e..9de4cfad7 100644
--- a/transformer_engine/common/gemm/rocm_gemm.cu
+++ b/transformer_engine/common/gemm/rocm_gemm.cu
@@ -750,8 +750,8 @@ protected:
std::getline(is, type_b, csv_sep);
std::getline(is, type_d, csv_sep);
std::getline(is, bias_type, csv_sep);
- is >> cfg.lda >> c >> cfg.ldb >> c >> cfg.ldd >> c >> cfg.scaling_mode >> c;
std::getline(is, aux_type, csv_sep);
+ is >> cfg.lda >> c >> cfg.ldb >> c >> cfg.ldd >> c >> cfg.scaling_mode >> c;
std::getline(is, epi, csv_sep);
std::getline(is, comp, csv_sep);
std::getline(is, scale, csv_sep);
diff --git a/transformer_engine/common/normalization/common.h b/transformer_engine/common/normalization/common.h
index 241a3b77b..fb525c9db 100644
--- a/transformer_engine/common/normalization/common.h
+++ b/transformer_engine/common/normalization/common.h
@@ -459,15 +459,20 @@ void rocm_norm_mxfp8_quantize(LaunchParams &launch_params)
scale_dim_Y_colwise, SCALE_DIM_Y,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
launch_params.z_tensor->dtype(), OType,
- cast_mxfp8_2D_kernel<<>>(
- reinterpret_cast(launch_params.params.z),
- nullptr,
- reinterpret_cast(launch_params.z_tensor->data.dptr),
- reinterpret_cast(launch_params.z_tensor->columnwise_data.dptr),
- scales_rowwise_ptr, scales_colwise_ptr,
- nullptr, nullptr, nullptr,
- rows, cols, scale_stride_rowwise, scale_stride_colwise);););
+ TRANSFORMER_ENGINE_SWITCH_CONDITION(
+ !(cols % (32 * sizeof(compute_t))), IS_ALIGNED,
+ cast_mxfp8_2D_kernel<<>>(
+ reinterpret_cast(launch_params.params.z),
+ nullptr,
+ reinterpret_cast(launch_params.z_tensor->data.dptr),
+ reinterpret_cast(launch_params.z_tensor->columnwise_data.dptr),
+ scales_rowwise_ptr, scales_colwise_ptr,
+ nullptr, nullptr, nullptr,
+ rows, cols, scale_stride_rowwise, scale_stride_colwise);
+ );
+ );
+ );
}
#endif
diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh
index 562ed47a5..4dfd45b3e 100644
--- a/transformer_engine/common/util/cast_gated_kernels.cuh
+++ b/transformer_engine/common/util/cast_gated_kernels.cuh
@@ -847,8 +847,9 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
gated_input.dtype(), IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
output->dtype(), OType,
-
#ifdef __HIP_PLATFORM_AMD__
+ TRANSFORMER_ENGINE_SWITCH_CONDITION(
+ !(cols % (32 * sizeof(IType))), IS_ALIGNED,
const IType *tensor_map_grad = IS_DGATED ? reinterpret_cast(grad.data.dptr) : nullptr;
const IType *tensor_map_input_act = reinterpret_cast(gated_input.data.dptr);
const IType *tensor_map_input_gate = reinterpret_cast(gated_input.data.dptr) + cols;
@@ -918,11 +919,19 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
NVTE_CHECK_CUDA(cudaFuncSetAttribute(
(const void*)cast_mxfp8_gated_kernel,
+ SCALE_DIM_Y, SCALE_DIM_X
+#ifdef __HIP_PLATFORM_AMD__
+ , IS_ALIGNED
+#endif
+ >,
cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size));
cast_mxfp8_gated_kernel
+ SCALE_DIM_Y, SCALE_DIM_X
+#ifdef __HIP_PLATFORM_AMD__
+ , IS_ALIGNED
+#endif
+ >
<<>>(
tensor_map_grad, tensor_map_input_act, tensor_map_input_gate,
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise,
@@ -932,6 +941,9 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
); // NOLINT(*)
); // NOLINT(*)
); // NOLINT(*)
+#ifdef __HIP_PLATFORM_AMD__
+ ); // NOLINT(*)
+#endif
}
template
diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh
index 14818b5a1..5dfe85ed0 100644
--- a/transformer_engine/common/util/cast_kernels.cuh
+++ b/transformer_engine/common/util/cast_kernels.cuh
@@ -999,8 +999,10 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
output->dtype(), OType,
#ifdef __HIP_PLATFORM_AMD__
- cast_mxfp8_2D_kernel<<>>(
+ TRANSFORMER_ENGINE_SWITCH_CONDITION(
+ !(cols % (32 * sizeof(IType))), IS_ALIGNED,
+ cast_mxfp8_2D_kernel<<>>(
reinterpret_cast(input.data.dptr),
(IS_DACT) ? reinterpret_cast(act_input->data.dptr) : nullptr,
reinterpret_cast(output->data.dptr),
@@ -1051,6 +1053,9 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
); // NOLINT(*)
); // NOLINT(*)
); // NOLINT(*)
+#ifdef __HIP_PLATFORM_AMD__
+ ); // NOLINT(*)
+#endif
}
namespace detail {
diff --git a/transformer_engine/common/util/dequantize_kernels.cuh b/transformer_engine/common/util/dequantize_kernels.cuh
index 91e0fc69e..76e05c2d9 100644
--- a/transformer_engine/common/util/dequantize_kernels.cuh
+++ b/transformer_engine/common/util/dequantize_kernels.cuh
@@ -310,8 +310,10 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
output->dtype(), OType,
#ifdef __HIP_PLATFORM_AMD__
- dequantize_mxfp8_kernel
- <<>>(reinterpret_cast(input_data.dptr), reinterpret_cast(output->data.dptr), scales_ptr,
+ TRANSFORMER_ENGINE_SWITCH_CONDITION(
+ !(cols % (32 * sizeof(OType))), IS_ALIGNED,
+ dequantize_mxfp8_kernel
+ <<>>(reinterpret_cast(input_data.dptr), reinterpret_cast(output->data.dptr), scales_ptr,
rows, cols, scales_stride);); // NOLINT(*)
#else // #ifdef __HIP_PLATFORM_AMD__
alignas(64) CUtensorMap tensor_map_input{};
@@ -329,6 +331,9 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s
); // NOLINT(*)
); // NOLINT(*)
); // NOLINT(*)
+#ifdef __HIP_PLATFORM_AMD__
+ ); // NOLINT(*)
+#endif
}
} // namespace dequantization
diff --git a/transformer_engine/common/util/rocm_cast_gated_kernels.cuh b/transformer_engine/common/util/rocm_cast_gated_kernels.cuh
index db1b18af4..b8fee6862 100644
--- a/transformer_engine/common/util/rocm_cast_gated_kernels.cuh
+++ b/transformer_engine/common/util/rocm_cast_gated_kernels.cuh
@@ -43,7 +43,7 @@ __device__ inline float sigmoidf(const float x) { return __frcp_rn(1.0f + __expf
template
+ size_t SCALE_DIM_Y, size_t SCALE_DIM_X, bool IS_ALIGNED>
__global__ void __launch_bounds__(THREADS_PER_CHUNK)
cast_mxfp8_gated_kernel(const IType *grad_ptr,
const IType *input_act,
@@ -76,7 +76,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const int tid_Y = threadIdx.x / THREADS_PER_CHUNK_X;
const int tid_X = threadIdx.x % THREADS_PER_CHUNK_X;
- constexpr size_t VECTOR_WIDTH = 16 / sizeof(OType);
+ constexpr size_t VECTOR_WIDTH = (IS_ALIGNED ?: 2) * 8 / sizeof(OType);
const int thread_offset_Y = tid_Y;
const int thread_offset_X = tid_X;
@@ -136,16 +136,16 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// Initiate bulk tensor copy
if constexpr (IS_DGATED) {
- copy_2d_to_shared(&in_grad_sh[0], grad_ptr, chunk_it_offset_x, chunk_it_offset_y,
+ copy_2d_to_shared(&in_grad_sh[0], grad_ptr, chunk_it_offset_x, chunk_it_offset_y,
cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols);
}
// Act
- copy_2d_to_shared(&in_act_sh[0], input_act, chunk_it_offset_x, chunk_it_offset_y,
+ copy_2d_to_shared(&in_act_sh[0], input_act, chunk_it_offset_x, chunk_it_offset_y,
2*cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols);
// Gate
- copy_2d_to_shared(&in_gate_sh[0], input_gate, chunk_it_offset_x, chunk_it_offset_y,
+ copy_2d_to_shared(&in_gate_sh[0], input_gate, chunk_it_offset_x, chunk_it_offset_y,
2*cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols);
__syncthreads();
@@ -347,19 +347,19 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
__syncthreads();
if constexpr (USE_ROWWISE_SCALING) {
- bulk_tensor_2d_shared_to_global(&out_act_rowwise_sh[0], output_act_rowwise, chunk_it_offset_x,
+ bulk_tensor_2d_shared_to_global(&out_act_rowwise_sh[0], output_act_rowwise, chunk_it_offset_x,
chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols);
if constexpr (IS_DGATED) {
- bulk_tensor_2d_shared_to_global(&out_gate_rowwise_sh[0], output_gate_rowwise, chunk_it_offset_x,
+ bulk_tensor_2d_shared_to_global(&out_gate_rowwise_sh[0], output_gate_rowwise, chunk_it_offset_x,
chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols);
}
}
if constexpr (USE_COLWISE_SCALING) {
- bulk_tensor_2d_shared_to_global(&out_act_colwise_sh[0], output_act_colwise, chunk_it_offset_x,
+ bulk_tensor_2d_shared_to_global(&out_act_colwise_sh[0], output_act_colwise, chunk_it_offset_x,
chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols);
if constexpr (IS_DGATED) {
- bulk_tensor_2d_shared_to_global(&out_gate_colwise_sh[0], output_gate_colwise, chunk_it_offset_x,
+ bulk_tensor_2d_shared_to_global(&out_gate_colwise_sh[0], output_gate_colwise, chunk_it_offset_x,
chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols);
}
}
diff --git a/transformer_engine/common/util/rocm_cast_kernels.cuh b/transformer_engine/common/util/rocm_cast_kernels.cuh
index e7424237e..00c408183 100644
--- a/transformer_engine/common/util/rocm_cast_kernels.cuh
+++ b/transformer_engine/common/util/rocm_cast_kernels.cuh
@@ -27,9 +27,6 @@ constexpr size_t MXFP8_CHUNKS_PER_BLOCK_Y = 1;
constexpr size_t MXFP8_CHUNKS_PER_BLOCK_X = 1;
constexpr size_t MXFP8_CHUNKS_PER_BLOCK = MXFP8_CHUNKS_PER_BLOCK_Y * MXFP8_CHUNKS_PER_BLOCK_X;
constexpr size_t MXFP8_THREADS_PER_CHUNK = 64;
-constexpr size_t MXFP8_BUFFERS_NUM = 2;
-constexpr size_t MXFP8_PREFETCH_BUFFERS_NUM = 1;
-static_assert(MXFP8_PREFETCH_BUFFERS_NUM < MXFP8_BUFFERS_NUM);
constexpr size_t ELEMS_PER_THREAD = 16;
constexpr size_t MXFP8_BUFFER_DIM_Y = 32; // only 32 is supported
@@ -45,11 +42,10 @@ constexpr size_t THREADS_PER_CHUNK_X_COLWISE = MXFP8_CHUNK_DIM_X; // 64
constexpr size_t MXFP8_BUFF_STAGES_NUM =
MXFP8_BUFFER_DIM_Y / THREADS_PER_CHUNK_Y_ROWWISE; // 2 = 32 / 16
constexpr size_t MXFP8_ITERATIONS = MXFP8_CHUNK_DIM_Y / MXFP8_BUFFER_DIM_Y; // 2 = 64 / 32
-static_assert(MXFP8_ITERATIONS >= MXFP8_PREFETCH_BUFFERS_NUM);
template
+ size_t SCALE_DIM_X, bool IS_ALIGNED, bool IS_NORM = false>
__global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
cast_mxfp8_2D_kernel(const IType *input_ptr,
const IType *act_input_ptr,
@@ -83,7 +79,7 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
constexpr size_t THREADS_PER_SCALE_X_ROWWISE =
DIVUP(SCALE_DIM_X, ELEMS_PER_THREAD); // 2 = 32 / 16
constexpr size_t SUBWARP_WIDTH = THREADS_PER_SCALE_X_ROWWISE; // 2
- constexpr size_t VECTOR_WIDTH = 16 / sizeof(OType);
+ constexpr size_t VECTOR_WIDTH = (IS_ALIGNED ?: 2) * 8 / sizeof(OType);
const int block_offset_Y = blockIdx.y * MXFP8_CHUNKS_PER_BLOCK_Y * MXFP8_CHUNK_DIM_Y;
const int block_offset_X = blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X;
@@ -161,11 +157,11 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
const int chunk_it_offset_x = chunk_offset_X;
const size_t row_base = chunk_it_offset_y;
if constexpr (IS_DACT) {
- copy_2d_to_shared(&act_in_sh[0][0], act_input_ptr,
+ copy_2d_to_shared(&act_in_sh[0][0], act_input_ptr,
chunk_it_offset_x, chunk_it_offset_y, cols,
MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, rows, cols);
}
- copy_2d_to_shared(&in_sh[0][0], input_ptr, chunk_it_offset_x,
+ copy_2d_to_shared(&in_sh[0][0], input_ptr, chunk_it_offset_x,
chunk_it_offset_y, cols, MXFP8_SHMEM_DIM_Y,
MXFP8_SHMEM_DIM_X, rows, cols);
__syncthreads();
@@ -301,12 +297,12 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
__syncthreads();
if constexpr (USE_ROWWISE_SCALING) {
- bulk_tensor_2d_shared_to_global(&out_rowwise_sh[0][0], output_rowwise, chunk_it_offset_x,
+ bulk_tensor_2d_shared_to_global(&out_rowwise_sh[0][0], output_rowwise, chunk_it_offset_x,
chunk_it_offset_y, cols, MXFP8_SHMEM_DIM_Y,
MXFP8_SHMEM_DIM_X, rows, cols);
}
if constexpr (USE_COLWISE_SCALING) {
- bulk_tensor_2d_shared_to_global(&out_colwise_sh[0][0], output_colwise, chunk_it_offset_x,
+ bulk_tensor_2d_shared_to_global(&out_colwise_sh[0][0], output_colwise, chunk_it_offset_x,
chunk_it_offset_y, cols, MXFP8_SHMEM_DIM_Y,
MXFP8_SHMEM_DIM_X, rows, cols);
}
diff --git a/transformer_engine/common/util/rocm_dequantize_kernels.cuh b/transformer_engine/common/util/rocm_dequantize_kernels.cuh
index f77a3ef2c..ae5cb4bbd 100644
--- a/transformer_engine/common/util/rocm_dequantize_kernels.cuh
+++ b/transformer_engine/common/util/rocm_dequantize_kernels.cuh
@@ -42,7 +42,7 @@ constexpr size_t THREADS_PER_CHUNK_X_COLWISE = CHUNK_DIM_X;
constexpr size_t ITERATIONS = CHUNK_DIM_Y / BUFFER_DIM_Y; // 8 = 128 / 16
static_assert(ITERATIONS >= 1);
-template
+template
__global__ void __launch_bounds__(THREADS_PER_CHUNK)
dequantize_mxfp8_kernel(const IType *input_ptr,
OType *output_ptr,
@@ -59,7 +59,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
constexpr size_t THREADS_PER_SCALE_X_ROWWISE =
DIVUP(SCALE_DIM_X, ELEMS_PER_THREAD); // 2 = 32 / 16
- constexpr size_t VECTOR_WIDTH = 16 / sizeof(OType);
+ constexpr size_t VECTOR_WIDTH = (IS_ALIGNED ?: 2) * 8 / sizeof(IType);
const int chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const int chunk_offset_X = blockIdx.x * CHUNK_DIM_X;
@@ -86,7 +86,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const int chunk_it_offset_y = chunk_offset_Y + iter * BUFFER_DIM_Y;
const int chunk_it_offset_x = chunk_offset_X;
- copy_2d_to_shared(&in_sh[0][0], input_ptr, chunk_it_offset_x,
+ copy_2d_to_shared(&in_sh[0][0], input_ptr, chunk_it_offset_x,
chunk_it_offset_y, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, rows, cols);
__syncthreads();
@@ -127,7 +127,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
__syncthreads();
- bulk_tensor_2d_shared_to_global(&out_sh[0][0], output_ptr, chunk_it_offset_x,
+ bulk_tensor_2d_shared_to_global(&out_sh[0][0], output_ptr, chunk_it_offset_x,
chunk_it_offset_y, cols, SHMEM_DIM_Y,
SHMEM_DIM_X, rows, cols);
diff --git a/transformer_engine/common/util/rocm_vectorized_2d.cuh b/transformer_engine/common/util/rocm_vectorized_2d.cuh
index e1e9e9ec4..5877ddd87 100644
--- a/transformer_engine/common/util/rocm_vectorized_2d.cuh
+++ b/transformer_engine/common/util/rocm_vectorized_2d.cuh
@@ -10,13 +10,11 @@
namespace transformer_engine {
// These 2d copy functions replace TMA tensormap async copies for AMD GPUs.
-template
+template
__device__ inline void copy_2d_to_shared(T *sh_ptr_base, const T *g_ptr, size_t g_start_col,
size_t g_start_row, size_t g_stride, size_t chunk_dim_y,
size_t chunk_dim_x, size_t total_rows,
size_t total_cols) {
-// TODO: Manage edge cases where "aligned = true" causes into issues
- constexpr bool ALIGNED_ACCESS = aligned;
size_t chunk_dim_x_vec_elements = (chunk_dim_x + N_VEC - 1) / N_VEC;
const size_t l_idx = threadIdx.x;
@@ -51,12 +49,11 @@ __device__ inline void copy_2d_to_shared(T *sh_ptr_base, const T *g_ptr, size_t
}
}
-template
+template
__device__ inline void bulk_tensor_2d_shared_to_global(const T *sh_ptr_base, T *g_ptr, size_t g_start_col,
size_t g_start_row, size_t g_stride, size_t chunk_dim_y,
size_t chunk_dim_x, size_t total_rows,
size_t total_cols) {
- constexpr bool ALIGNED_ACCESS = aligned;
const size_t chunk_dim_x_vec_elements = (chunk_dim_x + N_VEC - 1) / N_VEC;
const size_t l_idx = threadIdx.x;
diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py
index d1f701489..495284e11 100644
--- a/transformer_engine/jax/cpp_extensions/attention.py
+++ b/transformer_engine/jax/cpp_extensions/attention.py
@@ -357,7 +357,10 @@ def abstract(
softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, config.max_segments_per_seq)
softmax_dtype = dtypes.canonicalize_dtype(jnp.float32)
elif backend == NVTE_Fused_Attn_Backend.NVTE_CK:
- softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1)
+ if config.qkv_layout.is_thd():
+ softmax_shape = (*batch_shape, q_max_seqlen, attn_heads, 1)
+ else:
+ softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1)
softmax_dtype = dtypes.canonicalize_dtype(jnp.float32)
else:
raise ValueError(f"Unsupported {backend=}")
diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py
index 12e7927db..c76a40807 100644
--- a/transformer_engine/pytorch/fp8.py
+++ b/transformer_engine/pytorch/fp8.py
@@ -132,6 +132,7 @@ class FP8GlobalStateManager:
HIGH_PRECISION_INIT_VAL = False
IS_FIRST_FP8_MODULE = False
FP8_GRAPH_CAPTURING = False
+ SKIP_FP8_REDUCTION_FOR_FSDP2 = False
FP8_AUTOCAST_DEPTH = 0
global_amax_buffer = {}
global_amax_history_buffer = {}
@@ -494,7 +495,7 @@ def fp8_autocast_exit(cls, enabled: bool, _graph: bool) -> None:
# Reduce only the non-FP8 weight modules here.
# FP8 weight modules are reduced at the end of the optimizer
# step after the weight amax is populated.
- if enabled and cls.FP8_AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled():
+ if not cls.SKIP_FP8_REDUCTION_FOR_FSDP2 and enabled and cls.FP8_AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled():
# delayed scaling only function, for other recipes (current scaling with any granularity),
# this is noop for other recipes because cls.global_amax_buffer is empty list
cls.reduce_and_update_fp8_tensors(forward=True)
diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py
index 00f555fc2..87d50ccc9 100644
--- a/transformer_engine/pytorch/graph.py
+++ b/transformer_engine/pytorch/graph.py
@@ -464,6 +464,8 @@ def backward(ctx, *grads):
# Update FP8 scale factors if needed
if ctx.is_first_module:
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
+ if FP8GlobalStateManager.SKIP_FP8_REDUCTION_FOR_FSDP2:
+ FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=True)
# Input args that didn't require grad expect a None gradient.
assert isinstance(static_grad_inputs, tuple)
diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py
index 815c808b2..781c20417 100644
--- a/transformer_engine/pytorch/module/base.py
+++ b/transformer_engine/pytorch/module/base.py
@@ -43,6 +43,8 @@
from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
+if IS_HIP_EXTENSION:
+ from ..tensor.fsdp2_allgather_tensor import FSDPAGTensor
from ..tensor._internal.float8_tensor_base import Float8TensorBase
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..utils import get_device_compute_capability, torch_get_autocast_gpu_dtype
@@ -596,7 +598,8 @@ def __init__(self) -> None:
self.fsdp_group = None
self._fp8_workspaces: Dict[str, QuantizedTensor] = {}
self.activation_dtype: Optional[torch.dtype] = None,
- self.keep_fp8_weight_transpose_cache: bool = True
+ self.keep_fp8_weight_transpose_cache: bool = True,
+ self.use_fsdp2 = False
if not TEDebugState.debug_enabled:
TEDebugState.initialize()
@@ -905,6 +908,9 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None:
fp8_enabled = self.fp8 or self.fp8_calibration
self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration
+ if IS_HIP_EXTENSION and not FP8GlobalStateManager.SKIP_FP8_REDUCTION_FOR_FSDP2 and hasattr(self, 'use_fsdp2') and self.use_fsdp2:
+ FP8GlobalStateManager.SKIP_FP8_REDUCTION_FOR_FSDP2 = True
+
if self.fp8_parameters or fp8_enabled:
if (
self.fp8_initialized
@@ -936,6 +942,8 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None:
self.fp8_initialized = True
self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
+ if self.fp8_meta["recipe"].mxfp8():
+ self.keep_fp8_weight_transpose_cache = True
@contextmanager
def prepare_forward(
@@ -1153,6 +1161,14 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None:
if IS_HIP_EXTENSION and not self.keep_fp8_weight_transpose_cache:
quantizer.columnwise_usage=False
param = quantizer(param)
+ if IS_HIP_EXTENSION and self.use_fsdp2 and not self.primary_weights_in_fp8 and fp8_meta_index is not None:
+ self.keep_fp8_weight_transpose_cache = False
+ param = FSDPAGTensor(
+ param,
+ module=self,
+ fp8_meta_index=fp8_meta_index,
+ keep_fp8_weight_transpose_cache=self.keep_fp8_weight_transpose_cache
+ )
# Redo parameter wrap in case we broke it above
# NOTE: Currently this can only be broken when primary weights are in Fp8 but
diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py
index 1164d97cc..6225d3119 100644
--- a/transformer_engine/pytorch/module/grouped_linear.py
+++ b/transformer_engine/pytorch/module/grouped_linear.py
@@ -1,3 +1,5 @@
+# This file was modified for portability to AMDGPU
+# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
@@ -8,6 +10,7 @@
import functools
import torch
+import os
import transformer_engine_torch as tex
@@ -49,6 +52,7 @@
prepare_for_saving,
restore_from_saved,
)
+from torch.utils.cpp_extension import IS_HIP_EXTENSION
__all__ = ["GroupedLinear"]
@@ -125,9 +129,20 @@ def forward(
recipe = FP8GlobalStateManager.get_fp8_recipe()
if hasattr(recipe, "fp8_gemm_fprop"):
fprop_gemm_use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator
- inputmats = tex.fused_multi_quantize(
- inputmats_no_fp8, None, input_quantizers, TE_DType[activation_dtype]
- )
+
+ if IS_HIP_EXTENSION and bool( int(os.environ.get('NVTE_USE_CAST_TRANSPOSE_TRITON', '0')) ):
+ # The Triton path has no equivalent for tex.fused_multi_quantize()
+ inputmats = []
+ for i, x in enumerate(inputmats_no_fp8):
+ qi = input_quantizers[i]
+ dst = qi.make_empty(x.shape, dtype=x.dtype, device=x.device, requires_grad=False)
+ qi.update_quantized(x, dst, noop_flag=None)
+ inputmats.append(dst)
+ else:
+ inputmats = tex.fused_multi_quantize(
+ inputmats_no_fp8, None, input_quantizers, TE_DType[activation_dtype]
+ )
+
weights_fp8 = []
bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype
# FP8 cast to workspace buffer
diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py
index 0403906c9..3e0844a7a 100644
--- a/transformer_engine/pytorch/module/layernorm_linear.py
+++ b/transformer_engine/pytorch/module/layernorm_linear.py
@@ -137,6 +137,7 @@ def forward(
symmetric_ar_type: str,
debug: Optional[bool] = False,
keep_fp8_weight_transpose_cache: bool = True,
+ use_fsdp2: bool = False,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# pylint: disable=missing-function-docstring
@@ -418,7 +419,7 @@ def forward(
ln_out.update_usage(rowwise_usage=False)
# Weight with column-wise usage is needed for dgrad GEMM while keeping fp8 weight transpose cache.
- if inp.requires_grad and keep_fp8_weight_transpose_cache:
+ if inp.requires_grad and keep_fp8_weight_transpose_cache and not use_fsdp2:
if isinstance(weightmat, QuantizedTensorBase):
weightmat.update_usage(columnwise_usage=True)
@@ -498,15 +499,17 @@ def forward(
ctx.requires_dgrad = inp_requires_grad
ctx.normalization = normalization
ctx.reduce_and_update_bwd_fp8_tensors = False
+ ctx.autocast_fp8_reduction_skipped = False
if ctx.fp8 and requires_grad(inp, ln_weight, ln_bias, weight, bias):
_first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE
ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module()
if in_fp8_activation_recompute_phase():
FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module
+ ctx.autocast_fp8_reduction_skipped = FP8GlobalStateManager.SKIP_FP8_REDUCTION_FOR_FSDP2
ctx.wgrad_store = wgrad_store
ctx.debug = debug
ctx.keep_fp8_weight_transpose_cache = keep_fp8_weight_transpose_cache
-
+ ctx.use_fsdp2 = use_fsdp2
# ------------------------------------------------------
# Cached state for backward pass is ready...
# ------------------------------------------------------
@@ -975,6 +978,8 @@ def wgrad_gemm(
if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing():
nvtx_range_push(f"{nvtx_label}.reduce_and_update_fp8_tensors")
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
+ if ctx.autocast_fp8_reduction_skipped:
+ FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=True)
nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors")
# Scatter fp8 weight buffers
@@ -1026,6 +1031,7 @@ def wgrad_gemm(
None, # skip_fp8_weight_update
None, # symmetric_ar_type
None, # keep_fp8_weight_transpose_cache
+ None, # use_fsdp2
)
@@ -1171,6 +1177,7 @@ def __init__(
symmetric_ar_type: Optional[str] = None,
name: str = None,
keep_fp8_weight_transpose_cache: bool = True,
+ use_fsdp2: bool = False
) -> None:
super().__init__()
@@ -1192,7 +1199,8 @@ def __init__(
self.name = name
if TEDebugState.debug_enabled:
self._turn_off_unsupported_features_in_debug() # turn off userbuffers
- self.keep_fp8_weight_transpose_cache = keep_fp8_weight_transpose_cache if IS_HIP_EXTENSION else True
+ self.keep_fp8_weight_transpose_cache = keep_fp8_weight_transpose_cache if IS_HIP_EXTENSION else True
+ self.use_fsdp2 = use_fsdp2 if IS_HIP_EXTENSION else False
if tp_group is None:
self.tp_size = tp_size
@@ -1606,7 +1614,8 @@ def forward(
skip_fp8_weight_update,
self.symmetric_ar_type,
debug,
- self.keep_fp8_weight_transpose_cache
+ self.keep_fp8_weight_transpose_cache,
+ self.use_fsdp2
)
out = fwd_fn(*args)
diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py
index b1128b73f..2edea459c 100644
--- a/transformer_engine/pytorch/module/layernorm_mlp.py
+++ b/transformer_engine/pytorch/module/layernorm_mlp.py
@@ -205,6 +205,7 @@ def forward(
symmetric_ar_type: str,
debug: Optional[bool] = False,
keep_fp8_weight_transpose_cache: bool = True,
+ use_fsdp2: bool = False,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# pylint: disable=missing-function-docstring
@@ -527,7 +528,7 @@ def forward(
if is_grad_enabled:
# Weight with column-wise usage is needed for dgrad GEMM while keeping fp8 weight transpose cache.
- if inp.requires_grad and keep_fp8_weight_transpose_cache:
+ if inp.requires_grad and keep_fp8_weight_transpose_cache and not use_fsdp2:
if isinstance(fc1_weight_final, QuantizedTensorBase):
fc1_weight_final.update_usage(columnwise_usage=True)
if isinstance(fc2_weight_final, QuantizedTensorBase):
@@ -635,7 +636,9 @@ def forward(
)
ctx.normalization = normalization
ctx.reduce_and_update_bwd_fp8_tensors = False
+ ctx.autocast_fp8_reduction_skipped = False
ctx.keep_fp8_weight_transpose_cache = keep_fp8_weight_transpose_cache
+ ctx.use_fsdp2 = use_fsdp2
if ctx.fp8 and requires_grad(
inp, ln_weight, ln_bias, fc1_weight, fc2_weight, fc1_bias, fc2_bias
):
@@ -643,6 +646,7 @@ def forward(
ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module()
if in_fp8_activation_recompute_phase():
FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module
+ ctx.autocast_fp8_reduction_skipped = FP8GlobalStateManager.SKIP_FP8_REDUCTION_FOR_FSDP2
ctx.wgrad_store = wgrad_store
@@ -1314,6 +1318,8 @@ def fc1_wgrad_gemm(
if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing():
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
+ if ctx.autocast_fp8_reduction_skipped:
+ FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=True)
# FIX THIS
# Scatter Fp8 tranposed-weight buffers
@@ -1377,6 +1383,7 @@ def fc1_wgrad_gemm(
None, # symmetric_ar_type
None, # debug
None, # keep_fp8_weight_transpose_cache
+ None, # use_fsdp2
)
@@ -1531,6 +1538,7 @@ def __init__(
delay_wgrad_compute: bool = False,
symmetric_ar_type: Optional[str] = None,
keep_fp8_weight_transpose_cache: bool = True,
+ use_fsdp2: bool = False
) -> None:
super().__init__()
@@ -1551,7 +1559,7 @@ def __init__(
self.zero_centered_gamma = zero_centered_gamma
self.symmetric_ar_type = symmetric_ar_type
self.keep_fp8_weight_transpose_cache = keep_fp8_weight_transpose_cache if IS_HIP_EXTENSION else True
-
+ self.use_fsdp2 = use_fsdp2 if IS_HIP_EXTENSION else False
# GEMM-GELU fusion is currently only supported with split GEMM-AG overlap
self.gemm_gelu_fusion = (
bool(int(os.getenv("NVTE_GEMM_GELU_FUSION", "0")))
@@ -1879,6 +1887,7 @@ def forward(
self.symmetric_ar_type,
debug,
self.keep_fp8_weight_transpose_cache,
+ self.use_fsdp2
)
out = fwd_fn(*args)
diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py
index 8591fc83e..a8361ef94 100644
--- a/transformer_engine/pytorch/module/linear.py
+++ b/transformer_engine/pytorch/module/linear.py
@@ -121,6 +121,7 @@ def forward(
symmetric_ar_type: str,
debug: Optional[bool] = False,
keep_fp8_weight_transpose_cache: bool = True,
+ use_fsdp2: bool = False,
) -> torch.Tensor:
# pylint: disable=missing-function-docstring
@@ -360,8 +361,8 @@ def forward(
assert not isinstance(inputmat, QuantizedTensorBase)
saved_inputmat = inputmat
- # Weight with column-wise usage is needed for dgrad GEMM.
- if inp.requires_grad and keep_fp8_weight_transpose_cache:
+ # Weight with column-wise usage is needed for dgrad GEMM while keeping fp8 weight transpose cache.
+ if inp.requires_grad and keep_fp8_weight_transpose_cache and not use_fsdp2:
if isinstance(weightmat, QuantizedTensorBase):
weightmat.update_usage(columnwise_usage=True)
@@ -430,14 +431,17 @@ def forward(
ctx.requires_dgrad = inp.requires_grad
ctx.requires_wgrad = weight.requires_grad
ctx.reduce_and_update_bwd_fp8_tensors = False
+ ctx.autocast_fp8_reduction_skipped = False
ctx.owns_input = saved_inputmat is not inp
ctx.keep_fp8_weight_transpose_cache = keep_fp8_weight_transpose_cache
+ ctx.use_fsdp2 = use_fsdp2
if ctx.fp8 and requires_grad(inp, weight, bias):
_first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE
ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module()
if in_fp8_activation_recompute_phase():
FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module
+ ctx.autocast_fp8_reduction_skipped = FP8GlobalStateManager.SKIP_FP8_REDUCTION_FOR_FSDP2
ctx.wgrad_store = wgrad_store
# ------------------------------------------------------
@@ -858,6 +862,8 @@ def wgrad_gemm(
if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing():
nvtx_range_push(f"{nvtx_label}.reduce_and_update_fp8_tensors")
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)
+ if ctx.autocast_fp8_reduction_skipped:
+ FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=True)
nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors")
# Scatter fp8 weight buffers
@@ -900,6 +906,7 @@ def wgrad_gemm(
None, # symmetric_ar_type
None, # debug
None, # keep_fp8_weight_transpose_cache
+ None, # use_fsdp2
)
@@ -1024,6 +1031,7 @@ def __init__(
symmetric_ar_type: Optional[str] = None,
name: Optional[str] = None,
keep_fp8_weight_transpose_cache: bool = True,
+ use_fsdp2: bool = False
) -> None:
super().__init__()
@@ -1044,6 +1052,7 @@ def __init__(
self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad)
self.keep_fp8_weight_transpose_cache = keep_fp8_weight_transpose_cache if IS_HIP_EXTENSION else True
+ self.use_fsdp2 = use_fsdp2 if IS_HIP_EXTENSION else False
if device == "meta":
assert parameters_split is None, "Cannot split module parameters on 'meta' device."
@@ -1410,6 +1419,7 @@ def forward(
self.symmetric_ar_type,
debug,
self.keep_fp8_weight_transpose_cache,
+ self.use_fsdp2
)
out = linear_fn(*args)
if self.gemm_bias_unfused_add:
diff --git a/transformer_engine/pytorch/optimizers/multi_tensor_apply.py b/transformer_engine/pytorch/optimizers/multi_tensor_apply.py
index 64ec0a28d..e5daa4c64 100644
--- a/transformer_engine/pytorch/optimizers/multi_tensor_apply.py
+++ b/transformer_engine/pytorch/optimizers/multi_tensor_apply.py
@@ -1,10 +1,12 @@
-# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# This file was modified for portability to AMDGPU
+# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved
+# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Multi-tensor apply entry."""
from torch.distributed._tensor import DTensor
-
+from torch.utils.cpp_extension import IS_HIP_EXTENSION
class MultiTensorApply: # pylint: disable=too-few-public-methods
"""Multi-tensor apply entry."""
@@ -16,7 +18,7 @@ def __call__(self, op, noop_flag_buffer, tensor_lists, *args):
for i, ts in enumerate(tensor_lists):
for j, t in enumerate(ts):
if isinstance(t, DTensor):
- tensor_lists[i][j] = t._local_tensor
+ tensor_lists[i][j] = t._local_tensor.data if IS_HIP_EXTENSION else t._local_tensor
return op(self.chunk_size, noop_flag_buffer, tensor_lists, *args)
diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py
index b55ac577c..cce37dde8 100644
--- a/transformer_engine/pytorch/tensor/float8_tensor.py
+++ b/transformer_engine/pytorch/tensor/float8_tensor.py
@@ -247,7 +247,12 @@ def update_quantized(
src = src.contiguous()
# Launch cast kernel
- tex.quantize(src, self, dst, noop_flag)
+ if IS_HIP_EXTENSION:
+ use_cast_transpose_triton = bool( int(os.environ.get('NVTE_USE_CAST_TRANSPOSE_TRITON', '0')) )
+ quantize_func = te_quantize_triton if use_cast_transpose_triton else tex.quantize
+ quantize_func(src, self, dst, noop_flag)
+ else:
+ tex.quantize(src, self, dst, noop_flag)
# Update FP8 dtype
dst._fp8_dtype = self.dtype
diff --git a/transformer_engine/pytorch/tensor/fsdp2_allgather_tensor.py b/transformer_engine/pytorch/tensor/fsdp2_allgather_tensor.py
new file mode 100644
index 000000000..00d2cd97b
--- /dev/null
+++ b/transformer_engine/pytorch/tensor/fsdp2_allgather_tensor.py
@@ -0,0 +1,199 @@
+#!/usr/bin/python3
+# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
+# See LICENSE for license information.
+
+from typing import Any, Optional, Tuple
+import torch
+import torch.nn as nn
+from transformer_engine.pytorch.tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer
+from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
+import torch.utils._pytree as pytree
+
+_ops_to_preserve_subclass = {
+ torch.ops.aten.empty_like.default,
+ torch.ops.aten.new_zeros.default,
+ torch.ops.aten.slice.Tensor,
+ torch.ops.aten.copy_.default,
+ torch.ops.aten.view.default,
+ torch.ops.aten.as_strided.default,
+ torch.ops.aten._to_copy.default,
+ torch.ops.aten._pin_memory.default,
+ torch.ops.aten.split.Tensor,
+ torch.ops.aten.clone.default,
+}
+
+
+# A wrapper subclass for stateful FSDP transport
+class FSDPAGTensor(torch.Tensor):
+
+ @staticmethod
+ def __new__(cls, elem: torch.Tensor, **kwargs):
+ # Build an "empty" wrapper with the same meta as elem
+ return torch.Tensor._make_wrapper_subclass(
+ cls,
+ elem.size(),
+ strides=elem.stride(),
+ storage_offset=elem.storage_offset(),
+ dtype=elem.dtype,
+ layout=elem.layout,
+ requires_grad=elem.requires_grad,
+ device=elem.device,
+ )
+
+ def __init__(
+ self,
+ tensor: torch.Tensor,
+ *,
+ module: nn.Module,
+ fp8_meta_index: str,
+ keep_fp8_weight_transpose_cache: bool,
+ ):
+ #The underlying tensor
+ self._data = tensor
+ # Where quantizers are present
+ self._module = module
+ # Which quantizer to use within module.quantizers["scaling_fwd"][idx]
+ self._fp8_meta_index = fp8_meta_index
+ # Disable or enable transpose cache for fp8 weights
+ self._keep_fp8_weight_transpose_cache = keep_fp8_weight_transpose_cache
+
+ @property
+ def data(self) -> torch.Tensor:
+ return self._data.detach()
+
+ def __repr__(self):
+ return (
+ f"FSDPAGTensor("
+ f"elem={self._data}, "
+ f"module={self._module.__class__.__name__}, "
+ f"fp8_meta_index={self._fp8_meta_index})"
+ )
+
+ def __tensor_flatten__(self):
+ """
+ Makes some ops (view/as_strided, etc.) and serialization friendlier for wrapper subclasses.
+ Return (names_of_inner_tensors, flatten_spec_metadata).
+ """
+ # We only carry the one inner tensor.
+ # We store (module, fp8_meta_index, keep_fp8_weight_transpose_cache) as metadata to reconstruct.
+ return ["_data"], (self._module, self._fp8_meta_index, self._keep_fp8_weight_transpose_cache)
+
+
+ @staticmethod
+ def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride):
+ module, fp8_meta_index, keep_fp8_weight_transpose_cache = flatten_spec
+ return FSDPAGTensor(
+ inner_tensors["_data"],
+ module=module,
+ fp8_meta_index=fp8_meta_index,
+ keep_fp8_weight_transpose_cache=keep_fp8_weight_transpose_cache
+ )
+
+ @classmethod
+ def __torch_dispatch__(cls, func, types, args, kwargs=None):
+ if kwargs is None:
+ kwargs = {}
+
+ # detach
+ if func is torch.ops.aten.detach.default:
+ t = args[0]
+ assert isinstance(t, cls), f"Unexpected detach input type: {type(t)}"
+ detached = t._data.detach()
+ return cls(detached, module=t._module, fp8_meta_index=t._fp8_meta_index, keep_fp8_weight_transpose_cache=t._keep_fp8_weight_transpose_cache)
+
+ # Unwrap only our subclass; capture shared metadata for rewrapping
+ meta: Optional[tuple[nn.Module, str, bool]] = None
+
+ def unwrap(x):
+ nonlocal meta
+ if isinstance(x, cls):
+ if meta is None:
+ meta = (x._module, x._fp8_meta_index, x._keep_fp8_weight_transpose_cache)
+ return x._data
+ return x
+
+ unwrapped_args, unwrapped_kwargs = pytree.tree_map_only(cls, unwrap, (args, kwargs))
+
+ # Run the actual op on internal tensors
+ out = func(*unwrapped_args, **unwrapped_kwargs)
+
+ # Rewrap outputs only for ops that need to preserve subclass identity
+ if func not in _ops_to_preserve_subclass or meta is None:
+ return out
+
+ def rewrap(x):
+ if isinstance(x, torch.Tensor):
+ mod, idx, keep_transpose = meta
+ return cls(x, module=mod, fp8_meta_index=idx, keep_fp8_weight_transpose_cache=keep_transpose)
+ return x
+
+ out = pytree.tree_map_only(torch.Tensor, rewrap, out)
+ return out
+
+ # Must return (list_of_tensors_to_all_gather, user_metadata)
+ def fsdp_pre_all_gather(self, mesh):
+ # If metadata isn't initialized yet, we can't access the quantizers
+ if not self._module.fp8:
+ module_class_name = self._module.__class__.__name__
+ if "LayerNormMLP" in module_class_name:
+ num_gemms = 2
+ else: # Linear, LayerNormLinear, etc.
+ num_gemms = 1
+
+ self._module.init_fp8_metadata(num_gemms=num_gemms)
+ if not self._module.fp8:
+ return (self._data,), (self._data.requires_grad,)
+ # Use the actual data
+ base = self._data
+ # Access the quantizer using fp8_meta_index
+ quantizer = self._module.quantizers["scaling_fwd"][self._fp8_meta_index]
+ if not isinstance(quantizer, MXFP8Quantizer) and not self._keep_fp8_weight_transpose_cache:
+ quantizer.set_usage(columnwise=False)
+ if isinstance(quantizer, Float8CurrentScalingQuantizer):
+ quantizer.with_amax_reduction = True
+ sharded_fp8_tensor = quantizer(base)
+ if isinstance(quantizer, MXFP8Quantizer):
+ rowwise_data = sharded_fp8_tensor._rowwise_data if quantizer.rowwise_usage else torch.empty(0, dtype=torch.uint8, device=base.device)
+ rowwise_scale_inv = sharded_fp8_tensor._rowwise_scale_inv if quantizer.rowwise_usage else torch.empty(0, dtype=torch.uint8, device=base.device)
+ columnwise_data = sharded_fp8_tensor._columnwise_data if quantizer.columnwise_usage else torch.empty(0, dtype=torch.uint8, device=base.device)
+ columnwise_scale_inv = sharded_fp8_tensor._columnwise_scale_inv if quantizer.columnwise_usage else torch.empty(0, dtype=torch.uint8, device=base.device)
+ return (rowwise_data, rowwise_scale_inv, columnwise_data, columnwise_scale_inv, ), (base.requires_grad,)
+ return (sharded_fp8_tensor._data,), (base.requires_grad,)
+
+ def fsdp_post_all_gather(
+ self,
+ all_gather_outputs: Tuple[torch.Tensor, ...],
+ metadata: Any,
+ param_dtype: torch.dtype,
+ *,
+ out: Optional[torch.Tensor] = None,
+ ):
+ (requires_grad, ) = metadata
+ if not self._module.fp8:
+ (data,) = all_gather_outputs
+ return data, all_gather_outputs
+ # Retrieve the same quantizer you used in pre_all_gather
+ quantizer = self._module.quantizers["scaling_fwd"][self._fp8_meta_index]
+ shape = None
+ if not isinstance(quantizer, MXFP8Quantizer) and not self._keep_fp8_weight_transpose_cache:
+ quantizer.set_usage(columnwise=False)
+ if isinstance(quantizer, MXFP8Quantizer):
+ (rowwise_data, rowwise_scale_inv, columnwise_data, columnwise_scale_inv,) = all_gather_outputs
+ shape = rowwise_data.shape
+ else:
+ (data,) = all_gather_outputs
+ shape = data.shape
+
+ # Construct a new low precision tensor subclass that will wrap the gathered data
+ if out is None:
+ out = quantizer.make_empty(shape = shape, dtype=param_dtype, requires_grad=requires_grad)
+
+ if isinstance(quantizer, MXFP8Quantizer):
+ out._rowwise_data = rowwise_data
+ out._rowwise_scale_inv = rowwise_scale_inv
+ out._columnwise_data = None if columnwise_data.numel() == 0 else columnwise_data
+ out._columnwise_scale_inv = None if columnwise_scale_inv.numel() == 0 else columnwise_scale_inv
+ else:
+ out._scale_inv = 1 / quantizer.scale
+ out._data = data
+ return out, all_gather_outputs
diff --git a/transformer_engine/pytorch/triton_kernels/cast.py b/transformer_engine/pytorch/triton_kernels/cast.py
index 67875ed71..4c7033132 100644
--- a/transformer_engine/pytorch/triton_kernels/cast.py
+++ b/transformer_engine/pytorch/triton_kernels/cast.py
@@ -96,6 +96,10 @@ def te_quantize_triton(
cast_out = out._data
trans_out = out._transpose
scale_inv_out = out._scale_inv
+
+ from ..tensor.float8_tensor import Float8CurrentScalingQuantizer
+ is_current_scaling = isinstance(quantizer, Float8CurrentScalingQuantizer)
+
te_cast_transpose_noop_triton(
input_tensor,
noop_flag,
@@ -104,7 +108,10 @@ def te_quantize_triton(
trans_out=trans_out,
amax_out=amax_out,
scale_inv_out=scale_inv_out,
- otype=otype
+ otype=otype,
+ current_scaling=is_current_scaling,
+ eps = getattr(quantizer, "amax_epsilon", 0.0),
+ force_pow_2_scales = getattr(quantizer, "force_pow_2_scales", False),
)
else:
diff --git a/transformer_engine/pytorch/triton_kernels/cast_transpose.py b/transformer_engine/pytorch/triton_kernels/cast_transpose.py
index bf380f203..8d35c3681 100644
--- a/transformer_engine/pytorch/triton_kernels/cast_transpose.py
+++ b/transformer_engine/pytorch/triton_kernels/cast_transpose.py
@@ -16,6 +16,80 @@
#### cast_transpose
##########################################
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'GROUP_M': 1}, num_warps=4),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'GROUP_M': 8}, num_warps=4),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=8),
+ ],
+ key=['M', 'N'],
+)
+@triton.jit
+def _amax_reduce_triton(
+ A,
+ stride_am, stride_an,
+ M, N,
+ amax_ptr, # float32[1], initialize to -inf on host
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ GROUP_M: tl.constexpr,
+):
+ pid = tl.program_id(0)
+
+ grid_m = (M + BLOCK_M - 1) // BLOCK_M
+ grid_n = (N + BLOCK_N - 1) // BLOCK_N
+
+ width = GROUP_M * grid_n
+ group_id = pid // width
+ group_size = tl.minimum(grid_m - group_id * GROUP_M, GROUP_M)
+ pid_m = group_id * GROUP_M + (pid % group_size)
+ pid_n = (pid % width) // group_size
+
+ rm = pid_m.to(tl.int64) * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n.to(tl.int64) * BLOCK_N + tl.arange(0, BLOCK_N)
+
+ A_ptrs = A + rm[:, None] * stride_am + rn[None, :] * stride_an
+ mask = (rm < M)[:, None] & (rn < N)[None, :]
+
+ a = tl.load(A_ptrs, mask=mask, other=0).to(tl.float32)
+ tile_amax = tl.max(tl.abs(a))
+ # accumulate tile-wise max into global amax
+ tl.atomic_max(amax_ptr, tile_amax, sem='relaxed')
+
+
+@triton.jit
+def _compute_scale_from_amax_triton(
+ amax_ptr,
+ scale_ptr,
+ inv_ptr,
+ max_fp8,
+ epsilon,
+ value_for_inf,
+ FORCE_POW_2_SCALES: tl.constexpr,
+):
+ # This implementation mimics transformer_engine::compute_scale_from_amax()
+
+ a = tl.load(amax_ptr).to(tl.float32)
+
+ # amax < epsilon -> epsilon (NaNs pass through)
+ a = tl.where(a < epsilon, epsilon, a)
+
+ # bad amax (NaN, inf, 0.0) -> scale = 1.0
+ bad = (a != a) | (tl.abs(a) == float('inf')) | (a == 0.0)
+
+ if bad:
+ s = tl.full((), 1.0, tl.float32)
+ else:
+ s = max_fp8 / a
+ # inf -> scale = value_for_inf
+ s = tl.where(tl.abs(a) == float('inf'), value_for_inf, s)
+ if FORCE_POW_2_SCALES:
+ s = tl.math.exp2(tl.floor(tl.log2(s)))
+
+ tl.store(scale_ptr, s)
+ tl.store(inv_ptr, 1.0 / s)
+
+
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'GROUP_M': 1}, num_warps=4),
@@ -69,6 +143,52 @@ def _cast_transpose_triton(A, noop_ptr, C, T, stride_am, stride_an, stride_bn, s
scale_inv_out = tl.fdiv(1.0, scale)
tl.store(scale_inv_ptr, scale_inv_out)
+
+@triton.autotune(
+ configs=[
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'GROUP_M': 1}, num_warps=4),
+ triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'GROUP_M': 8}, num_warps=4),
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'GROUP_M': 8}, num_warps=8),
+ ],
+ key=['M', 'N']
+)
+@triton.jit
+def _cast_transpose_triton_current_scaling(A, C, T, stride_am, stride_an, stride_bn, stride_bm, M, N, scale_ptr, max_fp8: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, GROUP_M: tl.constexpr):
+ # Similar (but slightly optimized) version of the delayed scaling kernel
+ # implemented in _cast_transpose_triton().
+ pid = tl.program_id(0)
+ scale = tl.load(scale_ptr)
+
+ grid_m = (M + BLOCK_M - 1) // BLOCK_M
+ grid_n = (N + BLOCK_N - 1) // BLOCK_N
+
+ width = GROUP_M * grid_n
+ group_id = pid // width
+ group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
+ pid_m = group_id * GROUP_M + (pid % group_size)
+ pid_n = (pid % width) // group_size
+
+ rm = pid_m.to(tl.int64) * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n.to(tl.int64) * BLOCK_N + tl.arange(0, BLOCK_N)
+ A = A + rm[:, None] * stride_am + rn[None, :] * stride_an
+ mask = (rm < M)[:, None] & (rn < N)[None, :]
+ a = tl.load(A, mask=mask)
+ a = a.to(tl.float32)
+
+ scaled_a = a * scale
+ scaled_a = tl.clamp(scaled_a, -max_fp8, max_fp8)
+ fp8_a = scaled_a.to(C.type.element_ty)
+ C = C + rm[:, None] * stride_am + rn[None, :] * stride_an
+ tl.store(C, fp8_a, mask=mask)
+
+ # rematerialize to save registers
+ rm = pid_m.to(tl.int64) * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n.to(tl.int64) * BLOCK_N + tl.arange(0, BLOCK_N)
+ T = T + rm[:, None] * stride_bm + rn[None, :] * stride_bn
+ mask = (rm < M)[:, None] & (rn < N)[None, :]
+ tl.store(T, fp8_a, mask=mask)
+
+
FP32_EXPONENT_BIAS = tl.constexpr(127)
FP32_MANTISSA_BITS = tl.constexpr(23)
@triton.jit
@@ -232,7 +352,7 @@ def _dequantize_mxfp8_triton(
# Reshapes input of any given shape to 2D for processing,
# then uses the Triton kernel to perform casting and transposition efficiently.
-def te_cast_transpose_noop_triton(input, noop_flag, input_scale, cast_out, trans_out, amax_out, scale_inv_out, otype):
+def te_cast_transpose_noop_triton(input, noop_flag, input_scale, cast_out, trans_out, amax_out, scale_inv_out, otype, current_scaling, eps, force_pow_2_scales):
row_length = input.shape[-1] if len(input.shape) > 0 else 1
num_rows = input.numel() // row_length
@@ -254,7 +374,35 @@ def te_cast_transpose_noop_triton(input, noop_flag, input_scale, cast_out, trans
use_noop = False
grid = lambda META: (triton.cdiv(num_rows, META['BLOCK_M']) * triton.cdiv(row_length, META['BLOCK_N']),)
- _cast_transpose_triton[grid](input_2d_view, noop_flag, triton.reinterpret(cast_out_2d_view, tl_dtype), triton.reinterpret(trans_out_2d_view, tl_dtype), input_stride_M, input_stride_N, trans_out_stride_M, trans_out_stride_N, num_rows, row_length, input_scale, amax_out, scale_inv_out, get_fp8_max(otype), use_noop)
+
+ if current_scaling:
+ # Current scaling:
+ # 1) global amax reduction
+ # 2) compute current scale
+ # 3) cast+transpose with that current scale (otherwise same as delayed)
+
+ # global amax
+ amax_out.fill_(-float("inf"))
+ _amax_reduce_triton[grid](
+ input_2d_view,
+ input_stride_M, input_stride_N,
+ num_rows, row_length,
+ amax_out,
+ )
+
+ # Compute scale
+ fp8_max = get_fp8_max(otype)
+
+ _compute_scale_from_amax_triton[(1,)](
+ amax_out, input_scale, scale_inv_out,
+ fp8_max, eps, torch.finfo(torch.float32).max,
+ FORCE_POW_2_SCALES=force_pow_2_scales,
+ )
+
+ _cast_transpose_triton_current_scaling[grid](input_2d_view, triton.reinterpret(cast_out_2d_view, tl_dtype), triton.reinterpret(trans_out_2d_view, tl_dtype), input_stride_M, input_stride_N, trans_out_stride_M, trans_out_stride_N, num_rows, row_length, input_scale, get_fp8_max(otype))
+ else:
+ # Delayed scaling
+ _cast_transpose_triton[grid](input_2d_view, noop_flag, triton.reinterpret(cast_out_2d_view, tl_dtype), triton.reinterpret(trans_out_2d_view, tl_dtype), input_stride_M, input_stride_N, trans_out_stride_M, trans_out_stride_N, num_rows, row_length, input_scale, amax_out, scale_inv_out, get_fp8_max(otype), use_noop)
def te_cast_transpose_mxfp8_triton(input, out, noop_flag=None):
row_length = input.shape[-1] if len(input.shape) > 0 else 1